164 lines
6.8 KiB
C++
164 lines
6.8 KiB
C++
//===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
|
|
#include "mlir/Analysis/DataLayoutAnalysis.h"
|
|
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
// TODO: Fix the LLVM utilities for looking up functions to take Operation*
|
|
// with SymbolTable trait instead of ModuleOp and make similar change here. This
|
|
// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
|
|
// of getParentOfType<ModuleOp> to pass down the operation.
|
|
LLVM::LLVMFuncOp getNotalignedAllocFn(LLVMTypeConverter *typeConverter,
|
|
ModuleOp module, Type indexType) {
|
|
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
|
|
|
|
if (useGenericFn)
|
|
return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
|
|
|
|
return LLVM::lookupOrCreateMallocFn(module, indexType);
|
|
}
|
|
|
|
LLVM::LLVMFuncOp getAlignedAllocFn(LLVMTypeConverter *typeConverter,
|
|
ModuleOp module, Type indexType) {
|
|
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
|
|
|
|
if (useGenericFn)
|
|
return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
|
|
|
|
return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
|
|
}
|
|
|
|
} // end namespace
|
|
|
|
Value AllocationOpLLVMLowering::createAligned(
|
|
ConversionPatternRewriter &rewriter, Location loc, Value input,
|
|
Value alignment) {
|
|
Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
|
|
Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
|
|
Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
|
|
Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
|
|
return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
|
|
}
|
|
|
|
std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
|
|
ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
|
|
Operation *op, Value alignment) const {
|
|
if (alignment) {
|
|
// Adjust the allocation size to consider alignment.
|
|
sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
|
|
}
|
|
|
|
MemRefType memRefType = getMemRefResultType(op);
|
|
// Allocate the underlying buffer.
|
|
Type elementPtrType = this->getElementPtrType(memRefType);
|
|
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
|
|
getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
|
|
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
|
|
Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
|
|
results.getResult());
|
|
|
|
Value alignedPtr = allocatedPtr;
|
|
if (alignment) {
|
|
// Compute the aligned pointer.
|
|
Value allocatedInt =
|
|
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
|
|
Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
|
|
alignedPtr =
|
|
rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
|
|
}
|
|
|
|
return std::make_tuple(allocatedPtr, alignedPtr);
|
|
}
|
|
|
|
unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
|
|
MemRefType memRefType, Operation *op,
|
|
const DataLayout *defaultLayout) const {
|
|
const DataLayout *layout = defaultLayout;
|
|
if (const DataLayoutAnalysis *analysis =
|
|
getTypeConverter()->getDataLayoutAnalysis()) {
|
|
layout = &analysis->getAbove(op);
|
|
}
|
|
Type elementType = memRefType.getElementType();
|
|
if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
|
|
return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
|
|
*layout);
|
|
if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
|
|
return getTypeConverter()->getUnrankedMemRefDescriptorSize(
|
|
memRefElementType, *layout);
|
|
return layout->getTypeSize(elementType);
|
|
}
|
|
|
|
bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
|
|
MemRefType type, uint64_t factor, Operation *op,
|
|
const DataLayout *defaultLayout) const {
|
|
uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
|
|
for (unsigned i = 0, e = type.getRank(); i < e; i++) {
|
|
if (ShapedType::isDynamic(type.getDimSize(i)))
|
|
continue;
|
|
sizeDivisor = sizeDivisor * type.getDimSize(i);
|
|
}
|
|
return sizeDivisor % factor == 0;
|
|
}
|
|
|
|
Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
|
|
ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
|
|
Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
|
|
Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
|
|
|
|
MemRefType memRefType = getMemRefResultType(op);
|
|
// Function aligned_alloc requires size to be a multiple of alignment; we pad
|
|
// the size to the next multiple if necessary.
|
|
if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
|
|
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
|
|
|
|
Type elementPtrType = this->getElementPtrType(memRefType);
|
|
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
|
|
getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
|
|
auto results = rewriter.create<LLVM::CallOp>(
|
|
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
|
|
Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
|
|
results.getResult());
|
|
|
|
return allocatedPtr;
|
|
}
|
|
|
|
LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
|
|
Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
MemRefType memRefType = getMemRefResultType(op);
|
|
if (!isConvertibleAndHasIdentityMaps(memRefType))
|
|
return rewriter.notifyMatchFailure(op, "incompatible memref type");
|
|
auto loc = op->getLoc();
|
|
|
|
// Get actual sizes of the memref as values: static sizes are constant
|
|
// values and dynamic sizes are passed to 'alloc' as operands. In case of
|
|
// zero-dimensional memref, assume a scalar (size 1).
|
|
SmallVector<Value, 4> sizes;
|
|
SmallVector<Value, 4> strides;
|
|
Value sizeBytes;
|
|
this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
|
|
strides, sizeBytes);
|
|
|
|
// Allocate the underlying buffer.
|
|
auto [allocatedPtr, alignedPtr] =
|
|
this->allocateBuffer(rewriter, loc, sizeBytes, op);
|
|
|
|
// Create the MemRef descriptor.
|
|
auto memRefDescriptor = this->createMemRefDescriptor(
|
|
loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
|
|
|
|
// Return the final value of the descriptor.
|
|
rewriter.replaceOp(op, {memRefDescriptor});
|
|
return success();
|
|
}
|