llvm-project/mlir/lib/Dialect/Func/Transforms/FuncBufferize.cpp

72 lines
2.7 KiB
C++

//===- Bufferize.cpp - Bufferization for func ops -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements bufferization of func.func's and func.call's.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_FUNCBUFFERIZE
#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::func;
namespace {
struct FuncBufferizePass : public impl::FuncBufferizeBase<FuncBufferizePass> {
using FuncBufferizeBase<FuncBufferizePass>::FuncBufferizeBase;
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();
bufferization::BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
typeConverter);
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody());
});
populateCallOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<CallOp>(
[&](CallOp op) { return typeConverter.isLegal(op); });
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
populateReturnOpTypeConversionPattern(patterns, typeConverter);
target.addLegalOp<ModuleOp, bufferization::ToTensorOp,
bufferization::ToMemrefOp>();
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
isLegalForBranchOpInterfaceTypeConversionPattern(op,
typeConverter) ||
isLegalForReturnOpTypeConversionPattern(op, typeConverter);
});
if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::func::createFuncBufferizePass() {
return std::make_unique<FuncBufferizePass>();
}