llvm-project/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp

385 lines
16 KiB
C++

//===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
using namespace mlir;
using namespace mlir::linalg;
namespace {
/// Pattern to decompose a GenericOp that has more than two statements
/// into one GenericOp with the first statement (i.e. peeled operation), and
/// a second GenericOp with the remaining statements (i.e. residual operations).
/// - The result of the first GenericOp has the same shape as the iteration
/// space of the GenericOp. The body of the op yields as many values as the
/// original op plus all the results of the peeled operation.
/// - The second GenericOp has as many operands as the original operation plus
/// all the results of the first Generic Op. It has the same number of yields as
/// the original op.
/// - If the result of the peeled operation was yielded by the original
/// GenericOp the uses of the corresponding results will be replaced with the
/// result of the first GenericOp created.
///
/// Example
///
/// ```mlir
/// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
/// outs(%init0, %init1 : ...) {
/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...):
/// %0 = <s0> %b0, %b1 : ...
/// %1 = <s1> %0, %b2 : ...
/// linalg.yield %0, %1 : ...
/// } -> (..., ...)
/// return %result#0, %result#1
/// ```
///
/// gets split into
///
/// ```mlir
/// %init = tensor.empty ...
/// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...)
/// outs(%init0, %init1, %init : ...)
/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
/// %0 = <s0> %b0, %b1 : ...
/// linalg.yield %0, %..., %0 : ...
/// } -> (..., ..., ...)
/// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...)
/// outs(%init0, %init1 : ...) {
/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...):
/// %1 = <s1> %b3, %b2 : ...
/// linalg.yield %..., %1 : ...
/// } -> (..., ...)
/// return %op0#0, %op1#1
/// ```
///
/// After canonicalization this is expected to be
///
/// ```mlir
/// %init = tensor.empty ...
/// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...)
/// outs(%init : ...)
/// ^bb0(%b0: ... , %b1: ... , %b2: ...):
/// %0 = <s0> %b0, %b1 : ...
/// linalg.yield %0 : ...
/// } -> ...
/// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...)
/// outs(%init1 : ...) {
/// ^bb0(%b0: ... , %b1: ... , %b2: ...):
/// %1 = <s1> %b1, %b0 : ...
/// linalg.yield %..., %1 : ...
/// } -> ...
/// return %op0, %op1
/// ```
struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override;
private:
/// Helper method to create a generic op for the peeled scalar operation. The
/// created op has an empty region.
GenericOp createPeeledGenericOp(GenericOp genericOp,
PatternRewriter &rewriter) const;
/// Helper method to create a generic op for the residual scalar operation.
/// The created op has the same region as the original op.
GenericOp createResidualGenericOp(GenericOp genericOp,
GenericOp peeledGenericOp,
PatternRewriter &rewriter) const;
};
} // namespace
/// Helper method to compute the range of a generic op.
static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b,
GenericOp op) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
Location loc = op.getLoc();
auto allShapesSizes =
cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc);
AffineMap map = op.getShapesToLoopsMap();
IRRewriter rewriter(b);
return makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
allShapesSizes);
}
/// Helper method to permute the list of `values` based on the `map`.
SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values,
AffineMap map) {
assert(map.isPermutation());
SmallVector<OpFoldResult> permutedValues(values.size());
for (const auto &position :
llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
return expr.cast<AffineDimExpr>().getPosition();
})))
permutedValues[position.value()] = values[position.index()];
return permutedValues;
}
/// Get zero value for an element type.
static Value getZero(OpBuilder &b, Location loc, Type elementType) {
assert(elementType.isIntOrIndexOrFloat() &&
"expected scalar type while computing zero value");
if (elementType.isa<IntegerType>())
return b.create<arith::ConstantIntOp>(loc, 0, elementType);
if (elementType.isIndex())
return b.create<arith::ConstantIndexOp>(loc, 0);
// Assume float.
auto floatType = elementType.cast<FloatType>();
return b.create<arith::ConstantFloatOp>(
loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
}
GenericOp
DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
PatternRewriter &rewriter) const {
Block *body = genericOp.getBody();
Operation *peeledScalarOperation = &(*body->begin());
SmallVector<AffineMap> peeledGenericOpIndexingMaps =
genericOp.getIndexingMapsArray();
/// Compute the loop ranges for operation. This is the shape of the result of
/// the generic op for the peeled operation.
Location loc = genericOp.getLoc();
SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp);
SmallVector<Value> newInitValues;
SmallVector<Type> newResultTypes;
// Add as many new results as the number of results of the peeled scalar op.
for (auto scalarOpResult : peeledScalarOperation->getResults()) {
// If the result is yielded by the original op, use the operand, indexing
// map and result type that correspond to the yielded value.
Optional<unsigned> resultNumber;
for (auto *user : scalarOpResult.getUsers()) {
if (auto yieldOp = dyn_cast<YieldOp>(user)) {
// Find the first use of the `scalarOpResult` in the yield op.
for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
if (yieldOperand.get() == scalarOpResult) {
resultNumber = yieldOperand.getOperandNumber();
break;
}
}
assert(resultNumber && "unable to find use of a value in its user");
break;
}
}
if (resultNumber) {
newInitValues.push_back(
genericOp.getDpsInitOperand(*resultNumber)->get());
OpResult result = genericOp.getResult(*resultNumber).cast<OpResult>();
newResultTypes.push_back(result.getType());
peeledGenericOpIndexingMaps.push_back(
genericOp.getIndexingMapMatchingResult(result));
continue;
}
// Fall back path, use an `init_tensor` and identity indexing map.
AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
Value emptyTensor =
rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType());
newInitValues.push_back(emptyTensor);
newResultTypes.push_back(emptyTensor.getType());
peeledGenericOpIndexingMaps.push_back(indexingMap);
}
/// Create the peeled generic op with an empty body.
SmallVector<Value> outsOperands = genericOp.getOutputs();
outsOperands.append(newInitValues.begin(), newInitValues.end());
SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
resultTypes.append(newResultTypes.begin(), newResultTypes.end());
auto indexingMapAttr =
rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
return rewriter.create<GenericOp>(
loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,
genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
[](OpBuilder, Location, ValueRange) {});
}
GenericOp
DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
GenericOp peeledGenericOp,
PatternRewriter &rewriter) const {
/// Append all results from the peeledGenericOps as `ins` operand for the
/// residual generic op.
SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
unsigned origNumResults = genericOp.getNumResults();
unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
SmallVector<Value> extraIns;
for (auto resultNum :
llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults))
extraIns.push_back(peeledGenericOp->getResult(resultNum));
residualGenericOpOperands.append(extraIns);
/// Add indexing maps for the newly added operands. Use the same map
/// as those used for the new results of the peeledGenericOp.
auto indexingMaps = llvm::to_vector(
llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
return genericOp.getMatchingIndexingMap(operand);
}));
for (auto resultNum :
llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
OpResult result = peeledGenericOp.getResult(resultNum).cast<OpResult>();
indexingMaps.push_back(
peeledGenericOp.getIndexingMapMatchingResult(result));
}
for (OpOperand *outOperand : genericOp.getDpsInitOperands())
indexingMaps.push_back(genericOp.getMatchingIndexingMap(outOperand));
auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
return rewriter.create<GenericOp>(
genericOp->getLoc(), genericOp->getResultTypes(),
residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
[](OpBuilder, Location, ValueRange) {});
}
LogicalResult
DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const {
/// For now only match on operations where the iterator types are all parallel
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
return rewriter.notifyMatchFailure(genericOp,
"unhandled decomposition of operation "
"with non-parallel iterator types");
}
// TODO: this could be generalized to handle `linalg.generic` with buffer
// operands too but requires allocation for intermediates. Punt on this for
// now.
if (!genericOp.hasTensorSemantics()) {
return rewriter.notifyMatchFailure(
genericOp, "only operations with tensor semantics are handled");
}
if (llvm::any_of(genericOp.getDpsInitOperands(), [&](OpOperand *outOperand) {
return !genericOp.getMatchingIndexingMap(outOperand).isPermutation();
})) {
return rewriter.notifyMatchFailure(
genericOp, "unhandled decomposition of generic op with out operand not "
"accessed using a permutation");
}
/// If the op has only a single statement (apart from the yield), do nothing.
Block *body = genericOp.getBody();
if (body->getOperations().size() <= 2) {
return rewriter.notifyMatchFailure(genericOp,
"operation has less than 3 statements");
}
/// Check that the peeled statement has a scalar element type.
if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
[](Type t) { return !t.isIntOrIndexOrFloat(); })) {
return rewriter.notifyMatchFailure(
&(*body->getOperations().begin()),
"expected return type to be only int, index or float");
}
GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
GenericOp residualGenericOp =
createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
/// Move the first statement of the original operation into the body of the
/// generic op for the peeled operation.
Block *peeledGenericOpBody = peeledGenericOp.getBody();
Block *residualGenericOpBody = residualGenericOp.getBody();
assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
"expected split generic ops to have empty region");
peeledGenericOpBody->getOperations().splice(
peeledGenericOpBody->begin(), body->getOperations(), body->begin());
residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
body->getOperations());
Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
auto *yieldOp = residualGenericOpBody->getTerminator();
{
// Yield all the result of the peeled scalar operation.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToEnd(peeledGenericOpBody);
SmallVector<Value> yieldedVals;
for (auto origYield : yieldOp->getOperands()) {
if (origYield.getDefiningOp() == peeledScalarOperation) {
yieldedVals.push_back(origYield);
} else {
yieldedVals.push_back(
getZero(rewriter, genericOp.getLoc(), origYield.getType()));
}
}
yieldedVals.append(llvm::to_vector(
llvm::map_range(peeledScalarOperation->getResults(),
[](OpResult opr) -> Value { return opr; })));
rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
}
/// In the split operations, replace block arguments uses that refer to
/// original operation to the block arguments of the newly created operation.
unsigned origNumInputs = genericOp.getNumDpsInputs();
for (const auto &inputBlockArg :
llvm::enumerate(genericOp.getBody()->getArguments())) {
Value residualOpReplacementArg =
residualGenericOpBody->getArgument(inputBlockArg.index());
inputBlockArg.value().replaceUsesWithIf(
residualOpReplacementArg, [&](OpOperand &use) {
return use.getOwner()->getBlock() == residualGenericOpBody;
});
Value peeledOpReplacementArg =
peeledGenericOpBody->getArgument(inputBlockArg.index());
inputBlockArg.value().replaceUsesWithIf(
peeledOpReplacementArg, [&](OpOperand &use) {
return use.getOwner()->getBlock() == peeledGenericOpBody;
});
}
/// Before fixing up the residual operation, track what values are yielded. If
/// any of those are from the peeled scalar operation, the uses of the
/// corresponding result have to be remapped to result of the generic op for
/// the peeled operation.
SmallVector<Value> replacements;
for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
OpResult opr = yieldValue.value().dyn_cast<OpResult>();
if (!opr || opr.getOwner() != peeledScalarOperation)
replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
else
replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
}
/// Update all uses of the peeled scalar operation results in the residual op
/// to the newly added arguments.
{
SmallVector<Value> scalarReplacements;
unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
scalarReplacements.reserve(peeledScalarOpNumResults);
for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults))
scalarReplacements.push_back(
residualGenericOpBody->getArgument(num + origNumInputs));
bool allUsesReplaced = false;
rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
residualGenericOpBody, &allUsesReplaced);
assert(!allUsesReplaced &&
"peeled scalar operation is erased when it wasnt expected to be");
}
// Replace the original operation
rewriter.replaceOp(genericOp, replacements);
return success();
}
void mlir::linalg::populateDecomposeLinalgOpsPattern(
RewritePatternSet &patterns, bool removeDeadArgsAndResults) {
patterns.insert<DecomposeLinalgOp>(patterns.getContext());
// Add the patterns to clean up the dead operands and results.
if (removeDeadArgsAndResults)
populateEraseUnusedOperandsAndResultsPatterns(patterns);
}