llvm-project/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp

599 lines
24 KiB
C++

//===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass that unifies access of multiple aliased resources
// into access of one single resource.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include <algorithm>
#include <iterator>
namespace mlir {
namespace spirv {
#define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
} // namespace spirv
} // namespace mlir
#define DEBUG_TYPE "spirv-unify-aliased-resource"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
using AliasedResourceMap =
DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
/// Collects all aliased resources in the given SPIR-V `moduleOp`.
static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
AliasedResourceMap aliasedResources;
moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
if (varOp->getAttrOfType<UnitAttr>("aliased")) {
Optional<uint32_t> set = varOp.getDescriptorSet();
Optional<uint32_t> binding = varOp.getBinding();
if (set && binding)
aliasedResources[{*set, *binding}].push_back(varOp);
}
});
return aliasedResources;
}
/// Returns the element type if the given `type` is a runtime array resource:
/// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type
/// otherwise.
static Type getRuntimeArrayElementType(Type type) {
auto ptrType = type.dyn_cast<spirv::PointerType>();
if (!ptrType)
return {};
auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
if (!structType || structType.getNumElements() != 1)
return {};
auto rtArrayType =
structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>();
if (!rtArrayType)
return {};
return rtArrayType.getElementType();
}
/// Given a list of resource element `types`, returns the index of the canonical
/// resource that all resources should be unified into. Returns std::nullopt if
/// unable to unify.
static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
// scalarNumBits: contains all resources' scalar types' bit counts.
// vectorNumBits: only contains resources whose element types are vectors.
// vectorIndices: each vector's original index in `types`.
SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
scalarNumBits.reserve(types.size());
vectorNumBits.reserve(types.size());
vectorIndices.reserve(types.size());
for (const auto &indexedTypes : llvm::enumerate(types)) {
spirv::SPIRVType type = indexedTypes.value();
assert(type.isScalarOrVector());
if (auto vectorType = type.dyn_cast<VectorType>()) {
if (vectorType.getNumElements() % 2 != 0)
return std::nullopt; // Odd-sized vector has special layout
// requirements.
Optional<int64_t> numBytes = type.getSizeInBytes();
if (!numBytes)
return std::nullopt;
scalarNumBits.push_back(
vectorType.getElementType().getIntOrFloatBitWidth());
vectorNumBits.push_back(*numBytes * 8);
vectorIndices.push_back(indexedTypes.index());
} else {
scalarNumBits.push_back(type.getIntOrFloatBitWidth());
}
}
if (!vectorNumBits.empty()) {
// Choose the *vector* with the smallest bitwidth as the canonical resource,
// so that we can still keep vectorized load/store and avoid partial updates
// to large vectors.
auto *minVal = std::min_element(vectorNumBits.begin(), vectorNumBits.end());
// Make sure that the canonical resource's bitwidth is divisible by others.
// With out this, we cannot properly adjust the index later.
if (llvm::any_of(vectorNumBits,
[&](int bits) { return bits % *minVal != 0; }))
return std::nullopt;
// Require all scalar type bit counts to be a multiple of the chosen
// vector's primitive type to avoid reading/writing subcomponents.
int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
int baseNumBits = scalarNumBits[index];
if (llvm::any_of(scalarNumBits,
[&](int bits) { return bits % baseNumBits != 0; }))
return std::nullopt;
return index;
}
// All element types are scalars. Then choose the smallest bitwidth as the
// cannonical resource to avoid subcomponent load/store.
auto *minVal = std::min_element(scalarNumBits.begin(), scalarNumBits.end());
if (llvm::any_of(scalarNumBits,
[minVal](int64_t bit) { return bit % *minVal != 0; }))
return std::nullopt;
return std::distance(scalarNumBits.begin(), minVal);
}
static bool areSameBitwidthScalarType(Type a, Type b) {
return a.isIntOrFloat() && b.isIntOrFloat() &&
a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
}
//===----------------------------------------------------------------------===//
// Analysis
//===----------------------------------------------------------------------===//
namespace {
/// A class for analyzing aliased resources.
///
/// Resources are expected to be spirv.GlobalVarible that has a descriptor set
/// and binding number. Such resources are of the type
/// `!spirv.ptr<!spirv.struct<...>>` per Vulkan requirements.
///
/// Right now, we only support the case that there is a single runtime array
/// inside the struct.
class ResourceAliasAnalysis {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis)
explicit ResourceAliasAnalysis(Operation *);
/// Returns true if the given `op` can be rewritten to use a canonical
/// resource.
bool shouldUnify(Operation *op) const;
/// Returns all descriptors and their corresponding aliased resources.
const AliasedResourceMap &getResourceMap() const { return resourceMap; }
/// Returns the canonical resource for the given descriptor/variable.
spirv::GlobalVariableOp
getCanonicalResource(const Descriptor &descriptor) const;
spirv::GlobalVariableOp
getCanonicalResource(spirv::GlobalVariableOp varOp) const;
/// Returns the element type for the given variable.
spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
private:
/// Given the descriptor and aliased resources bound to it, analyze whether we
/// can unify them and record if so.
void recordIfUnifiable(const Descriptor &descriptor,
ArrayRef<spirv::GlobalVariableOp> resources);
/// Mapping from a descriptor to all aliased resources bound to it.
AliasedResourceMap resourceMap;
/// Mapping from a descriptor to the chosen canonical resource.
DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
/// Mapping from an aliased resource to its descriptor.
DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
/// Mapping from an aliased resource to its element (scalar/vector) type.
DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
};
} // namespace
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
// Collect all aliased resources first and put them into different sets
// according to the descriptor.
AliasedResourceMap aliasedResources =
collectAliasedResources(cast<spirv::ModuleOp>(root));
// For each resource set, analyze whether we can unify; if so, try to identify
// a canonical resource, whose element type has the largest bitwidth.
for (const auto &descriptorResource : aliasedResources) {
recordIfUnifiable(descriptorResource.first, descriptorResource.second);
}
}
bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
auto canonicalOp = getCanonicalResource(varOp);
return canonicalOp && varOp != canonicalOp;
}
if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
auto *varOp =
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
return shouldUnify(varOp);
}
if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
return shouldUnify(acOp.getBasePtr().getDefiningOp());
if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
return shouldUnify(loadOp.getPtr().getDefiningOp());
if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
return shouldUnify(storeOp.getPtr().getDefiningOp());
return false;
}
spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
const Descriptor &descriptor) const {
auto varIt = canonicalResourceMap.find(descriptor);
if (varIt == canonicalResourceMap.end())
return {};
return varIt->second;
}
spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
spirv::GlobalVariableOp varOp) const {
auto descriptorIt = descriptorMap.find(varOp);
if (descriptorIt == descriptorMap.end())
return {};
return getCanonicalResource(descriptorIt->second);
}
spirv::SPIRVType
ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
auto it = elementTypeMap.find(varOp);
if (it == elementTypeMap.end())
return {};
return it->second;
}
void ResourceAliasAnalysis::recordIfUnifiable(
const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
// Collect the element types for all resources in the current set.
SmallVector<spirv::SPIRVType> elementTypes;
for (spirv::GlobalVariableOp resource : resources) {
Type elementType = getRuntimeArrayElementType(resource.getType());
if (!elementType)
return; // Unexpected resource variable type.
auto type = elementType.cast<spirv::SPIRVType>();
if (!type.isScalarOrVector())
return; // Unexpected resource element type.
elementTypes.push_back(type);
}
Optional<int> index = deduceCanonicalResource(elementTypes);
if (!index)
return;
// Update internal data structures for later use.
resourceMap[descriptor].assign(resources.begin(), resources.end());
canonicalResourceMap[descriptor] = resources[*index];
for (const auto &resource : llvm::enumerate(resources)) {
descriptorMap[resource.value()] = descriptor;
elementTypeMap[resource.value()] = elementTypes[resource.index()];
}
}
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
template <typename OpTy>
class ConvertAliasResource : public OpConversionPattern<OpTy> {
public:
ConvertAliasResource(const ResourceAliasAnalysis &analysis,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
protected:
const ResourceAliasAnalysis &analysis;
};
struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> {
using ConvertAliasResource::ConvertAliasResource;
LogicalResult
matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Just remove the aliased resource. Users will be rewritten to use the
// canonical one.
rewriter.eraseOp(varOp);
return success();
}
};
struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
using ConvertAliasResource::ConvertAliasResource;
LogicalResult
matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Rewrite the AddressOf op to get the address of the canoncical resource.
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
auto srcVarOp = cast<spirv::GlobalVariableOp>(
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
return success();
}
};
struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
using ConvertAliasResource::ConvertAliasResource;
LogicalResult
matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
if (!addressOp)
return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
auto srcVarOp = cast<spirv::GlobalVariableOp>(
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
if (srcElemType == dstElemType ||
areSameBitwidthScalarType(srcElemType, dstElemType)) {
// We have the same bitwidth for source and destination element types.
// Thie indices keep the same.
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), adaptor.getIndices());
return success();
}
Location loc = acOp.getLoc();
auto i32Type = rewriter.getI32Type();
if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
// The source indices are for a buffer with scalar element types. Rewrite
// them into a buffer with vector element types. We need to scale the last
// index for the vector as a whole, then add one level of index for inside
// the vector.
int srcNumBytes = *srcElemType.getSizeInBytes();
int dstNumBytes = *dstElemType.getSizeInBytes();
assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
int ratio = dstNumBytes / srcNumBytes;
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back();
indices.back() =
rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
indices.push_back(
rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), indices);
return success();
}
if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
(srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) {
// The source indices are for a buffer with larger bitwidth scalar/vector
// element types. Rewrite them into a buffer with smaller bitwidth element
// types. We only need to scale the last index.
int srcNumBytes = *srcElemType.getSizeInBytes();
int dstNumBytes = *dstElemType.getSizeInBytes();
assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
int ratio = srcNumBytes / dstNumBytes;
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back();
indices.back() =
rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), indices);
return success();
}
return rewriter.notifyMatchFailure(
acOp, "unsupported src/dst types for spirv.AccessChain");
}
};
struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
using ConvertAliasResource::ConvertAliasResource;
LogicalResult
matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcPtrType = loadOp.getPtr().getType().cast<spirv::PointerType>();
auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>();
auto dstPtrType = adaptor.getPtr().getType().cast<spirv::PointerType>();
auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>();
Location loc = loadOp.getLoc();
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
if (srcElemType == dstElemType) {
rewriter.replaceOp(loadOp, newLoadOp->getResults());
return success();
}
if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
newLoadOp.getValue());
rewriter.replaceOp(loadOp, castOp->getResults());
return success();
}
if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
(srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) {
// The source and destination have scalar types of different bitwidths, or
// vector types of different component counts. For such cases, we load
// multiple smaller bitwidth values and construct a larger bitwidth one.
int srcNumBytes = *srcElemType.getSizeInBytes();
int dstNumBytes = *dstElemType.getSizeInBytes();
assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
int ratio = srcNumBytes / dstNumBytes;
if (ratio > 4)
return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
SmallVector<Value> components;
components.reserve(ratio);
components.push_back(newLoadOp);
auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
if (!acOp)
return rewriter.notifyMatchFailure(loadOp, "ptr not spirv.AccessChain");
auto i32Type = rewriter.getI32Type();
Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
auto indices = llvm::to_vector<4>(acOp.getIndices());
for (int i = 1; i < ratio; ++i) {
// Load all subsequent components belonging to this element.
indices.back() = rewriter.create<spirv::IAddOp>(
loc, i32Type, indices.back(), oneValue);
auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
loc, acOp.getBasePtr(), indices);
// Assuming little endian, this reads lower-ordered bits of the number
// to lower-numbered components of the vector.
components.push_back(
rewriter.create<spirv::LoadOp>(loc, componentAcOp));
}
// Create a vector of the components and then cast back to the larger
// bitwidth element type. For spirv.bitcast, the lower-numbered components
// of the vector map to lower-ordered bits of the larger bitwidth element
// type.
Type vectorType = srcElemType;
if (!srcElemType.isa<VectorType>())
vectorType = VectorType::get({ratio}, dstElemType);
Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
loc, vectorType, components);
if (!srcElemType.isa<VectorType>())
vectorValue =
rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
rewriter.replaceOp(loadOp, vectorValue);
return success();
}
return rewriter.notifyMatchFailure(
loadOp, "unsupported src/dst types for spirv.Load");
}
};
struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
using ConvertAliasResource::ConvertAliasResource;
LogicalResult
matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcElemType =
storeOp.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
auto dstElemType =
adaptor.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
return rewriter.notifyMatchFailure(storeOp, "not scalar type");
if (!areSameBitwidthScalarType(srcElemType, dstElemType))
return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
Location loc = storeOp.getLoc();
Value value = adaptor.getValue();
if (srcElemType != dstElemType)
value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(),
value, storeOp->getAttrs());
return success();
}
};
//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
namespace {
class UnifyAliasedResourcePass final
: public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
UnifyAliasedResourcePass> {
public:
explicit UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv)
: getTargetEnvFn(std::move(getTargetEnv)) {}
void runOnOperation() override;
private:
spirv::GetTargetEnvFn getTargetEnvFn;
};
} // namespace
void UnifyAliasedResourcePass::runOnOperation() {
spirv::ModuleOp moduleOp = getOperation();
MLIRContext *context = &getContext();
if (getTargetEnvFn) {
// This pass is only needed for targeting WebGPU, Metal, or layering Vulkan
// on Metal via MoltenVK, where we need to translate SPIR-V into WGSL or
// MSL. The translation has limitations.
spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
bool isVulkanOnAppleDevices =
clientAPI == spirv::ClientAPI::Vulkan &&
targetEnv.getVendorID() == spirv::Vendor::Apple;
if (clientAPI != spirv::ClientAPI::WebGPU &&
clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
return;
}
// Analyze aliased resources first.
ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
ConversionTarget target(*context);
target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
spirv::AccessChainOp, spirv::LoadOp,
spirv::StoreOp>(
[&analysis](Operation *op) { return !analysis.shouldUnify(op); });
target.addLegalDialect<spirv::SPIRVDialect>();
// Run patterns to rewrite usages of non-canonical resources.
RewritePatternSet patterns(context);
patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
ConvertLoad, ConvertStore>(analysis, context);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
return signalPassFailure();
// Drop aliased attribute if we only have one single bound resource for a
// descriptor. We need to re-collect the map here given in the above the
// conversion is best effort; certain sets may not be converted.
AliasedResourceMap resourceMap =
collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
for (const auto &dr : resourceMap) {
const auto &resources = dr.second;
if (resources.size() == 1)
resources.front()->removeAttr("aliased");
}
}
std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) {
return std::make_unique<UnifyAliasedResourcePass>(std::move(getTargetEnv));
}