232 lines
7.6 KiB
C++
232 lines
7.6 KiB
C++
//===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements rewrites based on the basic rules of algebra
|
|
// (Commutativity, associativity, etc...) and strength reductions for math
|
|
// operations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include <climits>
|
|
|
|
using namespace mlir;
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// PowFOp strength reduction.
|
|
//----------------------------------------------------------------------------//
|
|
|
|
namespace {
|
|
struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(math::PowFOp op,
|
|
PatternRewriter &rewriter) const final;
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult
|
|
PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
|
|
PatternRewriter &rewriter) const {
|
|
Location loc = op.getLoc();
|
|
Value x = op.getLhs();
|
|
|
|
FloatAttr scalarExponent;
|
|
DenseFPElementsAttr vectorExponent;
|
|
|
|
bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
|
|
bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
|
|
|
|
// Returns true if exponent is a constant equal to `value`.
|
|
auto isExponentValue = [&](double value) -> bool {
|
|
if (isScalar)
|
|
return scalarExponent.getValue().isExactlyValue(value);
|
|
|
|
if (isVector && vectorExponent.isSplat())
|
|
return vectorExponent.getSplatValue<FloatAttr>()
|
|
.getValue()
|
|
.isExactlyValue(value);
|
|
|
|
return false;
|
|
};
|
|
|
|
// Maybe broadcasts scalar value into vector type compatible with `op`.
|
|
auto bcast = [&](Value value) -> Value {
|
|
if (auto vec = op.getType().dyn_cast<VectorType>())
|
|
return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
|
|
return value;
|
|
};
|
|
|
|
// Replace `pow(x, 1.0)` with `x`.
|
|
if (isExponentValue(1.0)) {
|
|
rewriter.replaceOp(op, x);
|
|
return success();
|
|
}
|
|
|
|
// Replace `pow(x, 2.0)` with `x * x`.
|
|
if (isExponentValue(2.0)) {
|
|
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
|
|
return success();
|
|
}
|
|
|
|
// Replace `pow(x, 3.0)` with `x * x * x`.
|
|
if (isExponentValue(3.0)) {
|
|
Value square =
|
|
rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
|
|
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
|
|
return success();
|
|
}
|
|
|
|
// Replace `pow(x, -1.0)` with `1.0 / x`.
|
|
if (isExponentValue(-1.0)) {
|
|
Value one = rewriter.create<arith::ConstantOp>(
|
|
loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
|
|
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
|
|
return success();
|
|
}
|
|
|
|
// Replace `pow(x, 0.5)` with `sqrt(x)`.
|
|
if (isExponentValue(0.5)) {
|
|
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
|
|
return success();
|
|
}
|
|
|
|
// Replace `pow(x, -0.5)` with `rsqrt(x)`.
|
|
if (isExponentValue(-0.5)) {
|
|
rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
|
|
return success();
|
|
}
|
|
|
|
// Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
|
|
if (isExponentValue(0.75)) {
|
|
Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
|
|
Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
|
|
rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
|
|
ValueRange{powHalf, powQuarter});
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// FPowIOp/IPowIOp strength reduction.
|
|
//----------------------------------------------------------------------------//
|
|
|
|
namespace {
|
|
template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
|
|
struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> {
|
|
|
|
unsigned exponentThreshold;
|
|
|
|
public:
|
|
PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
|
|
PatternBenefit benefit = 1,
|
|
ArrayRef<StringRef> generatedNames = {})
|
|
: OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
|
|
exponentThreshold(exponentThreshold) {}
|
|
|
|
LogicalResult matchAndRewrite(PowIOpTy op,
|
|
PatternRewriter &rewriter) const final;
|
|
};
|
|
} // namespace
|
|
|
|
template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
|
|
LogicalResult
|
|
PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
|
|
PowIOpTy op, PatternRewriter &rewriter) const {
|
|
Location loc = op.getLoc();
|
|
Value base = op.getLhs();
|
|
|
|
IntegerAttr scalarExponent;
|
|
DenseIntElementsAttr vectorExponent;
|
|
|
|
bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
|
|
bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
|
|
|
|
// Simplify cases with known exponent value.
|
|
int64_t exponentValue = 0;
|
|
if (isScalar)
|
|
exponentValue = scalarExponent.getInt();
|
|
else if (isVector && vectorExponent.isSplat())
|
|
exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
|
|
else
|
|
return failure();
|
|
|
|
// Maybe broadcasts scalar value into vector type compatible with `op`.
|
|
auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
|
|
if (auto vec = op.getType().template dyn_cast<VectorType>())
|
|
return rewriter.create<vector::BroadcastOp>(loc, vec, value);
|
|
return value;
|
|
};
|
|
|
|
Value one;
|
|
Type opType = getElementTypeOrSelf(op.getType());
|
|
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
|
|
one = rewriter.create<arith::ConstantOp>(
|
|
loc, rewriter.getFloatAttr(opType, 1.0));
|
|
else
|
|
one = rewriter.create<arith::ConstantOp>(
|
|
loc, rewriter.getIntegerAttr(opType, 1));
|
|
|
|
// Replace `[fi]powi(x, 0)` with `1`.
|
|
if (exponentValue == 0) {
|
|
rewriter.replaceOp(op, bcast(one));
|
|
return success();
|
|
}
|
|
|
|
bool exponentIsNegative = false;
|
|
if (exponentValue < 0) {
|
|
exponentIsNegative = true;
|
|
exponentValue *= -1;
|
|
}
|
|
|
|
// Bail out if `abs(exponent)` exceeds the threshold.
|
|
if (exponentValue > exponentThreshold)
|
|
return failure();
|
|
|
|
// Inverse the base for negative exponent, i.e. for
|
|
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
|
|
if (exponentIsNegative)
|
|
base = rewriter.create<DivOpTy>(loc, bcast(one), base);
|
|
|
|
Value result = base;
|
|
// Transform to naive sequence of multiplications:
|
|
// * For positive exponent case replace:
|
|
// `[fi]powi(x, positive_exponent)`
|
|
// with:
|
|
// x * x * x * ...
|
|
// * For negative exponent case replace:
|
|
// `[fi]powi(x, negative_exponent)`
|
|
// with:
|
|
// (1 / x) * (1 / x) * (1 / x) * ...
|
|
for (unsigned i = 1; i < exponentValue; ++i)
|
|
result = rewriter.create<MulOpTy>(loc, result, base);
|
|
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
|
|
//----------------------------------------------------------------------------//
|
|
|
|
void mlir::populateMathAlgebraicSimplificationPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns
|
|
.add<PowFStrengthReduction,
|
|
PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
|
|
PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
|
|
patterns.getContext());
|
|
}
|