llvm-project/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

198 lines
7.9 KiB
C++

//===- VectorUtils.cpp - MLIR Utilities for VectorOps ------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements utility methods for working with the Vector dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
int64_t dim) {
if (source.getType().isa<UnrankedMemRefType, MemRefType>())
return b.createOrFold<memref::DimOp>(loc, source, dim);
if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
return b.createOrFold<tensor::DimOp>(loc, source, dim);
llvm_unreachable("Expected MemRefType or TensorType");
}
/// Constructs a permutation map from memref indices to vector dimension.
///
/// The implementation uses the knowledge of the mapping of enclosing loop to
/// vector dimension. `enclosingLoopToVectorDim` carries this information as a
/// map with:
/// - keys representing "vectorized enclosing loops";
/// - values representing the corresponding vector dimension.
/// The algorithm traverses "vectorized enclosing loops" and extracts the
/// at-most-one MemRef index that is invariant along said loop. This index is
/// guaranteed to be at most one by construction: otherwise the MemRef is not
/// vectorizable.
/// If this invariant index is found, it is added to the permutation_map at the
/// proper vector dimension.
/// If no index is found to be invariant, 0 is added to the permutation_map and
/// corresponds to a vector broadcast along that dimension.
///
/// Returns an empty AffineMap if `enclosingLoopToVectorDim` is empty,
/// signalling that no permutation map can be constructed given
/// `enclosingLoopToVectorDim`.
///
/// Examples can be found in the documentation of `makePermutationMap`, in the
/// header file.
static AffineMap makePermutationMap(
ArrayRef<Value> indices,
const DenseMap<Operation *, unsigned> &enclosingLoopToVectorDim) {
if (enclosingLoopToVectorDim.empty())
return AffineMap();
MLIRContext *context =
enclosingLoopToVectorDim.begin()->getFirst()->getContext();
SmallVector<AffineExpr> perm(enclosingLoopToVectorDim.size(),
getAffineConstantExpr(0, context));
for (auto kvp : enclosingLoopToVectorDim) {
assert(kvp.second < perm.size());
auto invariants = getInvariantAccesses(
cast<AffineForOp>(kvp.first).getInductionVar(), indices);
unsigned numIndices = indices.size();
unsigned countInvariantIndices = 0;
for (unsigned dim = 0; dim < numIndices; ++dim) {
if (!invariants.count(indices[dim])) {
assert(perm[kvp.second] == getAffineConstantExpr(0, context) &&
"permutationMap already has an entry along dim");
perm[kvp.second] = getAffineDimExpr(dim, context);
} else {
++countInvariantIndices;
}
}
assert((countInvariantIndices == numIndices ||
countInvariantIndices == numIndices - 1) &&
"Vectorization prerequisite violated: at most 1 index may be "
"invariant wrt a vectorized loop");
(void)countInvariantIndices;
}
return AffineMap::get(indices.size(), 0, perm, context);
}
/// Implementation detail that walks up the parents and records the ones with
/// the specified type.
/// TODO: could also be implemented as a collect parents followed by a
/// filter and made available outside this file.
template <typename T>
static SetVector<Operation *> getParentsOfType(Block *block) {
SetVector<Operation *> res;
auto *current = block->getParentOp();
while (current) {
if (auto typedParent = dyn_cast<T>(current)) {
assert(res.count(current) == 0 && "Already inserted");
res.insert(current);
}
current = current->getParentOp();
}
return res;
}
/// Returns the enclosing AffineForOp, from closest to farthest.
static SetVector<Operation *> getEnclosingforOps(Block *block) {
return getParentsOfType<AffineForOp>(block);
}
AffineMap mlir::makePermutationMap(
Block *insertPoint, ArrayRef<Value> indices,
const DenseMap<Operation *, unsigned> &loopToVectorDim) {
DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
auto enclosingLoops = getEnclosingforOps(insertPoint);
for (auto *forInst : enclosingLoops) {
auto it = loopToVectorDim.find(forInst);
if (it != loopToVectorDim.end()) {
enclosingLoopToVectorDim.insert(*it);
}
}
return ::makePermutationMap(indices, enclosingLoopToVectorDim);
}
AffineMap mlir::makePermutationMap(
Operation *op, ArrayRef<Value> indices,
const DenseMap<Operation *, unsigned> &loopToVectorDim) {
return makePermutationMap(op->getBlock(), indices, loopToVectorDim);
}
bool matcher::operatesOnSuperVectorsOf(Operation &op,
VectorType subVectorType) {
// First, extract the vector type and distinguish between:
// a. ops that *must* lower a super-vector (i.e. vector.transfer_read,
// vector.transfer_write); and
// b. ops that *may* lower a super-vector (all other ops).
// The ops that *may* lower a super-vector only do so if the super-vector to
// sub-vector ratio exists. The ops that *must* lower a super-vector are
// explicitly checked for this property.
/// TODO: there should be a single function for all ops to do this so we
/// do not have to special case. Maybe a trait, or just a method, unclear atm.
bool mustDivide = false;
(void)mustDivide;
VectorType superVectorType;
if (auto transfer = dyn_cast<VectorTransferOpInterface>(op)) {
superVectorType = transfer.getVectorType();
mustDivide = true;
} else if (op.getNumResults() == 0) {
if (!isa<func::ReturnOp>(op)) {
op.emitError("NYI: assuming only return operations can have 0 "
" results at this point");
}
return false;
} else if (op.getNumResults() == 1) {
if (auto v = op.getResult(0).getType().dyn_cast<VectorType>()) {
superVectorType = v;
} else {
// Not a vector type.
return false;
}
} else {
// Not a vector.transfer and has more than 1 result, fail hard for now to
// wake us up when something changes.
op.emitError("NYI: operation has more than 1 result");
return false;
}
// Get the ratio.
auto ratio =
computeShapeRatio(superVectorType.getShape(), subVectorType.getShape());
// Sanity check.
assert((ratio || !mustDivide) &&
"vector.transfer operation in which super-vector size is not an"
" integer multiple of sub-vector size");
// This catches cases that are not strictly necessary to have multiplicity but
// still aren't divisible by the sub-vector shape.
// This could be useful information if we wanted to reshape at the level of
// the vector type (but we would have to look at the compute and distinguish
// between parallel, reduction and possibly other cases.
return ratio.has_value();
}