308 lines
12 KiB
C++
308 lines
12 KiB
C++
//===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements constant folding on Linalg operations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
namespace {
|
|
/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
|
|
/// and permutation indexing maps.
|
|
///
|
|
/// `ConcreteType` should provide methods with signatures
|
|
///
|
|
/// ```c++
|
|
/// bool matchIndexingMaps(GenericOp genericOp) const;
|
|
/// RegionComputationFn getRegionComputeFn(GenericOp) const;
|
|
/// ```
|
|
///
|
|
/// The latter inspects the region and returns the computation inside as a
|
|
/// functor. The functor will be invoked with constant elements for all inputs
|
|
/// and should return the corresponding computed constant element for output.
|
|
template <typename ConcreteType>
|
|
class FoldConstantBase : public OpRewritePattern<GenericOp> {
|
|
public:
|
|
struct APIntOrFloat {
|
|
Optional<APInt> apInt;
|
|
Optional<APFloat> apFloat;
|
|
};
|
|
struct APIntOrFloatArray {
|
|
SmallVector<APInt> apInts;
|
|
SmallVector<APFloat> apFloats;
|
|
};
|
|
using RegionComputationFn =
|
|
std::function<APIntOrFloat(const APIntOrFloatArray &)>;
|
|
|
|
FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Mixed and buffer sematics aren't supported.
|
|
if (!genericOp.hasTensorSemantics())
|
|
return failure();
|
|
|
|
// Only support ops generating one output for now.
|
|
if (genericOp.getNumDpsInits() != 1)
|
|
return failure();
|
|
|
|
auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
|
|
// Require the output types to be static given that we are generating
|
|
// constants.
|
|
if (!outputType || !outputType.hasStaticShape())
|
|
return failure();
|
|
|
|
if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
|
|
return input.getType().isa<ShapedType>();
|
|
}))
|
|
return failure();
|
|
|
|
// Make sure all element types are the same.
|
|
auto getOperandElementType = [](Value value) {
|
|
return value.getType().cast<ShapedType>().getElementType();
|
|
};
|
|
if (!llvm::all_equal(
|
|
llvm::map_range(genericOp->getOperands(), getOperandElementType)))
|
|
return failure();
|
|
|
|
// We can only handle the case where we have int/float elements.
|
|
auto elementType = outputType.getElementType();
|
|
if (!elementType.isIntOrFloat())
|
|
return failure();
|
|
|
|
// Require all indexing maps to be permutations for now. This is common and
|
|
// it simplifies input/output access greatly: we can do the data shuffling
|
|
// entirely in the compiler, without needing to turn all indices into
|
|
// Values, and then do affine apply on them, and then match back the
|
|
// constant again.
|
|
if (!llvm::all_of(genericOp.getIndexingMapsArray(),
|
|
[](AffineMap map) { return map.isPermutation(); }))
|
|
return failure();
|
|
|
|
for (OpOperand *operand : genericOp.getDpsInitOperands()) {
|
|
if (genericOp.payloadUsesValueFromOperand(operand))
|
|
return failure();
|
|
}
|
|
|
|
// Further check the indexing maps are okay for the ConcreteType.
|
|
if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
|
|
return failure();
|
|
|
|
// Defer to the concrete type to check the region and discover the
|
|
// computation inside.
|
|
RegionComputationFn computeFn =
|
|
static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
|
|
if (!computeFn)
|
|
return failure();
|
|
|
|
// All inputs should be constants.
|
|
int numInputs = genericOp.getNumDpsInputs();
|
|
SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
|
|
for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
|
|
if (!matchPattern(en.value()->get(),
|
|
m_Constant(&inputValues[en.index()])))
|
|
return failure();
|
|
}
|
|
|
|
// Identified this as a potential candidate for folding. Now check the
|
|
// policy to see whether we are allowed to proceed.
|
|
for (OpOperand *operand : genericOp.getDpsInputOperands()) {
|
|
if (!controlFn(operand))
|
|
return failure();
|
|
}
|
|
|
|
auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
|
|
SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
|
|
int64_t numElements = outputType.getNumElements();
|
|
|
|
// Use APInt/APFloat instead of Attribute here for constructing the output.
|
|
// This helps to avoid blowing up compiler memory usage: Attributes would
|
|
// unify the following cases but they have lifetime as the MLIRContext.
|
|
SmallVector<APInt> intOutputValues;
|
|
SmallVector<APFloat> fpOutputValues;
|
|
if (elementType.template isa<FloatType>())
|
|
fpOutputValues.resize(numElements, APFloat(0.f));
|
|
else
|
|
intOutputValues.resize(numElements);
|
|
|
|
// Return the constant dim positions from the given permutation map.
|
|
auto getDimPositions = [](AffineMap map) {
|
|
SmallVector<unsigned> dims;
|
|
dims.reserve(map.getNumResults());
|
|
for (AffineExpr result : map.getResults()) {
|
|
dims.push_back(result.cast<AffineDimExpr>().getPosition());
|
|
}
|
|
return dims;
|
|
};
|
|
|
|
SmallVector<SmallVector<unsigned>> inputDims;
|
|
for (int i = 0; i < numInputs; ++i)
|
|
inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i]));
|
|
auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back());
|
|
auto outputShape = outputType.getShape();
|
|
|
|
// Allocate small vectors for index delinearization. Initial values do not
|
|
// matter here as they will be overwritten later.
|
|
SmallVector<uint64_t> indices(loopBounds.size(), 0);
|
|
SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
|
|
SmallVector<SmallVector<uint64_t>> srcIndices(
|
|
numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
|
|
SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
|
|
uint64_t dstLinearIndex = 0;
|
|
|
|
// Allocate spaces for compute function inputs. Initial values do not matter
|
|
// here as they will be overwritten later.
|
|
APIntOrFloatArray computeFnInputs;
|
|
|
|
auto inputShapes = llvm::to_vector<4>(
|
|
llvm::map_range(genericOp.getInputs(), [](Value value) {
|
|
return value.getType().cast<ShapedType>().getShape();
|
|
}));
|
|
|
|
// Given a `linearIndex`, remap it to a linear index to access linalg op
|
|
// inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
|
|
// `srcLinearIndices`, `dstLinearIndex` in place.
|
|
auto computeRemappedLinearIndex = [&](int linearIndex) {
|
|
int totalCount = linearIndex;
|
|
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
|
|
indices[dim] = totalCount % loopBounds[dim];
|
|
totalCount /= loopBounds[dim];
|
|
}
|
|
|
|
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
|
|
for (int i = 0; i < numInputs; ++i)
|
|
srcIndices[i][dim] = indices[inputDims[i][dim]];
|
|
dstIndices[dim] = indices[outputDims[dim]];
|
|
}
|
|
|
|
dstLinearIndex = dstIndices.front();
|
|
for (int i = 0; i < numInputs; ++i)
|
|
srcLinearIndices[i] = srcIndices[i].front();
|
|
|
|
for (int dim = 1; dim < outputType.getRank(); ++dim) {
|
|
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
|
|
for (int i = 0; i < numInputs; ++i)
|
|
srcLinearIndices[i] =
|
|
srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
|
|
}
|
|
};
|
|
|
|
bool isFloat = elementType.isa<FloatType>();
|
|
if (isFloat) {
|
|
SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
|
|
for (int i = 0; i < numInputs; ++i)
|
|
inFpRanges.push_back(inputValues[i].getValues<APFloat>());
|
|
|
|
computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
|
|
|
|
// Transpose the input constant. Because we don't know its rank in
|
|
// advance, we need to loop over the range [0, element count) and
|
|
// delinearize the index.
|
|
for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
|
|
computeRemappedLinearIndex(linearIndex);
|
|
|
|
// Collect constant elements for all inputs at this loop iteration.
|
|
for (int i = 0; i < numInputs; ++i)
|
|
computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
|
|
|
|
// Invoke the computation to get the corresponding constant output
|
|
// element.
|
|
fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
|
|
}
|
|
} else {
|
|
SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
|
|
for (int i = 0; i < numInputs; ++i)
|
|
inIntRanges.push_back(inputValues[i].getValues<APInt>());
|
|
|
|
computeFnInputs.apInts.resize(numInputs);
|
|
|
|
// Transpose the input constant. Because we don't know its rank in
|
|
// advance, we need to loop over the range [0, element count) and
|
|
// delinearize the index.
|
|
for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
|
|
computeRemappedLinearIndex(linearIndex);
|
|
|
|
// Collect constant elements for all inputs at this loop iteration.
|
|
for (int i = 0; i < numInputs; ++i)
|
|
computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
|
|
|
|
// Invoke the computation to get the corresponding constant output
|
|
// element.
|
|
intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
|
|
}
|
|
}
|
|
|
|
DenseElementsAttr outputAttr =
|
|
isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
|
|
: DenseElementsAttr::get(outputType, intOutputValues);
|
|
|
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
ControlFusionFn controlFn;
|
|
};
|
|
|
|
// Folds linalg.generic ops that are actually transposes on constant values.
|
|
struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
|
|
using FoldConstantBase::FoldConstantBase;
|
|
|
|
bool matchIndexingMaps(GenericOp genericOp) const {
|
|
// We should have one input and one output.
|
|
return genericOp.getIndexingMapsArray().size() == 2;
|
|
}
|
|
|
|
RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
|
|
// Make sure the region only contains a yield op.
|
|
Block &body = genericOp.getRegion().front();
|
|
if (!llvm::hasSingleElement(body))
|
|
return nullptr;
|
|
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
|
|
if (!yieldOp)
|
|
return nullptr;
|
|
|
|
// The yield op should return the block argument corresponds to the input.
|
|
for (Value yieldVal : yieldOp.getValues()) {
|
|
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
|
|
if (!yieldArg || yieldArg.getOwner() != &body)
|
|
return nullptr;
|
|
if (yieldArg.getArgNumber() != 0)
|
|
return nullptr;
|
|
}
|
|
|
|
// No computation; just return the orginal value.
|
|
return [](const APIntOrFloatArray &inputs) {
|
|
if (inputs.apFloats.empty())
|
|
return APIntOrFloat{inputs.apInts.front(), std::nullopt};
|
|
return APIntOrFloat{std::nullopt, inputs.apFloats.front()};
|
|
};
|
|
}
|
|
|
|
ControlFusionFn controlFn;
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::linalg::populateConstantFoldLinalgOperations(
|
|
RewritePatternSet &patterns, const ControlFusionFn &controlFn) {
|
|
MLIRContext *context = patterns.getContext();
|
|
patterns.insert<FoldConstantTranspose>(context, controlFn);
|
|
}
|