186 lines
5.5 KiB
C++
186 lines
5.5 KiB
C++
//===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Analysis/DataFlowFramework.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
/// This analysis state represents an integer that is XOR'd with other states.
|
|
class FooState : public AnalysisState {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
|
|
|
|
using AnalysisState::AnalysisState;
|
|
|
|
/// Returns true if the state is uninitialized.
|
|
bool isUninitialized() const { return !state; }
|
|
|
|
/// Print the integer value or "none" if uninitialized.
|
|
void print(raw_ostream &os) const override {
|
|
if (state)
|
|
os << *state;
|
|
else
|
|
os << "none";
|
|
}
|
|
|
|
/// Join the state with another. If either is unintialized, take the
|
|
/// initialized value. Otherwise, XOR the integer values.
|
|
ChangeResult join(const FooState &rhs) {
|
|
if (rhs.isUninitialized())
|
|
return ChangeResult::NoChange;
|
|
return join(*rhs.state);
|
|
}
|
|
ChangeResult join(uint64_t value) {
|
|
if (isUninitialized()) {
|
|
state = value;
|
|
return ChangeResult::Change;
|
|
}
|
|
uint64_t before = *state;
|
|
state = before ^ value;
|
|
return before == *state ? ChangeResult::NoChange : ChangeResult::Change;
|
|
}
|
|
|
|
/// Set the value of the state directly.
|
|
ChangeResult set(const FooState &rhs) {
|
|
if (state == rhs.state)
|
|
return ChangeResult::NoChange;
|
|
state = rhs.state;
|
|
return ChangeResult::Change;
|
|
}
|
|
|
|
/// Returns the integer value of the state.
|
|
uint64_t getValue() const { return *state; }
|
|
|
|
private:
|
|
/// An optional integer value.
|
|
Optional<uint64_t> state;
|
|
};
|
|
|
|
/// This analysis computes `FooState` across operations and control-flow edges.
|
|
/// If an op specifies a `foo` integer attribute, the contained value is XOR'd
|
|
/// with the value before the operation.
|
|
class FooAnalysis : public DataFlowAnalysis {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis)
|
|
|
|
using DataFlowAnalysis::DataFlowAnalysis;
|
|
|
|
LogicalResult initialize(Operation *top) override;
|
|
LogicalResult visit(ProgramPoint point) override;
|
|
|
|
private:
|
|
void visitBlock(Block *block);
|
|
void visitOperation(Operation *op);
|
|
};
|
|
|
|
struct TestFooAnalysisPass
|
|
: public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
|
|
|
|
StringRef getArgument() const override { return "test-foo-analysis"; }
|
|
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult FooAnalysis::initialize(Operation *top) {
|
|
if (top->getNumRegions() != 1)
|
|
return top->emitError("expected a single region top-level op");
|
|
|
|
// Initialize the top-level state.
|
|
getOrCreate<FooState>(&top->getRegion(0).front())->join(0);
|
|
|
|
// Visit all nested blocks and operations.
|
|
for (Block &block : top->getRegion(0)) {
|
|
visitBlock(&block);
|
|
for (Operation &op : block) {
|
|
if (op.getNumRegions())
|
|
return op.emitError("unexpected op with regions");
|
|
visitOperation(&op);
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult FooAnalysis::visit(ProgramPoint point) {
|
|
if (auto *op = point.dyn_cast<Operation *>()) {
|
|
visitOperation(op);
|
|
return success();
|
|
}
|
|
if (auto *block = point.dyn_cast<Block *>()) {
|
|
visitBlock(block);
|
|
return success();
|
|
}
|
|
return emitError(point.getLoc(), "unknown point kind");
|
|
}
|
|
|
|
void FooAnalysis::visitBlock(Block *block) {
|
|
if (block->isEntryBlock()) {
|
|
// This is the initial state. Let the framework default-initialize it.
|
|
return;
|
|
}
|
|
FooState *state = getOrCreate<FooState>(block);
|
|
ChangeResult result = ChangeResult::NoChange;
|
|
for (Block *pred : block->getPredecessors()) {
|
|
// Join the state at the terminators of all predecessors.
|
|
const FooState *predState =
|
|
getOrCreateFor<FooState>(block, pred->getTerminator());
|
|
result |= state->join(*predState);
|
|
}
|
|
propagateIfChanged(state, result);
|
|
}
|
|
|
|
void FooAnalysis::visitOperation(Operation *op) {
|
|
FooState *state = getOrCreate<FooState>(op);
|
|
ChangeResult result = ChangeResult::NoChange;
|
|
|
|
// Copy the state across the operation.
|
|
const FooState *prevState;
|
|
if (Operation *prev = op->getPrevNode())
|
|
prevState = getOrCreateFor<FooState>(op, prev);
|
|
else
|
|
prevState = getOrCreateFor<FooState>(op, op->getBlock());
|
|
result |= state->set(*prevState);
|
|
|
|
// Modify the state with the attribute, if specified.
|
|
if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
|
|
uint64_t value = attr.getUInt();
|
|
result |= state->join(value);
|
|
}
|
|
propagateIfChanged(state, result);
|
|
}
|
|
|
|
void TestFooAnalysisPass::runOnOperation() {
|
|
func::FuncOp func = getOperation();
|
|
DataFlowSolver solver;
|
|
solver.load<FooAnalysis>();
|
|
if (failed(solver.initializeAndRun(func)))
|
|
return signalPassFailure();
|
|
|
|
raw_ostream &os = llvm::errs();
|
|
os << "function: @" << func.getSymName() << "\n";
|
|
|
|
func.walk([&](Operation *op) {
|
|
auto tag = op->getAttrOfType<StringAttr>("tag");
|
|
if (!tag)
|
|
return;
|
|
const FooState *state = solver.lookupState<FooState>(op);
|
|
assert(state && !state->isUninitialized());
|
|
os << tag.getValue() << " -> " << state->getValue() << "\n";
|
|
});
|
|
}
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
|
|
} // namespace test
|
|
} // namespace mlir
|