135 lines
5.8 KiB
C++
135 lines
5.8 KiB
C++
//===- Split.cpp - Structured op splitting --------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/Interfaces/TilingInterface.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
/// Creates a part of the given `op` split along the iteration space `dimension`
|
|
/// with the given `size` and an optional `offset` (default 0). Makes slices
|
|
/// of operands, using the input operands of the original op and the output
|
|
/// operands provided as `resultOperands`. Expects `offsets` and `sizes` to
|
|
/// define the shape of the iteration space of the original op. Returns the
|
|
/// split-out op as well as the output operand values updated with the partial
|
|
/// results produced by this op through `results`.
|
|
static TilingInterface
|
|
createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
|
|
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
|
|
ValueRange resultOperands, unsigned dimension,
|
|
OpFoldResult size, OpFoldResult offset,
|
|
SmallVectorImpl<Value> &results) {
|
|
// Iteration space of the current part.
|
|
SmallVector<OpFoldResult> sizesCopy = llvm::to_vector(sizes);
|
|
SmallVector<OpFoldResult> offsetsCopy = llvm::to_vector(offsets);
|
|
sizesCopy[dimension] = size;
|
|
offsetsCopy[dimension] = offset;
|
|
|
|
// Create the part as it it were a single tile.
|
|
SmallVector<Operation *> tiled =
|
|
op.getTiledImplementation(b, offsetsCopy, sizesCopy);
|
|
assert(tiled.size() == 1 && "expected a single result from tiling");
|
|
auto part = cast<TilingInterface>(tiled.front());
|
|
|
|
// Insert the results back and populate the `results` list.
|
|
for (auto i : llvm::seq<unsigned>(0, part->getNumResults())) {
|
|
SmallVector<OpFoldResult> resultOffsets, resultSizes;
|
|
if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy,
|
|
resultOffsets, resultSizes)))
|
|
return nullptr;
|
|
SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
|
|
b.getIndexAttr(1));
|
|
Value inserted = b.create<tensor::InsertSliceOp>(
|
|
loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes,
|
|
resultStrides);
|
|
results.push_back(inserted);
|
|
}
|
|
|
|
return part;
|
|
}
|
|
|
|
std::pair<TilingInterface, TilingInterface>
|
|
linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
|
|
OpFoldResult splitPoint) {
|
|
// Compute the iteration space.
|
|
SmallVector<Range> iterationSpace = op.getIterationDomain(rewriter);
|
|
|
|
// Bail out on dimension overflow.
|
|
if (dimension >= iterationSpace.size())
|
|
return std::make_pair(op, TilingInterface());
|
|
|
|
SmallVector<OpFoldResult> offsets = llvm::to_vector(llvm::map_range(
|
|
iterationSpace, [](const Range &range) { return range.offset; }));
|
|
SmallVector<OpFoldResult> sizes = llvm::to_vector(llvm::map_range(
|
|
iterationSpace, [](const Range &range) { return range.size; }));
|
|
|
|
// Adjust the split point so that it doesn't overflow the size.
|
|
AffineExpr d0, d1, d2;
|
|
bindDims(rewriter.getContext(), d0, d1, d2);
|
|
OpFoldResult minSplitPoint = makeComposedFoldedAffineMin(
|
|
rewriter, op.getLoc(),
|
|
AffineMap::inferFromExprList(ArrayRef<AffineExpr>{d0, d1 + d2}).front(),
|
|
{splitPoint, offsets[dimension], sizes[dimension]});
|
|
|
|
// Compute the size of the second part. Return early if the second part would
|
|
// have an empty iteration space.
|
|
OpFoldResult remainingSize = makeComposedFoldedAffineApply(
|
|
rewriter, op.getLoc(), d0 + d1 - d2,
|
|
{iterationSpace[dimension].offset, iterationSpace[dimension].size,
|
|
minSplitPoint});
|
|
if (auto attr = remainingSize.dyn_cast<Attribute>()) {
|
|
if (attr.cast<IntegerAttr>().getValue().isZero())
|
|
return {op, TilingInterface()};
|
|
}
|
|
|
|
// Compute destination tensors.
|
|
SmallVector<Value> destinationTensors;
|
|
LogicalResult destStatus = tensor::getOrCreateDestinations(
|
|
rewriter, op.getLoc(), op, destinationTensors);
|
|
(void)destStatus;
|
|
assert(succeeded(destStatus) && "failed to get destination tensors");
|
|
|
|
// Create the first part.
|
|
SmallVector<Value> firstResults;
|
|
TilingInterface firstPart = createSplitPart(
|
|
rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension,
|
|
minSplitPoint, iterationSpace[dimension].offset, firstResults);
|
|
|
|
// Need to pretend that the original op now takes as operands firstResults,
|
|
// otherwise tiling interface implementation will take the wrong value to
|
|
// produce data tiles.
|
|
rewriter.updateRootInPlace(op, [&]() {
|
|
unsigned numTotalOperands = op->getNumOperands();
|
|
unsigned numOutputOperands = firstResults.size();
|
|
op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
|
|
firstResults);
|
|
});
|
|
|
|
// Create the second part.
|
|
OpFoldResult totalOffset = makeComposedFoldedAffineApply(
|
|
rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint});
|
|
SmallVector<Value> secondResults;
|
|
TilingInterface secondPart =
|
|
createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults,
|
|
dimension, remainingSize, totalOffset, secondResults);
|
|
|
|
// Replace the original op with the results of the two newly created ops.
|
|
rewriter.replaceOp(op, secondResults);
|
|
return {firstPart, secondPart};
|
|
}
|