llvm-project/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

439 lines
16 KiB
C++

//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert Vector dialect to SPIRV dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include <numeric>
using namespace mlir;
/// Gets the first integer value from `attr`, assuming it is an integer array
/// attribute.
static uint64_t getFirstIntValue(ArrayAttr attr) {
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
}
namespace {
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
if (!dstType)
return failure();
if (dstType == adaptor.getSource().getType())
rewriter.replaceOp(bitcastOp, adaptor.getSource());
else
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
adaptor.getSource());
return success();
}
};
struct VectorBroadcastConvert final
: public OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = getTypeConverter()->convertType(castOp.getVectorType());
if (!resultType)
return failure();
if (resultType.isa<spirv::ScalarType>()) {
rewriter.replaceOp(castOp, adaptor.getSource());
return success();
}
SmallVector<Value, 4> source(castOp.getVectorType().getNumElements(),
adaptor.getSource());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
castOp, castOp.getVectorType(), source);
return success();
}
};
struct VectorExtractOpConvert final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only support extracting a scalar value now.
VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
if (resultVectorType && resultVectorType.getNumElements() > 1)
return failure();
Type dstType = getTypeConverter()->convertType(extractOp.getType());
if (!dstType)
return failure();
if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
int32_t id = getFirstIntValue(extractOp.getPosition());
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, adaptor.getVector(), id);
return success();
}
};
struct VectorExtractStridedSliceOpConvert final
: public OpConversionPattern<vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(extractOp.getType());
if (!dstType)
return failure();
uint64_t offset = getFirstIntValue(extractOp.getOffsets());
uint64_t size = getFirstIntValue(extractOp.getSizes());
uint64_t stride = getFirstIntValue(extractOp.getStrides());
if (stride != 1)
return failure();
Value srcVector = adaptor.getOperands().front();
// Extract vector<1xT> case.
if (dstType.isa<spirv::ScalarType>()) {
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
srcVector, offset);
return success();
}
SmallVector<int32_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), offset);
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
extractOp, dstType, srcVector, srcVector,
rewriter.getI32ArrayAttr(indices));
return success();
}
};
template <class SPIRVFMAOp>
struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(fmaOp.getType());
if (!dstType)
return failure();
rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
adaptor.getRhs(), adaptor.getAcc());
return success();
}
};
struct VectorInsertOpConvert final
: public OpConversionPattern<vector::InsertOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Special case for inserting scalar values into size-1 vectors.
if (insertOp.getSourceType().isIntOrFloat() &&
insertOp.getDestVectorType().getNumElements() == 1) {
rewriter.replaceOp(insertOp, adaptor.getSource());
return success();
}
if (insertOp.getSourceType().isa<VectorType>() ||
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
return failure();
int32_t id = getFirstIntValue(insertOp.getPosition());
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.getSource(), adaptor.getDest(), id);
return success();
}
};
struct VectorExtractElementOpConvert final
: public OpConversionPattern<vector::ExtractElementOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = getTypeConverter()->convertType(extractOp.getType());
if (!resultType)
return failure();
if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
APInt cstPos;
if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, resultType, adaptor.getVector(),
rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
else
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
return success();
}
};
struct VectorInsertElementOpConvert final
: public OpConversionPattern<vector::InsertElementOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type vectorType = getTypeConverter()->convertType(insertOp.getType());
if (!vectorType)
return failure();
if (vectorType.isa<spirv::ScalarType>()) {
rewriter.replaceOp(insertOp, adaptor.getSource());
return success();
}
APInt cstPos;
if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.getSource(), adaptor.getDest(),
cstPos.getSExtValue());
else
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
adaptor.getPosition());
return success();
}
};
struct VectorInsertStridedSliceOpConvert final
: public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value srcVector = adaptor.getOperands().front();
Value dstVector = adaptor.getOperands().back();
uint64_t stride = getFirstIntValue(insertOp.getStrides());
if (stride != 1)
return failure();
uint64_t offset = getFirstIntValue(insertOp.getOffsets());
if (srcVector.getType().isa<spirv::ScalarType>()) {
assert(!dstVector.getType().isa<spirv::ScalarType>());
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, dstVector.getType(), srcVector, dstVector,
rewriter.getI32ArrayAttr(offset));
return success();
}
uint64_t totalSize =
dstVector.getType().cast<VectorType>().getNumElements();
uint64_t insertSize =
srcVector.getType().cast<VectorType>().getNumElements();
SmallVector<int32_t, 2> indices(totalSize);
std::iota(indices.begin(), indices.end(), 0);
std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
totalSize);
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
insertOp, dstVector.getType(), dstVector, srcVector,
rewriter.getI32ArrayAttr(indices));
return success();
}
};
template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
struct VectorReductionPattern final
: public OpConversionPattern<vector::ReductionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = typeConverter->convertType(reduceOp.getType());
if (!resultType)
return failure();
auto srcVectorType = adaptor.getVector().getType().dyn_cast<VectorType>();
if (!srcVectorType || srcVectorType.getRank() != 1)
return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
// Extract all elements.
int numElements = srcVectorType.getDimSize(0);
SmallVector<Value, 4> values;
values.reserve(numElements + (adaptor.getAcc() != nullptr));
Location loc = reduceOp.getLoc();
for (int i = 0; i < numElements; ++i) {
values.push_back(rewriter.create<spirv::CompositeExtractOp>(
loc, srcVectorType.getElementType(), adaptor.getVector(),
rewriter.getI32ArrayAttr({i})));
}
if (Value acc = adaptor.getAcc())
values.push_back(acc);
// Reduce them.
Value result = values.front();
for (Value next : llvm::makeArrayRef(values).drop_front()) {
switch (reduceOp.getKind()) {
#define INT_AND_FLOAT_CASE(kind, iop, fop) \
case vector::CombiningKind::kind: \
if (resultType.isa<IntegerType>()) { \
result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
} else { \
assert(resultType.isa<FloatType>()); \
result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
} \
break
#define INT_OR_FLOAT_CASE(kind, fop) \
case vector::CombiningKind::kind: \
result = rewriter.create<fop>(loc, resultType, result, next); \
break
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
case vector::CombiningKind::AND:
case vector::CombiningKind::OR:
case vector::CombiningKind::XOR:
return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
}
}
rewriter.replaceOp(reduceOp, result);
return success();
}
};
class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
public:
using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return failure();
if (dstType.isa<spirv::ScalarType>()) {
rewriter.replaceOp(op, adaptor.getInput());
} else {
auto dstVecType = dstType.cast<VectorType>();
SmallVector<Value, 4> source(dstVecType.getNumElements(),
adaptor.getInput());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
source);
}
return success();
}
};
struct VectorShuffleOpConvert final
: public OpConversionPattern<vector::ShuffleOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto oldResultType = shuffleOp.getVectorType();
if (!spirv::CompositeType::isValid(oldResultType))
return failure();
Type newResultType = getTypeConverter()->convertType(oldResultType);
auto oldSourceType = shuffleOp.getV1VectorType();
if (oldSourceType.getNumElements() > 1) {
SmallVector<int32_t, 4> components = llvm::to_vector<4>(
llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
return attr.cast<IntegerAttr>().getValue().getZExtValue();
}));
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
rewriter.getI32ArrayAttr(components));
return success();
}
SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()};
SmallVector<Value, 4> newOperands;
newOperands.reserve(oldResultType.getNumElements());
for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) {
newOperands.push_back(oldOperands[i.getZExtValue()]);
}
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
shuffleOp, newResultType, newOperands);
return success();
}
};
} // namespace
#define CL_MAX_MIN_OPS \
spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
spirv::CLSMaxOp, spirv::CLSMinOp
#define GL_MAX_MIN_OPS \
spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
spirv::GLSMaxOp, spirv::GLSMinOp
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<
VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
VectorReductionPattern<CL_MAX_MIN_OPS>, VectorInsertStridedSliceOpConvert,
VectorShuffleOpConvert, VectorSplatPattern>(typeConverter,
patterns.getContext());
}