447 lines
20 KiB
C++
447 lines
20 KiB
C++
//===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements linalg transformation to break a reduction dimension
|
|
// between a parallel and a reduction dimension.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <utility>
|
|
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
|
|
PatternRewriter &b, LinalgOp op,
|
|
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
|
|
OpBuilder::InsertionGuard guard(b);
|
|
b.setInsertionPoint(op);
|
|
|
|
SplitReductionOptions control = controlSplitReductionFn(op);
|
|
int64_t ratio = control.ratio;
|
|
unsigned insertSplitIndex = control.index;
|
|
unsigned insertSplitDimension = control.index;
|
|
if (ratio <= 1)
|
|
return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
|
|
|
|
SmallVector<unsigned> dims;
|
|
op.getReductionDims(dims);
|
|
assert(dims.size() == 1);
|
|
unsigned reductionDim = dims[0];
|
|
if (control.innerParallel) {
|
|
insertSplitDimension = reductionDim + 1;
|
|
}
|
|
SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
|
|
int64_t reductionDimSize = loopRanges[reductionDim];
|
|
if (reductionDimSize == ShapedType::kDynamic ||
|
|
reductionDimSize % ratio != 0)
|
|
return b.notifyMatchFailure(
|
|
op, "Reduction dimension not divisible by split ratio");
|
|
if (op.getNumDpsInits() != 1)
|
|
return b.notifyMatchFailure(op, "More than one output in split reduction");
|
|
if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
|
|
return b.notifyMatchFailure(op, "Insert dimension position too large "
|
|
"compared to intermediate tensor size");
|
|
|
|
SmallVector<Operation *, 4> combinerOps;
|
|
if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
|
|
combinerOps.size() != 1)
|
|
return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
|
|
|
|
Operation *reductionOp = combinerOps[0];
|
|
Optional<Attribute> identity = getNeutralElement(reductionOp);
|
|
if (!identity.has_value())
|
|
return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
|
|
|
|
Location loc = op->getLoc();
|
|
SmallVector<Value> newInputs;
|
|
SmallVector<AffineMap> newMaps;
|
|
// Calculate the new shapes and indexing maps of the input operands.
|
|
for (OpOperand *operand : op.getDpsInputOperands()) {
|
|
AffineMap map = op.getMatchingIndexingMap(operand);
|
|
SmallVector<int64_t> newShape;
|
|
SmallVector<AffineExpr> exprs;
|
|
SmallVector<ReassociationIndices> reassociation;
|
|
unsigned index = 0;
|
|
for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
|
|
unsigned dim = map.getDimPosition(idx);
|
|
if (reductionDim == dim) {
|
|
if (control.innerParallel) {
|
|
newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
|
|
newShape.push_back(ratio); // parallel (insert)
|
|
exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1));
|
|
exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
|
|
} else {
|
|
newShape.push_back(ratio); // parallel (insert)
|
|
newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
|
|
exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
|
|
exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1));
|
|
}
|
|
reassociation.push_back({index++, index++});
|
|
continue;
|
|
}
|
|
newShape.push_back(op.getShape(operand)[idx]);
|
|
exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
|
|
reassociation.push_back({index++});
|
|
}
|
|
newMaps.push_back(
|
|
AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
|
|
// If the shape is unchanged the input doesn't change.
|
|
if (newShape == op.getShape(operand)) {
|
|
newInputs.push_back(operand->get());
|
|
continue;
|
|
}
|
|
Type newType = RankedTensorType::get(
|
|
newShape,
|
|
operand->get().getType().cast<RankedTensorType>().getElementType());
|
|
Value newInput = b.create<tensor::ExpandShapeOp>(
|
|
loc, newType, operand->get(), reassociation);
|
|
newInputs.push_back(newInput);
|
|
}
|
|
|
|
// Calculate the new output map and shape, we insert the new dimension based
|
|
// on the index returned by `controlSplitReductionFn`.
|
|
SmallVector<int64_t> newOutputShape;
|
|
AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
|
|
ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0));
|
|
SmallVector<AffineExpr> outputExpr;
|
|
for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
|
|
if (insertSplitIndex == idx) {
|
|
newOutputShape.push_back(ratio);
|
|
outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
|
|
}
|
|
if (idx < oldShape.size()) {
|
|
newOutputShape.push_back(oldShape[idx]);
|
|
unsigned dim = oldOutputMap.getDimPosition(idx);
|
|
outputExpr.push_back(
|
|
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
|
|
}
|
|
}
|
|
Value emptyOrAllocTensor;
|
|
if (useAlloc) {
|
|
emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>(
|
|
loc,
|
|
RankedTensorType::get(newOutputShape,
|
|
op.getRegionOutputArgs()[0].getType()),
|
|
ValueRange{});
|
|
} else {
|
|
emptyOrAllocTensor = b.create<tensor::EmptyOp>(
|
|
loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
|
|
}
|
|
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
|
|
Value identityTensor =
|
|
b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)
|
|
.getResult(0);
|
|
|
|
newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
|
|
op.getContext()));
|
|
SmallVector<utils::IteratorType> newIteratorTypes;
|
|
for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) {
|
|
if (insertSplitDimension == it.index())
|
|
newIteratorTypes.push_back(utils::IteratorType::parallel);
|
|
newIteratorTypes.push_back(it.value());
|
|
}
|
|
if (insertSplitDimension == op.getIteratorTypesArray().size()) {
|
|
newIteratorTypes.push_back(utils::IteratorType::parallel);
|
|
}
|
|
// Create the new op matching the original op with an extra parallel
|
|
// dimension.
|
|
GenericOp genericOp = b.create<GenericOp>(
|
|
loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs,
|
|
ValueRange({identityTensor}), newMaps, newIteratorTypes);
|
|
b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
|
|
genericOp.getRegion().begin());
|
|
|
|
// Then create a new reduction that only reduce the newly added dimension
|
|
// from the previous op.
|
|
unsigned intermRank = newOutputShape.size();
|
|
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
|
|
SmallVector<utils::IteratorType> reductionIteratorTypes;
|
|
SmallVector<AffineExpr> exprs;
|
|
for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
|
|
if (insertSplitIndex == i) {
|
|
reductionIteratorTypes.push_back(utils::IteratorType::reduction);
|
|
} else {
|
|
exprs.push_back(b.getAffineDimExpr(i));
|
|
reductionIteratorTypes.push_back(utils::IteratorType::parallel);
|
|
}
|
|
}
|
|
AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
|
|
SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
|
|
|
|
auto reduction = b.create<GenericOp>(
|
|
loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
|
|
SmallVector<Value>{op.getDpsInitOperands()}, reductionMaps,
|
|
reductionIteratorTypes,
|
|
[reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
|
|
Operation *clonedReductionOp = b.clone(*reductionOp);
|
|
clonedReductionOp->setOperand(0, inputs[0]);
|
|
clonedReductionOp->setOperand(1, inputs[1]);
|
|
b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
|
|
});
|
|
b.replaceOp(op, reduction.getResults());
|
|
|
|
return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(),
|
|
identityTensor.getDefiningOp<FillOp>(),
|
|
cast<LinalgOp>(genericOp.getOperation()),
|
|
reduction};
|
|
}
|
|
|
|
/// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
|
|
/// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
|
|
/// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
|
|
/// done as a transform to enable better vectorization.
|
|
static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
|
|
unsigned reductionDimPos,
|
|
int64_t reductionRatio) {
|
|
auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
|
|
auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
|
|
AffineMap map = op.getMatchingIndexingMap(&opOperand);
|
|
AffineMap idMap =
|
|
AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
|
|
AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
|
|
AffineMap composeMap = shiftedIdMap.replace(
|
|
reductionDim, reductionDim * reductionRatio + reductionDimP1,
|
|
shiftedIdMap.getNumDims(), /*numSymbols=*/0);
|
|
return map.compose(composeMap);
|
|
}
|
|
|
|
static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
|
|
unsigned reductionDimPos, int64_t size) {
|
|
auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
|
|
AffineMap map = op.getMatchingIndexingMap(&opOperand);
|
|
AffineMap idMap =
|
|
AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
|
|
AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
|
|
return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
|
|
}
|
|
|
|
/// Core rewrite implementation.
|
|
FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
|
|
PatternRewriter &b, LinalgOp op,
|
|
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
|
|
OpBuilder::InsertionGuard guard(b);
|
|
b.setInsertionPoint(op);
|
|
|
|
// Matcher part, enforce preconditions.
|
|
SplitReductionOptions control = controlSplitReductionFn(op);
|
|
if (control.innerParallel)
|
|
return b.notifyMatchFailure(op, "innerParallel not supported");
|
|
|
|
int64_t splitFactor = control.ratio;
|
|
unsigned insertSplitDimension = control.index;
|
|
if (splitFactor <= 1)
|
|
return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
|
|
|
|
SmallVector<unsigned> dims;
|
|
op.getReductionDims(dims);
|
|
if (dims.empty())
|
|
return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
|
|
|
|
unsigned reductionDimPos = dims[0];
|
|
SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
|
|
int64_t reductionDimSize = loopRanges[reductionDimPos];
|
|
if (reductionDimSize == ShapedType::kDynamic ||
|
|
reductionDimSize % splitFactor != 0 ||
|
|
insertSplitDimension >= loopRanges.size())
|
|
return b.notifyMatchFailure(
|
|
op, "first reduction dimension not divisible by split factor");
|
|
|
|
SmallVector<Operation *> combinerOps;
|
|
if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
|
|
return b.notifyMatchFailure(op, "cannot match a reduction pattern");
|
|
|
|
SmallVector<Attribute> neutralElements;
|
|
for (Operation *reductionOp : combinerOps) {
|
|
Optional<Attribute> neutralElement = getNeutralElement(reductionOp);
|
|
if (!neutralElement.has_value())
|
|
return b.notifyMatchFailure(op, "cannot find neutral element.");
|
|
neutralElements.push_back(*neutralElement);
|
|
}
|
|
if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
|
|
return b.notifyMatchFailure(op, "unknown reduction neutral");
|
|
|
|
// TODO: relax this when multi-reduction support is available.
|
|
if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size()))
|
|
return b.notifyMatchFailure(op, "expect one reduction per output");
|
|
|
|
// Rewrite part.
|
|
// Step 1. Build the intermediate outputs filled with the proper
|
|
// neutralElements. Such outputs are of the same shape with an extra dimension
|
|
// inserted at `insertSplitDimension`.
|
|
//
|
|
// Consider a minimal example where `k` is reduced:
|
|
// O(i, j) += I(i, j, k)
|
|
// Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
|
|
// The compute is rewritten as:
|
|
// a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
|
|
// b. O(i, j) += O_i(kk, i, j)
|
|
// The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
|
|
Location loc = op->getLoc();
|
|
MLIRContext *context = op.getContext();
|
|
// For now assume outputs are 1-1 with reduction neutralElements.
|
|
// TODO: generalize when multi-reduction support is available.
|
|
SmallVector<Value> newOutputs;
|
|
newOutputs.reserve(op.getNumDpsInits());
|
|
SmallVector<Operation *> emptyOrAllocTensorOps;
|
|
SmallVector<linalg::FillOp> fillOps;
|
|
fillOps.reserve(op.getNumDpsInits());
|
|
for (auto it : llvm::zip(op.getDpsInitOperands(), neutralElements)) {
|
|
Value rankedTensor = std::get<0>(it)->get();
|
|
auto t = rankedTensor.getType().cast<RankedTensorType>();
|
|
RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
|
|
reductionDimSize / splitFactor, insertSplitDimension);
|
|
SmallVector<Value> dims =
|
|
tensor::createDynamicDimValues(b, loc, rankedTensor);
|
|
Value emptyOrAllocTensor;
|
|
if (useAlloc) {
|
|
emptyOrAllocTensor =
|
|
b.create<bufferization::AllocTensorOp>(loc, newT, dims);
|
|
} else {
|
|
emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(),
|
|
t.getElementType(), dims);
|
|
}
|
|
Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
|
|
fillOps.push_back(
|
|
b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor));
|
|
newOutputs.push_back(fillOps.back().getResult(0));
|
|
emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp());
|
|
}
|
|
|
|
// Step 2. Reindex / expand indexing maps.
|
|
// Reindex existing input indexings: k -> k * splitFactor + k'.
|
|
SmallVector<AffineMap> newMaps;
|
|
newMaps.reserve(op->getNumOperands() + 1);
|
|
for (OpOperand *o : op.getDpsInputOperands())
|
|
newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
|
|
// Provision a new indexing for the shape-only tensor.
|
|
auto nDims = op.getNumLoops() + 1;
|
|
auto redDim = getAffineDimExpr(reductionDimPos, context);
|
|
auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
|
|
newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
|
|
// Expand existing output indexings.
|
|
// TODO: a subset of these may not reduce along reducePos and should be
|
|
// reindexed: k -> k * splitFactor + k', when multi-reduction support is
|
|
// available.
|
|
for (OpOperand *o : op.getDpsInitOperands())
|
|
newMaps.push_back(insertParallelDim(op, *o, reductionDimPos,
|
|
reductionDimSize / splitFactor));
|
|
|
|
// Step 3. Handle operands.
|
|
// Compute the new input tensors.
|
|
SmallVector<Value> newInputs(op.getDpsInputOperands());
|
|
// Add a single shape-only tensor to carry the dimensions without resorting to
|
|
// more complex inversions.
|
|
newInputs.push_back(b.create<tensor::EmptyOp>(
|
|
loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
|
|
b.getIntegerType(1)));
|
|
// Output tensors are already good to go.
|
|
|
|
// Step 4. Create the new op matching the original op with an extra parallel
|
|
// dimension.
|
|
auto iteratorTypes = op.getIteratorTypesArray();
|
|
iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
|
|
utils::IteratorType::parallel);
|
|
GenericOp genericOp =
|
|
b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
|
|
newOutputs, newMaps, iteratorTypes);
|
|
b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
|
|
genericOp.getRegion().begin());
|
|
genericOp.getRegion().front().insertArgument(reductionDimPos,
|
|
b.getIntegerType(1), loc);
|
|
|
|
// Step 5. Create new reduction ops that only reduce the newly added
|
|
// dimensions from the previous op.
|
|
// For now assume outputs are 1-1 with reduction ops.
|
|
// TODO: a subset of these may not reduce in the first place and do not
|
|
// require a new op, when multi-reduction support is available.
|
|
// TODO: all results can be handled in a single GenericOp, when
|
|
// multi-reduction support is available.
|
|
SmallVector<LinalgOp> results;
|
|
for (auto it : llvm::zip(genericOp->getResults(), op.getDpsInitOperands(),
|
|
combinerOps)) {
|
|
Value reindexedOutput = std::get<0>(it);
|
|
Value originalOutput = std::get<1>(it)->get();
|
|
auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
|
|
Operation *combinerOp = std::get<2>(it);
|
|
|
|
AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
|
|
SmallVector<AffineMap> indexingMaps = {
|
|
map, map.dropResult(insertSplitDimension)};
|
|
SmallVector<utils::IteratorType> reductionIteratorTypes(
|
|
originalOutputType.getRank() + 1, utils::IteratorType::parallel);
|
|
reductionIteratorTypes[insertSplitDimension] =
|
|
utils::IteratorType::reduction;
|
|
|
|
// clang-format off
|
|
auto reductionOp = b.create<GenericOp>(
|
|
loc,
|
|
originalOutputType,
|
|
reindexedOutput,
|
|
originalOutput,
|
|
indexingMaps,
|
|
reductionIteratorTypes,
|
|
[combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
|
|
Operation *clonedReductionOp = b.clone(*combinerOp);
|
|
clonedReductionOp->setOperand(0, bbArgs[0]);
|
|
clonedReductionOp->setOperand(1, bbArgs[1]);
|
|
b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
|
|
});
|
|
// clang-format on
|
|
|
|
results.push_back(reductionOp);
|
|
}
|
|
|
|
// TODO: extend when multi-reduction support is available.
|
|
assert(fillOps.size() == results.size() && results.size() == 1);
|
|
b.replaceOp(op, results.front()->getResults());
|
|
return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(),
|
|
cast<LinalgOp>(genericOp.getOperation()),
|
|
results.front()};
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
|
|
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
|
|
LinalgSplitReduction(MLIRContext *context,
|
|
ControlSplitReductionFn controlSplitReductionFn,
|
|
bool useAlloc = false, PatternBenefit benefit = 1)
|
|
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
|
|
controlSplitReductionFn(std::move(controlSplitReductionFn)),
|
|
useAlloc(useAlloc) {}
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
|
|
}
|
|
|
|
private:
|
|
ControlSplitReductionFn controlSplitReductionFn;
|
|
bool useAlloc;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void linalg::populateSplitReductionPattern(
|
|
RewritePatternSet &patterns,
|
|
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
|
|
patterns.add<LinalgSplitReduction>(patterns.getContext(),
|
|
controlSplitReductionFn, useAlloc);
|
|
}
|