599 lines
24 KiB
C++
599 lines
24 KiB
C++
//===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass that unifies access of multiple aliased resources
|
|
// into access of one single resource.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
|
|
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
|
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/SymbolTable.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include <algorithm>
|
|
#include <iterator>
|
|
|
|
namespace mlir {
|
|
namespace spirv {
|
|
#define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
|
|
#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
|
|
} // namespace spirv
|
|
} // namespace mlir
|
|
|
|
#define DEBUG_TYPE "spirv-unify-aliased-resource"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utility functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
|
|
using AliasedResourceMap =
|
|
DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
|
|
|
|
/// Collects all aliased resources in the given SPIR-V `moduleOp`.
|
|
static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
|
|
AliasedResourceMap aliasedResources;
|
|
moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
|
|
if (varOp->getAttrOfType<UnitAttr>("aliased")) {
|
|
Optional<uint32_t> set = varOp.getDescriptorSet();
|
|
Optional<uint32_t> binding = varOp.getBinding();
|
|
if (set && binding)
|
|
aliasedResources[{*set, *binding}].push_back(varOp);
|
|
}
|
|
});
|
|
return aliasedResources;
|
|
}
|
|
|
|
/// Returns the element type if the given `type` is a runtime array resource:
|
|
/// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type
|
|
/// otherwise.
|
|
static Type getRuntimeArrayElementType(Type type) {
|
|
auto ptrType = type.dyn_cast<spirv::PointerType>();
|
|
if (!ptrType)
|
|
return {};
|
|
|
|
auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
|
|
if (!structType || structType.getNumElements() != 1)
|
|
return {};
|
|
|
|
auto rtArrayType =
|
|
structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>();
|
|
if (!rtArrayType)
|
|
return {};
|
|
|
|
return rtArrayType.getElementType();
|
|
}
|
|
|
|
/// Given a list of resource element `types`, returns the index of the canonical
|
|
/// resource that all resources should be unified into. Returns std::nullopt if
|
|
/// unable to unify.
|
|
static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
|
|
// scalarNumBits: contains all resources' scalar types' bit counts.
|
|
// vectorNumBits: only contains resources whose element types are vectors.
|
|
// vectorIndices: each vector's original index in `types`.
|
|
SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
|
|
scalarNumBits.reserve(types.size());
|
|
vectorNumBits.reserve(types.size());
|
|
vectorIndices.reserve(types.size());
|
|
|
|
for (const auto &indexedTypes : llvm::enumerate(types)) {
|
|
spirv::SPIRVType type = indexedTypes.value();
|
|
assert(type.isScalarOrVector());
|
|
if (auto vectorType = type.dyn_cast<VectorType>()) {
|
|
if (vectorType.getNumElements() % 2 != 0)
|
|
return std::nullopt; // Odd-sized vector has special layout
|
|
// requirements.
|
|
|
|
Optional<int64_t> numBytes = type.getSizeInBytes();
|
|
if (!numBytes)
|
|
return std::nullopt;
|
|
|
|
scalarNumBits.push_back(
|
|
vectorType.getElementType().getIntOrFloatBitWidth());
|
|
vectorNumBits.push_back(*numBytes * 8);
|
|
vectorIndices.push_back(indexedTypes.index());
|
|
} else {
|
|
scalarNumBits.push_back(type.getIntOrFloatBitWidth());
|
|
}
|
|
}
|
|
|
|
if (!vectorNumBits.empty()) {
|
|
// Choose the *vector* with the smallest bitwidth as the canonical resource,
|
|
// so that we can still keep vectorized load/store and avoid partial updates
|
|
// to large vectors.
|
|
auto *minVal = std::min_element(vectorNumBits.begin(), vectorNumBits.end());
|
|
// Make sure that the canonical resource's bitwidth is divisible by others.
|
|
// With out this, we cannot properly adjust the index later.
|
|
if (llvm::any_of(vectorNumBits,
|
|
[&](int bits) { return bits % *minVal != 0; }))
|
|
return std::nullopt;
|
|
|
|
// Require all scalar type bit counts to be a multiple of the chosen
|
|
// vector's primitive type to avoid reading/writing subcomponents.
|
|
int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
|
|
int baseNumBits = scalarNumBits[index];
|
|
if (llvm::any_of(scalarNumBits,
|
|
[&](int bits) { return bits % baseNumBits != 0; }))
|
|
return std::nullopt;
|
|
|
|
return index;
|
|
}
|
|
|
|
// All element types are scalars. Then choose the smallest bitwidth as the
|
|
// cannonical resource to avoid subcomponent load/store.
|
|
auto *minVal = std::min_element(scalarNumBits.begin(), scalarNumBits.end());
|
|
if (llvm::any_of(scalarNumBits,
|
|
[minVal](int64_t bit) { return bit % *minVal != 0; }))
|
|
return std::nullopt;
|
|
return std::distance(scalarNumBits.begin(), minVal);
|
|
}
|
|
|
|
static bool areSameBitwidthScalarType(Type a, Type b) {
|
|
return a.isIntOrFloat() && b.isIntOrFloat() &&
|
|
a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Analysis
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// A class for analyzing aliased resources.
|
|
///
|
|
/// Resources are expected to be spirv.GlobalVarible that has a descriptor set
|
|
/// and binding number. Such resources are of the type
|
|
/// `!spirv.ptr<!spirv.struct<...>>` per Vulkan requirements.
|
|
///
|
|
/// Right now, we only support the case that there is a single runtime array
|
|
/// inside the struct.
|
|
class ResourceAliasAnalysis {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis)
|
|
|
|
explicit ResourceAliasAnalysis(Operation *);
|
|
|
|
/// Returns true if the given `op` can be rewritten to use a canonical
|
|
/// resource.
|
|
bool shouldUnify(Operation *op) const;
|
|
|
|
/// Returns all descriptors and their corresponding aliased resources.
|
|
const AliasedResourceMap &getResourceMap() const { return resourceMap; }
|
|
|
|
/// Returns the canonical resource for the given descriptor/variable.
|
|
spirv::GlobalVariableOp
|
|
getCanonicalResource(const Descriptor &descriptor) const;
|
|
spirv::GlobalVariableOp
|
|
getCanonicalResource(spirv::GlobalVariableOp varOp) const;
|
|
|
|
/// Returns the element type for the given variable.
|
|
spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
|
|
|
|
private:
|
|
/// Given the descriptor and aliased resources bound to it, analyze whether we
|
|
/// can unify them and record if so.
|
|
void recordIfUnifiable(const Descriptor &descriptor,
|
|
ArrayRef<spirv::GlobalVariableOp> resources);
|
|
|
|
/// Mapping from a descriptor to all aliased resources bound to it.
|
|
AliasedResourceMap resourceMap;
|
|
|
|
/// Mapping from a descriptor to the chosen canonical resource.
|
|
DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
|
|
|
|
/// Mapping from an aliased resource to its descriptor.
|
|
DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
|
|
|
|
/// Mapping from an aliased resource to its element (scalar/vector) type.
|
|
DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
|
|
};
|
|
} // namespace
|
|
|
|
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
|
|
// Collect all aliased resources first and put them into different sets
|
|
// according to the descriptor.
|
|
AliasedResourceMap aliasedResources =
|
|
collectAliasedResources(cast<spirv::ModuleOp>(root));
|
|
|
|
// For each resource set, analyze whether we can unify; if so, try to identify
|
|
// a canonical resource, whose element type has the largest bitwidth.
|
|
for (const auto &descriptorResource : aliasedResources) {
|
|
recordIfUnifiable(descriptorResource.first, descriptorResource.second);
|
|
}
|
|
}
|
|
|
|
bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
|
|
if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
|
|
auto canonicalOp = getCanonicalResource(varOp);
|
|
return canonicalOp && varOp != canonicalOp;
|
|
}
|
|
if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
|
|
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
|
|
auto *varOp =
|
|
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
|
|
return shouldUnify(varOp);
|
|
}
|
|
|
|
if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
|
|
return shouldUnify(acOp.getBasePtr().getDefiningOp());
|
|
if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
|
|
return shouldUnify(loadOp.getPtr().getDefiningOp());
|
|
if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
|
|
return shouldUnify(storeOp.getPtr().getDefiningOp());
|
|
|
|
return false;
|
|
}
|
|
|
|
spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
|
|
const Descriptor &descriptor) const {
|
|
auto varIt = canonicalResourceMap.find(descriptor);
|
|
if (varIt == canonicalResourceMap.end())
|
|
return {};
|
|
return varIt->second;
|
|
}
|
|
|
|
spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
|
|
spirv::GlobalVariableOp varOp) const {
|
|
auto descriptorIt = descriptorMap.find(varOp);
|
|
if (descriptorIt == descriptorMap.end())
|
|
return {};
|
|
return getCanonicalResource(descriptorIt->second);
|
|
}
|
|
|
|
spirv::SPIRVType
|
|
ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
|
|
auto it = elementTypeMap.find(varOp);
|
|
if (it == elementTypeMap.end())
|
|
return {};
|
|
return it->second;
|
|
}
|
|
|
|
void ResourceAliasAnalysis::recordIfUnifiable(
|
|
const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
|
|
// Collect the element types for all resources in the current set.
|
|
SmallVector<spirv::SPIRVType> elementTypes;
|
|
for (spirv::GlobalVariableOp resource : resources) {
|
|
Type elementType = getRuntimeArrayElementType(resource.getType());
|
|
if (!elementType)
|
|
return; // Unexpected resource variable type.
|
|
|
|
auto type = elementType.cast<spirv::SPIRVType>();
|
|
if (!type.isScalarOrVector())
|
|
return; // Unexpected resource element type.
|
|
|
|
elementTypes.push_back(type);
|
|
}
|
|
|
|
Optional<int> index = deduceCanonicalResource(elementTypes);
|
|
if (!index)
|
|
return;
|
|
|
|
// Update internal data structures for later use.
|
|
resourceMap[descriptor].assign(resources.begin(), resources.end());
|
|
canonicalResourceMap[descriptor] = resources[*index];
|
|
for (const auto &resource : llvm::enumerate(resources)) {
|
|
descriptorMap[resource.value()] = descriptor;
|
|
elementTypeMap[resource.value()] = elementTypes[resource.index()];
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename OpTy>
|
|
class ConvertAliasResource : public OpConversionPattern<OpTy> {
|
|
public:
|
|
ConvertAliasResource(const ResourceAliasAnalysis &analysis,
|
|
MLIRContext *context, PatternBenefit benefit = 1)
|
|
: OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
|
|
|
|
protected:
|
|
const ResourceAliasAnalysis &analysis;
|
|
};
|
|
|
|
struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> {
|
|
using ConvertAliasResource::ConvertAliasResource;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Just remove the aliased resource. Users will be rewritten to use the
|
|
// canonical one.
|
|
rewriter.eraseOp(varOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
|
|
using ConvertAliasResource::ConvertAliasResource;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Rewrite the AddressOf op to get the address of the canoncical resource.
|
|
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
|
|
auto srcVarOp = cast<spirv::GlobalVariableOp>(
|
|
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
|
|
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
|
|
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
|
|
using ConvertAliasResource::ConvertAliasResource;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
|
|
if (!addressOp)
|
|
return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
|
|
|
|
auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
|
|
auto srcVarOp = cast<spirv::GlobalVariableOp>(
|
|
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
|
|
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
|
|
|
|
spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
|
|
spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
|
|
|
|
if (srcElemType == dstElemType ||
|
|
areSameBitwidthScalarType(srcElemType, dstElemType)) {
|
|
// We have the same bitwidth for source and destination element types.
|
|
// Thie indices keep the same.
|
|
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
|
acOp, adaptor.getBasePtr(), adaptor.getIndices());
|
|
return success();
|
|
}
|
|
|
|
Location loc = acOp.getLoc();
|
|
auto i32Type = rewriter.getI32Type();
|
|
|
|
if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
|
|
// The source indices are for a buffer with scalar element types. Rewrite
|
|
// them into a buffer with vector element types. We need to scale the last
|
|
// index for the vector as a whole, then add one level of index for inside
|
|
// the vector.
|
|
int srcNumBytes = *srcElemType.getSizeInBytes();
|
|
int dstNumBytes = *dstElemType.getSizeInBytes();
|
|
assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
|
|
int ratio = dstNumBytes / srcNumBytes;
|
|
auto ratioValue = rewriter.create<spirv::ConstantOp>(
|
|
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
|
|
|
|
auto indices = llvm::to_vector<4>(acOp.getIndices());
|
|
Value oldIndex = indices.back();
|
|
indices.back() =
|
|
rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
|
|
indices.push_back(
|
|
rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
|
|
|
|
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
|
acOp, adaptor.getBasePtr(), indices);
|
|
return success();
|
|
}
|
|
|
|
if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
|
|
(srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) {
|
|
// The source indices are for a buffer with larger bitwidth scalar/vector
|
|
// element types. Rewrite them into a buffer with smaller bitwidth element
|
|
// types. We only need to scale the last index.
|
|
int srcNumBytes = *srcElemType.getSizeInBytes();
|
|
int dstNumBytes = *dstElemType.getSizeInBytes();
|
|
assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
|
|
int ratio = srcNumBytes / dstNumBytes;
|
|
auto ratioValue = rewriter.create<spirv::ConstantOp>(
|
|
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
|
|
|
|
auto indices = llvm::to_vector<4>(acOp.getIndices());
|
|
Value oldIndex = indices.back();
|
|
indices.back() =
|
|
rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
|
|
|
|
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
|
acOp, adaptor.getBasePtr(), indices);
|
|
return success();
|
|
}
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
acOp, "unsupported src/dst types for spirv.AccessChain");
|
|
}
|
|
};
|
|
|
|
struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
|
|
using ConvertAliasResource::ConvertAliasResource;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto srcPtrType = loadOp.getPtr().getType().cast<spirv::PointerType>();
|
|
auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>();
|
|
auto dstPtrType = adaptor.getPtr().getType().cast<spirv::PointerType>();
|
|
auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>();
|
|
|
|
Location loc = loadOp.getLoc();
|
|
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
|
|
if (srcElemType == dstElemType) {
|
|
rewriter.replaceOp(loadOp, newLoadOp->getResults());
|
|
return success();
|
|
}
|
|
|
|
if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
|
|
auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
|
|
newLoadOp.getValue());
|
|
rewriter.replaceOp(loadOp, castOp->getResults());
|
|
|
|
return success();
|
|
}
|
|
|
|
if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
|
|
(srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) {
|
|
// The source and destination have scalar types of different bitwidths, or
|
|
// vector types of different component counts. For such cases, we load
|
|
// multiple smaller bitwidth values and construct a larger bitwidth one.
|
|
|
|
int srcNumBytes = *srcElemType.getSizeInBytes();
|
|
int dstNumBytes = *dstElemType.getSizeInBytes();
|
|
assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
|
|
int ratio = srcNumBytes / dstNumBytes;
|
|
if (ratio > 4)
|
|
return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
|
|
|
|
SmallVector<Value> components;
|
|
components.reserve(ratio);
|
|
components.push_back(newLoadOp);
|
|
|
|
auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
|
|
if (!acOp)
|
|
return rewriter.notifyMatchFailure(loadOp, "ptr not spirv.AccessChain");
|
|
|
|
auto i32Type = rewriter.getI32Type();
|
|
Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
|
|
auto indices = llvm::to_vector<4>(acOp.getIndices());
|
|
for (int i = 1; i < ratio; ++i) {
|
|
// Load all subsequent components belonging to this element.
|
|
indices.back() = rewriter.create<spirv::IAddOp>(
|
|
loc, i32Type, indices.back(), oneValue);
|
|
auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
|
|
loc, acOp.getBasePtr(), indices);
|
|
// Assuming little endian, this reads lower-ordered bits of the number
|
|
// to lower-numbered components of the vector.
|
|
components.push_back(
|
|
rewriter.create<spirv::LoadOp>(loc, componentAcOp));
|
|
}
|
|
|
|
// Create a vector of the components and then cast back to the larger
|
|
// bitwidth element type. For spirv.bitcast, the lower-numbered components
|
|
// of the vector map to lower-ordered bits of the larger bitwidth element
|
|
// type.
|
|
Type vectorType = srcElemType;
|
|
if (!srcElemType.isa<VectorType>())
|
|
vectorType = VectorType::get({ratio}, dstElemType);
|
|
Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
|
|
loc, vectorType, components);
|
|
if (!srcElemType.isa<VectorType>())
|
|
vectorValue =
|
|
rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
|
|
rewriter.replaceOp(loadOp, vectorValue);
|
|
return success();
|
|
}
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
loadOp, "unsupported src/dst types for spirv.Load");
|
|
}
|
|
};
|
|
|
|
struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
|
|
using ConvertAliasResource::ConvertAliasResource;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto srcElemType =
|
|
storeOp.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
|
|
auto dstElemType =
|
|
adaptor.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
|
|
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
|
|
return rewriter.notifyMatchFailure(storeOp, "not scalar type");
|
|
if (!areSameBitwidthScalarType(srcElemType, dstElemType))
|
|
return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
|
|
|
|
Location loc = storeOp.getLoc();
|
|
Value value = adaptor.getValue();
|
|
if (srcElemType != dstElemType)
|
|
value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
|
|
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(),
|
|
value, storeOp->getAttrs());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class UnifyAliasedResourcePass final
|
|
: public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
|
|
UnifyAliasedResourcePass> {
|
|
public:
|
|
explicit UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv)
|
|
: getTargetEnvFn(std::move(getTargetEnv)) {}
|
|
|
|
void runOnOperation() override;
|
|
|
|
private:
|
|
spirv::GetTargetEnvFn getTargetEnvFn;
|
|
};
|
|
} // namespace
|
|
|
|
void UnifyAliasedResourcePass::runOnOperation() {
|
|
spirv::ModuleOp moduleOp = getOperation();
|
|
MLIRContext *context = &getContext();
|
|
|
|
if (getTargetEnvFn) {
|
|
// This pass is only needed for targeting WebGPU, Metal, or layering Vulkan
|
|
// on Metal via MoltenVK, where we need to translate SPIR-V into WGSL or
|
|
// MSL. The translation has limitations.
|
|
spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
|
|
spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
|
|
bool isVulkanOnAppleDevices =
|
|
clientAPI == spirv::ClientAPI::Vulkan &&
|
|
targetEnv.getVendorID() == spirv::Vendor::Apple;
|
|
if (clientAPI != spirv::ClientAPI::WebGPU &&
|
|
clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
|
|
return;
|
|
}
|
|
|
|
// Analyze aliased resources first.
|
|
ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
|
|
|
|
ConversionTarget target(*context);
|
|
target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
|
|
spirv::AccessChainOp, spirv::LoadOp,
|
|
spirv::StoreOp>(
|
|
[&analysis](Operation *op) { return !analysis.shouldUnify(op); });
|
|
target.addLegalDialect<spirv::SPIRVDialect>();
|
|
|
|
// Run patterns to rewrite usages of non-canonical resources.
|
|
RewritePatternSet patterns(context);
|
|
patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
|
|
ConvertLoad, ConvertStore>(analysis, context);
|
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
|
|
return signalPassFailure();
|
|
|
|
// Drop aliased attribute if we only have one single bound resource for a
|
|
// descriptor. We need to re-collect the map here given in the above the
|
|
// conversion is best effort; certain sets may not be converted.
|
|
AliasedResourceMap resourceMap =
|
|
collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
|
|
for (const auto &dr : resourceMap) {
|
|
const auto &resources = dr.second;
|
|
if (resources.size() == 1)
|
|
resources.front()->removeAttr("aliased");
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
|
|
spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) {
|
|
return std::make_unique<UnifyAliasedResourcePass>(std::move(getTargetEnv));
|
|
}
|