llvm-project/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp

389 lines
14 KiB
C++

//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
// Pattern to convert vector operations to scalar operations.
template <typename Op>
struct VecOpToScalarOp : public OpRewritePattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
};
// Callback type for getting pre-generated FuncOp implementing
// a power operation of the given type.
using GetPowerFuncCallbackTy = function_ref<func::FuncOp(Type)>;
// Pattern to convert scalar IPowIOp into a call of outlined
// software implementation.
struct IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
private:
GetPowerFuncCallbackTy getFuncOpCallback;
public:
IPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb)
: OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
/// Convert IPowI into a call to a local function implementing
/// the power operation. The local function computes a scalar result,
/// so vector forms of IPowI are linearized.
LogicalResult matchAndRewrite(math::IPowIOp op,
PatternRewriter &rewriter) const final;
};
} // namespace
template <typename Op>
LogicalResult
VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
Type opType = op.getType();
Location loc = op.getLoc();
auto vecType = opType.template dyn_cast<VectorType>();
if (!vecType)
return rewriter.notifyMatchFailure(op, "not a vector operation");
if (!vecType.hasRank())
return rewriter.notifyMatchFailure(op, "unknown vector rank");
ArrayRef<int64_t> shape = vecType.getShape();
int64_t numElements = vecType.getNumElements();
Value result = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(
vecType, IntegerAttr::get(vecType.getElementType(), 0)));
SmallVector<int64_t> strides = computeStrides(shape);
for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
SmallVector<int64_t> positions = delinearize(strides, linearIndex);
SmallVector<Value> operands;
for (Value input : op->getOperands())
operands.push_back(
rewriter.create<vector::ExtractOp>(loc, input, positions));
Value scalarOp =
rewriter.create<Op>(loc, vecType.getElementType(), operands);
result =
rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
}
rewriter.replaceOp(op, result);
return success();
}
/// Create linkonce_odr function to implement the power function with
/// the given \p funcType type inside \p module. \p funcType must be
/// 'IntegerType (*)(IntegerType, IntegerType)' function type.
///
/// template <typename T>
/// T __mlir_math_ipowi_*(T b, T p) {
/// if (p == T(0))
/// return T(1);
/// if (p < T(0)) {
/// if (b == T(0))
/// return T(1) / T(0); // trigger div-by-zero
/// if (b == T(1))
/// return T(1);
/// if (b == T(-1)) {
/// if (p & T(1))
/// return T(-1);
/// return T(1);
/// }
/// return T(0);
/// }
/// T result = T(1);
/// while (true) {
/// if (p & T(1))
/// result *= b;
/// p >>= T(1);
/// if (p == T(0))
/// return result;
/// b *= b;
/// }
/// }
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
assert(elementType.isa<IntegerType>() &&
"non-integer element type for IPowIOp");
// IntegerType elementType = funcType.getInput(0).cast<IntegerType>();
ImplicitLocOpBuilder builder =
ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
std::string funcName("__mlir_math_ipowi");
llvm::raw_string_ostream nameOS(funcName);
nameOS << '_' << elementType;
FunctionType funcType = FunctionType::get(
builder.getContext(), {elementType, elementType}, elementType);
auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
Attribute linkage =
LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
funcOp->setAttr("llvm.linkage", linkage);
funcOp.setPrivate();
Block *entryBlock = funcOp.addEntryBlock();
Region *funcBody = entryBlock->getParent();
Value bArg = funcOp.getArgument(0);
Value pArg = funcOp.getArgument(1);
builder.setInsertionPointToEnd(entryBlock);
Value zeroValue = builder.create<arith::ConstantOp>(
elementType, builder.getIntegerAttr(elementType, 0));
Value oneValue = builder.create<arith::ConstantOp>(
elementType, builder.getIntegerAttr(elementType, 1));
Value minusOneValue = builder.create<arith::ConstantOp>(
elementType,
builder.getIntegerAttr(elementType,
APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
/*isSigned=*/true)));
// if (p == T(0))
// return T(1);
auto pIsZero =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
Block *thenBlock = builder.createBlock(funcBody);
builder.create<func::ReturnOp>(oneValue);
Block *fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p == T(0)).
builder.setInsertionPointToEnd(pIsZero->getBlock());
builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
// if (p < T(0)) {
builder.setInsertionPointToEnd(fallthroughBlock);
auto pIsNeg =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
// if (b == T(0))
builder.createBlock(funcBody);
auto bIsZero =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
// return T(1) / T(0);
thenBlock = builder.createBlock(funcBody);
builder.create<func::ReturnOp>(
builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult());
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (b == T(0)).
builder.setInsertionPointToEnd(bIsZero->getBlock());
builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
// if (b == T(1))
builder.setInsertionPointToEnd(fallthroughBlock);
auto bIsOne =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
// return T(1);
thenBlock = builder.createBlock(funcBody);
builder.create<func::ReturnOp>(oneValue);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (b == T(1)).
builder.setInsertionPointToEnd(bIsOne->getBlock());
builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
// if (b == T(-1)) {
builder.setInsertionPointToEnd(fallthroughBlock);
auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
bArg, minusOneValue);
// if (p & T(1))
builder.createBlock(funcBody);
auto pIsOdd = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue),
zeroValue);
// return T(-1);
thenBlock = builder.createBlock(funcBody);
builder.create<func::ReturnOp>(minusOneValue);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p & T(1)).
builder.setInsertionPointToEnd(pIsOdd->getBlock());
builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
// return T(1);
// } // b == T(-1)
builder.setInsertionPointToEnd(fallthroughBlock);
builder.create<func::ReturnOp>(oneValue);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (b == T(-1)).
builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
fallthroughBlock);
// return T(0);
// } // (p < T(0))
builder.setInsertionPointToEnd(fallthroughBlock);
builder.create<func::ReturnOp>(zeroValue);
Block *loopHeader = builder.createBlock(
funcBody, funcBody->end(), {elementType, elementType, elementType},
{builder.getLoc(), builder.getLoc(), builder.getLoc()});
// Set up conditional branch for (p < T(0)).
builder.setInsertionPointToEnd(pIsNeg->getBlock());
// Set initial values of 'result', 'b' and 'p' for the loop.
builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
ValueRange{oneValue, bArg, pArg});
// T result = T(1);
// while (true) {
// if (p & T(1))
// result *= b;
// p >>= T(1);
// if (p == T(0))
// return result;
// b *= b;
// }
Value resultTmp = loopHeader->getArgument(0);
Value baseTmp = loopHeader->getArgument(1);
Value powerTmp = loopHeader->getArgument(2);
builder.setInsertionPointToEnd(loopHeader);
// if (p & T(1))
auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::ne,
builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
thenBlock = builder.createBlock(funcBody);
// result *= b;
Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp);
fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
builder.getLoc());
builder.setInsertionPointToEnd(thenBlock);
builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
// Set up conditional branch for (p & T(1)).
builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
resultTmp);
// Merged 'result'.
newResultTmp = fallthroughBlock->getArgument(0);
// p >>= T(1);
builder.setInsertionPointToEnd(fallthroughBlock);
Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue);
// if (p == T(0))
auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
newPowerTmp, zeroValue);
// return result;
thenBlock = builder.createBlock(funcBody);
builder.create<func::ReturnOp>(newResultTmp);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p == T(0)).
builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
// b *= b;
// }
builder.setInsertionPointToEnd(fallthroughBlock);
Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp);
// Pass new values for 'result', 'b' and 'p' to the loop header.
builder.create<cf::BranchOp>(
ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
return funcOp;
}
/// Convert IPowI into a call to a local function implementing
/// the power operation. The local function computes a scalar result,
/// so vector forms of IPowI are linearized.
LogicalResult
IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
PatternRewriter &rewriter) const {
auto baseType = op.getOperands()[0].getType().dyn_cast<IntegerType>();
if (!baseType)
return rewriter.notifyMatchFailure(op, "non-integer base operand");
// The outlined software implementation must have been already
// generated.
func::FuncOp elementFunc = getFuncOpCallback(baseType);
if (!elementFunc)
return rewriter.notifyMatchFailure(op, "missing software implementation");
rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
return success();
}
namespace {
struct ConvertMathToFuncsPass
: public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
ConvertMathToFuncsPass() = default;
void runOnOperation() override;
private:
// Generate outlined implementations for power operations
// and store them in powerFuncs map.
void preprocessPowOperations();
// A map between function types deduced from power operations
// and the corresponding outlined software implementations
// of these operations.
DenseMap<Type, func::FuncOp> powerFuncs;
};
} // namespace
void ConvertMathToFuncsPass::preprocessPowOperations() {
ModuleOp module = getOperation();
module.walk([&](Operation *op) {
TypeSwitch<Operation *>(op).Case<math::IPowIOp>([&](math::IPowIOp op) {
Type resultType = getElementTypeOrSelf(op.getResult().getType());
// Generate the software implementation of this operation,
// if it has not been generated yet.
auto entry = powerFuncs.try_emplace(resultType, func::FuncOp{});
if (entry.second)
entry.first->second = createElementIPowIFunc(&module, resultType);
});
});
}
void ConvertMathToFuncsPass::runOnOperation() {
ModuleOp module = getOperation();
// Create outlined implementations for power operations.
preprocessPowOperations();
RewritePatternSet patterns(&getContext());
patterns.add<VecOpToScalarOp<math::IPowIOp>>(patterns.getContext());
// For the given Type Returns FuncOp stored in powerFuncs map.
auto getPowerFuncOpByType = [&](Type type) -> func::FuncOp {
auto it = powerFuncs.find(type);
if (it == powerFuncs.end())
return {};
return it->second;
};
patterns.add<IPowIOpLowering>(patterns.getContext(), getPowerFuncOpByType);
ConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
func::FuncDialect, vector::VectorDialect>();
target.addIllegalOp<math::IPowIOp>();
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<Pass> mlir::createConvertMathToFuncsPass() {
return std::make_unique<ConvertMathToFuncsPass>();
}