llvm-project/mlir/lib/Analysis/Presburger/SlowMPInt.cpp

283 lines
9.3 KiB
C++

//===- SlowMPInt.cpp - MLIR SlowMPInt Class -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/SlowMPInt.h"
#include "llvm/Support/MathExtras.h"
using namespace mlir;
using namespace presburger;
using namespace detail;
SlowMPInt::SlowMPInt(int64_t val) : val(64, val, /*isSigned=*/true) {}
SlowMPInt::SlowMPInt() : SlowMPInt(0) {}
SlowMPInt::SlowMPInt(const llvm::APInt &val) : val(val) {}
SlowMPInt &SlowMPInt::operator=(int64_t val) { return *this = SlowMPInt(val); }
SlowMPInt::operator int64_t() const { return val.getSExtValue(); }
llvm::hash_code detail::hash_value(const SlowMPInt &x) {
return hash_value(x.val);
}
/// ---------------------------------------------------------------------------
/// Printing.
/// ---------------------------------------------------------------------------
void SlowMPInt::print(llvm::raw_ostream &os) const { os << val; }
void SlowMPInt::dump() const { print(llvm::errs()); }
llvm::raw_ostream &detail::operator<<(llvm::raw_ostream &os,
const SlowMPInt &x) {
x.print(os);
return os;
}
/// ---------------------------------------------------------------------------
/// Convenience operator overloads for int64_t.
/// ---------------------------------------------------------------------------
SlowMPInt &detail::operator+=(SlowMPInt &a, int64_t b) {
return a += SlowMPInt(b);
}
SlowMPInt &detail::operator-=(SlowMPInt &a, int64_t b) {
return a -= SlowMPInt(b);
}
SlowMPInt &detail::operator*=(SlowMPInt &a, int64_t b) {
return a *= SlowMPInt(b);
}
SlowMPInt &detail::operator/=(SlowMPInt &a, int64_t b) {
return a /= SlowMPInt(b);
}
SlowMPInt &detail::operator%=(SlowMPInt &a, int64_t b) {
return a %= SlowMPInt(b);
}
bool detail::operator==(const SlowMPInt &a, int64_t b) {
return a == SlowMPInt(b);
}
bool detail::operator!=(const SlowMPInt &a, int64_t b) {
return a != SlowMPInt(b);
}
bool detail::operator>(const SlowMPInt &a, int64_t b) {
return a > SlowMPInt(b);
}
bool detail::operator<(const SlowMPInt &a, int64_t b) {
return a < SlowMPInt(b);
}
bool detail::operator<=(const SlowMPInt &a, int64_t b) {
return a <= SlowMPInt(b);
}
bool detail::operator>=(const SlowMPInt &a, int64_t b) {
return a >= SlowMPInt(b);
}
SlowMPInt detail::operator+(const SlowMPInt &a, int64_t b) {
return a + SlowMPInt(b);
}
SlowMPInt detail::operator-(const SlowMPInt &a, int64_t b) {
return a - SlowMPInt(b);
}
SlowMPInt detail::operator*(const SlowMPInt &a, int64_t b) {
return a * SlowMPInt(b);
}
SlowMPInt detail::operator/(const SlowMPInt &a, int64_t b) {
return a / SlowMPInt(b);
}
SlowMPInt detail::operator%(const SlowMPInt &a, int64_t b) {
return a % SlowMPInt(b);
}
bool detail::operator==(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) == b;
}
bool detail::operator!=(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) != b;
}
bool detail::operator>(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) > b;
}
bool detail::operator<(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) < b;
}
bool detail::operator<=(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) <= b;
}
bool detail::operator>=(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) >= b;
}
SlowMPInt detail::operator+(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) + b;
}
SlowMPInt detail::operator-(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) - b;
}
SlowMPInt detail::operator*(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) * b;
}
SlowMPInt detail::operator/(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) / b;
}
SlowMPInt detail::operator%(int64_t a, const SlowMPInt &b) {
return SlowMPInt(a) % b;
}
static unsigned getMaxWidth(const APInt &a, const APInt &b) {
return std::max(a.getBitWidth(), b.getBitWidth());
}
/// ---------------------------------------------------------------------------
/// Comparison operators.
/// ---------------------------------------------------------------------------
// TODO: consider instead making APInt::compare available and using that.
bool SlowMPInt::operator==(const SlowMPInt &o) const {
unsigned width = getMaxWidth(val, o.val);
return val.sext(width) == o.val.sext(width);
}
bool SlowMPInt::operator!=(const SlowMPInt &o) const {
unsigned width = getMaxWidth(val, o.val);
return val.sext(width) != o.val.sext(width);
}
bool SlowMPInt::operator>(const SlowMPInt &o) const {
unsigned width = getMaxWidth(val, o.val);
return val.sext(width).sgt(o.val.sext(width));
}
bool SlowMPInt::operator<(const SlowMPInt &o) const {
unsigned width = getMaxWidth(val, o.val);
return val.sext(width).slt(o.val.sext(width));
}
bool SlowMPInt::operator<=(const SlowMPInt &o) const {
unsigned width = getMaxWidth(val, o.val);
return val.sext(width).sle(o.val.sext(width));
}
bool SlowMPInt::operator>=(const SlowMPInt &o) const {
unsigned width = getMaxWidth(val, o.val);
return val.sext(width).sge(o.val.sext(width));
}
/// ---------------------------------------------------------------------------
/// Arithmetic operators.
/// ---------------------------------------------------------------------------
/// Bring a and b to have the same width and then call op(a, b, overflow).
/// If the overflow bit becomes set, resize a and b to double the width and
/// call op(a, b, overflow), returning its result. The operation with double
/// widths should not also overflow.
APInt runOpWithExpandOnOverflow(
const APInt &a, const APInt &b,
llvm::function_ref<APInt(const APInt &, const APInt &, bool &overflow)>
op) {
bool overflow;
unsigned width = getMaxWidth(a, b);
APInt ret = op(a.sext(width), b.sext(width), overflow);
if (!overflow)
return ret;
width *= 2;
ret = op(a.sext(width), b.sext(width), overflow);
assert(!overflow && "double width should be sufficient to avoid overflow!");
return ret;
}
SlowMPInt SlowMPInt::operator+(const SlowMPInt &o) const {
return SlowMPInt(
runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sadd_ov)));
}
SlowMPInt SlowMPInt::operator-(const SlowMPInt &o) const {
return SlowMPInt(
runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::ssub_ov)));
}
SlowMPInt SlowMPInt::operator*(const SlowMPInt &o) const {
return SlowMPInt(
runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::smul_ov)));
}
SlowMPInt SlowMPInt::operator/(const SlowMPInt &o) const {
return SlowMPInt(
runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sdiv_ov)));
}
SlowMPInt detail::abs(const SlowMPInt &x) { return x >= 0 ? x : -x; }
SlowMPInt detail::ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) {
if (rhs == -1)
return -lhs;
unsigned width = getMaxWidth(lhs.val, rhs.val);
return SlowMPInt(llvm::APIntOps::RoundingSDiv(
lhs.val.sext(width), rhs.val.sext(width), APInt::Rounding::UP));
}
SlowMPInt detail::floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) {
if (rhs == -1)
return -lhs;
unsigned width = getMaxWidth(lhs.val, rhs.val);
return SlowMPInt(llvm::APIntOps::RoundingSDiv(
lhs.val.sext(width), rhs.val.sext(width), APInt::Rounding::DOWN));
}
// The RHS is always expected to be positive, and the result
/// is always non-negative.
SlowMPInt detail::mod(const SlowMPInt &lhs, const SlowMPInt &rhs) {
assert(rhs >= 1 && "mod is only supported for positive divisors!");
return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs;
}
SlowMPInt detail::gcd(const SlowMPInt &a, const SlowMPInt &b) {
assert(a >= 0 && b >= 0 && "operands must be non-negative!");
unsigned width = getMaxWidth(a.val, b.val);
return SlowMPInt(llvm::APIntOps::GreatestCommonDivisor(a.val.sext(width),
b.val.sext(width)));
}
/// Returns the least common multiple of 'a' and 'b'.
SlowMPInt detail::lcm(const SlowMPInt &a, const SlowMPInt &b) {
SlowMPInt x = abs(a);
SlowMPInt y = abs(b);
return (x * y) / gcd(x, y);
}
/// This operation cannot overflow.
SlowMPInt SlowMPInt::operator%(const SlowMPInt &o) const {
unsigned width = std::max(val.getBitWidth(), o.val.getBitWidth());
return SlowMPInt(val.sext(width).srem(o.val.sext(width)));
}
SlowMPInt SlowMPInt::operator-() const {
if (val.isMinSignedValue()) {
/// Overflow only occurs when the value is the minimum possible value.
APInt ret = val.sext(2 * val.getBitWidth());
return SlowMPInt(-ret);
}
return SlowMPInt(-val);
}
/// ---------------------------------------------------------------------------
/// Assignment operators, preincrement, predecrement.
/// ---------------------------------------------------------------------------
SlowMPInt &SlowMPInt::operator+=(const SlowMPInt &o) {
*this = *this + o;
return *this;
}
SlowMPInt &SlowMPInt::operator-=(const SlowMPInt &o) {
*this = *this - o;
return *this;
}
SlowMPInt &SlowMPInt::operator*=(const SlowMPInt &o) {
*this = *this * o;
return *this;
}
SlowMPInt &SlowMPInt::operator/=(const SlowMPInt &o) {
*this = *this / o;
return *this;
}
SlowMPInt &SlowMPInt::operator%=(const SlowMPInt &o) {
*this = *this % o;
return *this;
}
SlowMPInt &SlowMPInt::operator++() {
*this += 1;
return *this;
}
SlowMPInt &SlowMPInt::operator--() {
*this -= 1;
return *this;
}