191 lines
6.6 KiB
C++
191 lines
6.6 KiB
C++
//===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "TestDialect.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
using namespace mlir;
|
|
|
|
static std::string getStageDescription(const WalkStage &stage) {
|
|
if (stage.isBeforeAllRegions())
|
|
return "before all regions";
|
|
if (stage.isAfterAllRegions())
|
|
return "after all regions";
|
|
return "before region #" + std::to_string(stage.getNextRegion());
|
|
}
|
|
|
|
namespace {
|
|
/// This pass exercises generic visitor with void callbacks and prints the order
|
|
/// and stage in which operations are visited.
|
|
struct TestGenericIRVisitorPass
|
|
: public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGenericIRVisitorPass)
|
|
|
|
StringRef getArgument() const final { return "test-generic-ir-visitors"; }
|
|
StringRef getDescription() const final { return "Test generic IR visitors."; }
|
|
void runOnOperation() override {
|
|
Operation *outerOp = getOperation();
|
|
int stepNo = 0;
|
|
outerOp->walk([&](Operation *op, const WalkStage &stage) {
|
|
llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
|
|
<< getStageDescription(stage) << "\n";
|
|
});
|
|
|
|
// Exercise static inference of operation type.
|
|
outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
|
|
llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
|
|
<< getStageDescription(stage) << "\n";
|
|
});
|
|
}
|
|
};
|
|
|
|
/// This pass exercises the generic visitor with non-void callbacks and prints
|
|
/// the order and stage in which operations are visited. It will interrupt the
|
|
/// walk based on attributes peesent in the IR.
|
|
struct TestGenericIRVisitorInterruptPass
|
|
: public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestGenericIRVisitorInterruptPass)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-generic-ir-visitors-interrupt";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test generic IR visitors with interrupts.";
|
|
}
|
|
void runOnOperation() override {
|
|
Operation *outerOp = getOperation();
|
|
int stepNo = 0;
|
|
|
|
auto walker = [&](Operation *op, const WalkStage &stage) {
|
|
if (auto interruptBeforeAall =
|
|
op->getAttrOfType<BoolAttr>("interrupt_before_all"))
|
|
if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions())
|
|
return WalkResult::interrupt();
|
|
|
|
if (auto interruptAfterAll =
|
|
op->getAttrOfType<BoolAttr>("interrupt_after_all"))
|
|
if (interruptAfterAll.getValue() && stage.isAfterAllRegions())
|
|
return WalkResult::interrupt();
|
|
|
|
if (auto interruptAfterRegion =
|
|
op->getAttrOfType<IntegerAttr>("interrupt_after_region"))
|
|
if (stage.isAfterRegion(
|
|
static_cast<int>(interruptAfterRegion.getInt())))
|
|
return WalkResult::interrupt();
|
|
|
|
if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all"))
|
|
if (skipBeforeAall.getValue() && stage.isBeforeAllRegions())
|
|
return WalkResult::skip();
|
|
|
|
if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all"))
|
|
if (skipAfterAll.getValue() && stage.isAfterAllRegions())
|
|
return WalkResult::skip();
|
|
|
|
if (auto skipAfterRegion =
|
|
op->getAttrOfType<IntegerAttr>("skip_after_region"))
|
|
if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt())))
|
|
return WalkResult::skip();
|
|
|
|
llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
|
|
<< getStageDescription(stage) << "\n";
|
|
return WalkResult::advance();
|
|
};
|
|
|
|
// Interrupt the walk based on attributes on the operation.
|
|
auto result = outerOp->walk(walker);
|
|
|
|
if (result.wasInterrupted())
|
|
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
|
|
|
|
// Exercise static inference of operation type.
|
|
result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
|
|
return walker(op, stage);
|
|
});
|
|
|
|
if (result.wasInterrupted())
|
|
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
|
|
}
|
|
};
|
|
|
|
struct TestGenericIRBlockVisitorInterruptPass
|
|
: public PassWrapper<TestGenericIRBlockVisitorInterruptPass,
|
|
OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestGenericIRBlockVisitorInterruptPass)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-generic-ir-block-visitors-interrupt";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test generic IR visitors with interrupts, starting with Blocks.";
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
int stepNo = 0;
|
|
|
|
auto walker = [&](Block *block) {
|
|
for (Operation &op : *block)
|
|
if (op.getAttrOfType<BoolAttr>("interrupt"))
|
|
return WalkResult::interrupt();
|
|
|
|
llvm::outs() << "step " << stepNo++ << "\n";
|
|
return WalkResult::advance();
|
|
};
|
|
|
|
auto result = getOperation()->walk(walker);
|
|
if (result.wasInterrupted())
|
|
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
|
|
}
|
|
};
|
|
|
|
struct TestGenericIRRegionVisitorInterruptPass
|
|
: public PassWrapper<TestGenericIRRegionVisitorInterruptPass,
|
|
OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestGenericIRRegionVisitorInterruptPass)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-generic-ir-region-visitors-interrupt";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test generic IR visitors with interrupts, starting with Regions.";
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
int stepNo = 0;
|
|
|
|
auto walker = [&](Region *region) {
|
|
for (Operation &op : region->getOps())
|
|
if (op.getAttrOfType<BoolAttr>("interrupt"))
|
|
return WalkResult::interrupt();
|
|
|
|
llvm::outs() << "step " << stepNo++ << "\n";
|
|
return WalkResult::advance();
|
|
};
|
|
|
|
auto result = getOperation()->walk(walker);
|
|
if (result.wasInterrupted())
|
|
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestGenericIRVisitorsPass() {
|
|
PassRegistration<TestGenericIRVisitorPass>();
|
|
PassRegistration<TestGenericIRVisitorInterruptPass>();
|
|
PassRegistration<TestGenericIRBlockVisitorInterruptPass>();
|
|
PassRegistration<TestGenericIRRegionVisitorInterruptPass>();
|
|
}
|
|
|
|
} // namespace test
|
|
} // namespace mlir
|