llvm-project/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp

135 lines
5.0 KiB
C++

//===- ReconcileUnrealizedCasts.cpp - Eliminate noop unrealized casts -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
/// the same as the input ones.
/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
/// represent a noop within the IR, and thus the initial input values can be
/// propagated.
/// The same does not hold for 'open' chains chains of casts, such as
/// `A -> B -> C`. In this last case there is no cycle among the types and thus
/// the conversion is incomplete. The same hold for 'closed' chains like
/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
/// operations.
/// Bifurcations (that is when a chain starts in between of another one) are
/// also taken into considerations, and all the above considerations remain
/// valid.
/// Special corner cases such as dead casts or single casts with same input and
/// output types are also covered.
struct UnrealizedConversionCastPassthrough
: public OpRewritePattern<UnrealizedConversionCastOp> {
using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
PatternRewriter &rewriter) const override {
// The nodes that either are not used by any operation or have at least
// one user that is not an unrealized cast.
DenseSet<UnrealizedConversionCastOp> exitNodes;
// The nodes whose users are all unrealized casts
DenseSet<UnrealizedConversionCastOp> intermediateNodes;
// Stack used for the depth-first traversal of the use-def DAG.
SmallVector<UnrealizedConversionCastOp, 2> visitStack;
visitStack.push_back(op);
while (!visitStack.empty()) {
UnrealizedConversionCastOp current = visitStack.pop_back_val();
auto users = current->getUsers();
bool isLive = false;
for (Operation *user : users) {
if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
if (other.getInputs() != current.getOutputs())
return rewriter.notifyMatchFailure(
op, "mismatching values propagation");
} else {
isLive = true;
}
// Continue traversing the DAG of unrealized casts
if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
visitStack.push_back(other);
}
// If the cast is live, then we need to check if the results of the last
// cast have the same type of the root inputs. It this is the case (e.g.
// `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
// no-op and the inputs can be forwarded. If it's not (e.g.
// `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
if (isLive && !isCycle)
return rewriter.notifyMatchFailure(op,
"live unrealized conversion cast");
bool isExitNode = users.empty() || isLive;
if (isExitNode) {
exitNodes.insert(current);
} else {
intermediateNodes.insert(current);
}
}
// Replace the sink nodes with the root input values
for (UnrealizedConversionCastOp exitNode : exitNodes)
rewriter.replaceOp(exitNode, op.getInputs());
// Erase all the other casts belonging to the DAG
for (UnrealizedConversionCastOp castOp : intermediateNodes)
rewriter.eraseOp(castOp);
return success();
}
};
/// Pass to simplify and eliminate unrealized conversion casts.
struct ReconcileUnrealizedCasts
: public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
ReconcileUnrealizedCasts() = default;
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateReconcileUnrealizedCastsPatterns(patterns);
ConversionTarget target(getContext());
target.addIllegalOp<UnrealizedConversionCastOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::populateReconcileUnrealizedCastsPatterns(
RewritePatternSet &patterns) {
patterns.add<UnrealizedConversionCastPassthrough>(patterns.getContext());
}
std::unique_ptr<Pass> mlir::createReconcileUnrealizedCastsPass() {
return std::make_unique<ReconcileUnrealizedCasts>();
}