llvm-project/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp

164 lines
5.9 KiB
C++

//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
namespace mlir::memref {
#define GEN_PASS_DEF_MEMREFEMULATEWIDEINT
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace mlir::memref
using namespace mlir;
namespace {
//===----------------------------------------------------------------------===//
// ConvertMemRefAlloc
//===----------------------------------------------------------------------===//
struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type newTy = getTypeConverter()->convertType(op.getType());
if (!newTy)
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert memref type: {0}", op.getType()));
rewriter.replaceOpWithNewOp<memref::AllocOp>(
op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
adaptor.getAlignmentAttr());
return success();
}
};
//===----------------------------------------------------------------------===//
// ConvertMemRefLoad
//===----------------------------------------------------------------------===//
struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type newResTy = getTypeConverter()->convertType(op.getType());
if (!newResTy)
return rewriter.notifyMatchFailure(
op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
op.getMemRefType()));
rewriter.replaceOpWithNewOp<memref::LoadOp>(
op, newResTy, adaptor.getMemref(), adaptor.getIndices());
return success();
}
};
//===----------------------------------------------------------------------===//
// ConvertMemRefStore
//===----------------------------------------------------------------------===//
struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type newTy = getTypeConverter()->convertType(op.getMemRefType());
if (!newTy)
return rewriter.notifyMatchFailure(
op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
op.getMemRefType()));
rewriter.replaceOpWithNewOp<memref::StoreOp>(
op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices());
return success();
}
};
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
struct EmulateWideIntPass final
: memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> {
using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase;
void runOnOperation() override {
if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
signalPassFailure();
return;
}
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
arith::WideIntEmulationConverter typeConverter(widestIntSupported);
memref::populateMemRefWideIntEmulationConversions(typeConverter);
ConversionTarget target(*ctx);
target.addDynamicallyLegalDialect<
arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>(
[&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
RewritePatternSet patterns(ctx);
// Add common pattenrs to support contants, functions, etc.
arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Public Interface Definition
//===----------------------------------------------------------------------===//
void memref::populateMemRefWideIntEmulationPatterns(
arith::WideIntEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>(
typeConverter, patterns.getContext());
}
void memref::populateMemRefWideIntEmulationConversions(
arith::WideIntEmulationConverter &typeConverter) {
typeConverter.addConversion(
[&typeConverter](MemRefType ty) -> Optional<Type> {
auto intTy = ty.getElementType().dyn_cast<IntegerType>();
if (!intTy)
return ty;
if (intTy.getIntOrFloatBitWidth() <=
typeConverter.getMaxTargetIntBitWidth())
return ty;
Type newElemTy = typeConverter.convertType(intTy);
if (!newElemTy)
return std::nullopt;
return ty.cloneWith(std::nullopt, newElemTy);
});
}