llvm-project/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

205 lines
8.5 KiB
C++

//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
// SPIRV Dialect ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
// See SPV_NV_cooperative_matrix for supported element wise ops.
static void createElementWiseOp(ConversionPatternRewriter &builder,
gpu::SubgroupMmaElementwiseOp op,
spirv::CooperativeMatrixNVType coopType,
ValueRange operands) {
switch (op.getOpType()) {
case gpu::MMAElementwiseOp::ADDF:
builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
return;
case gpu::MMAElementwiseOp::ADDI:
builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
return;
case gpu::MMAElementwiseOp::SUBF:
builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
return;
case gpu::MMAElementwiseOp::SUBI:
builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
return;
case gpu::MMAElementwiseOp::DIVF:
builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
return;
case gpu::MMAElementwiseOp::DIVS:
builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
return;
case gpu::MMAElementwiseOp::DIVU:
builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
return;
case gpu::MMAElementwiseOp::NEGATEF:
builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
return;
case gpu::MMAElementwiseOp::NEGATES:
builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
return;
default:
llvm_unreachable("unknown op");
}
}
namespace {
/// This class implements the conversion of GPU MMA loadOp to
/// CooperativeMatrixLoad op in the SPIRV dialect.
struct WmmaLoadOpToSPIRVLowering
: public OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = subgroupMmaLoadMatrixOp->getLoc();
gpu::MMAMatrixType retType =
subgroupMmaLoadMatrixOp.getRes().getType().cast<gpu::MMAMatrixType>();
auto memrefType =
subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast<MemRefType>();
Value bufferPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType,
adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
auto coopType = convertMMAToSPIRVType(retType);
int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
auto i32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, IntegerAttr::get(i32Type, stride));
bool useColMajor =
static_cast<bool>(subgroupMmaLoadMatrixOp.getTranspose());
auto columnMajor = rewriter.create<spirv::ConstantOp>(
loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor));
rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixLoadOp>(
subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor,
spirv::MemoryAccessAttr());
return success();
}
};
/// This class implements the conversion of GPU MMA StoreOp to
/// CooperativeMatrixStore op in the SPIRV dialect.
struct WmmaStoreOpToSPIRVLowering
: public OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = subgroupMmaStoreMatrixOp->getLoc();
auto memrefType =
subgroupMmaStoreMatrixOp.getDstMemref().getType().cast<MemRefType>();
Value bufferPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType,
adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter);
int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue();
auto i32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, IntegerAttr::get(i32Type, stride));
auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixStoreOp>(
subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue,
coloumnMajor, spirv::MemoryAccessAttr());
return success();
}
};
/// This class implements the conversion of GPU MMA Compute to
/// CooperativeMatrixMulAdd op in the SPIRV dialect.
struct WmmaMmaOpToSPIRVLowering
: public OpConversionPattern<gpu::SubgroupMmaComputeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixMulAddOp>(
subgroupMmaComputeOp, adaptor.getOpC().getType(), adaptor.getOpA(),
adaptor.getOpB(), adaptor.getOpC());
return success();
}
};
/// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops.
struct WmmaConstantOpToSPIRVLowering
: public OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value cst = adaptor.getOperands()[0];
auto coopType = convertMMAToSPIRVType(
subgroupMmaConstantMatrixOp.getType().cast<gpu::MMAMatrixType>());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
subgroupMmaConstantMatrixOp, coopType, cst);
return success();
}
};
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops.
struct WmmaElementwiseOpToSPIRVLowering
: public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// All operands should be of cooperative matrix types.
for (Value operand : adaptor.getOperands()) {
if (!operand.getType().isa<spirv::CooperativeMatrixNVType>())
return failure();
}
auto coopType = convertMMAToSPIRVType(
subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>());
createElementWiseOp(rewriter, subgroupMmaElementwiseOp, coopType,
adaptor.getOperands());
return success();
}
};
} // namespace
/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
mlir::spirv::CooperativeMatrixNVType
mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) {
ArrayRef<int64_t> retTypeShape = type.getShape();
Type elementType = type.getElementType();
return spirv::CooperativeMatrixNVType::get(
elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
}
void mlir::populateGpuWMMAToSPIRVConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
WmmaElementwiseOpToSPIRVLowering>(converter,
patterns.getContext());
}