294 lines
12 KiB
Markdown
294 lines
12 KiB
Markdown
# Writing DataFlow Analyses in MLIR
|
|
|
|
Writing dataflow analyses in MLIR, or well any compiler, can often seem quite
|
|
daunting and/or complex. A dataflow analysis generally involves propagating
|
|
information about the IR across various different types of control flow
|
|
constructs, of which MLIR has many (Block-based branches, Region-based branches,
|
|
CallGraph, etc), and it isn't always clear how best to go about performing the
|
|
propagation. To help writing these types of analyses in MLIR, this document
|
|
details several utilities that simplify the process and make it a bit more
|
|
approachable.
|
|
|
|
## Forward Dataflow Analysis
|
|
|
|
One type of dataflow analysis is a forward propagation analysis. This type of
|
|
analysis, as the name may suggest, propagates information forward (e.g. from
|
|
definitions to uses). To provide a bit of concrete context, let's go over
|
|
writing a simple forward dataflow analysis in MLIR. Let's say for this analysis
|
|
that we want to propagate information about a special "metadata" dictionary
|
|
attribute. The contents of this attribute are simply a set of metadata that
|
|
describe a specific value, e.g. `metadata = { likes_pizza = true }`. We will
|
|
collect the `metadata` for operations in the IR and propagate them about.
|
|
|
|
### Lattices
|
|
|
|
Before going into how one might setup the analysis itself, it is important to
|
|
first introduce the concept of a `Lattice` and how we will use it for the
|
|
analysis. A lattice represents all of the possible values or results of the
|
|
analysis for a given value. A lattice element holds the set of information
|
|
computed by the analysis for a given value, and is what gets propagated across
|
|
the IR. For our analysis, this would correspond to the `metadata` dictionary
|
|
attribute.
|
|
|
|
Regardless of the value held within, every type of lattice contains two special
|
|
element states:
|
|
|
|
* `uninitialized`
|
|
|
|
- The element has not been initialized.
|
|
|
|
* `top`/`overdefined`/`unknown`
|
|
|
|
- The element encompasses every possible value.
|
|
- This is a very conservative state, and essentially means "I can't make
|
|
any assumptions about the value, it could be anything"
|
|
|
|
These two states are important when merging, or `join`ing as we will refer to it
|
|
further in this document, information as part of the analysis. Lattice elements
|
|
are `join`ed whenever there are two different source points, such as an argument
|
|
to a block with multiple predecessors. One important note about the `join`
|
|
operation, is that it is required to be monotonic (see the `join` method in the
|
|
example below for more information). This ensures that `join`ing elements is
|
|
consistent. The two special states mentioned above have unique properties during
|
|
a `join`:
|
|
|
|
* `uninitialized`
|
|
|
|
- If one of the elements is `uninitialized`, the other element is used.
|
|
- `uninitialized` in the context of a `join` essentially means "take the
|
|
other thing".
|
|
|
|
* `top`/`overdefined`/`unknown`
|
|
|
|
- If one of the elements being joined is `overdefined`, the result is
|
|
`overdefined`.
|
|
|
|
For our analysis in MLIR, we will need to define a class representing the value
|
|
held by an element of the lattice used by our dataflow analysis:
|
|
|
|
```c++
|
|
/// The value of our lattice represents the inner structure of a DictionaryAttr,
|
|
/// for the `metadata`.
|
|
struct MetadataLatticeValue {
|
|
MetadataLatticeValue() = default;
|
|
/// Compute a lattice value from the provided dictionary.
|
|
MetadataLatticeValue(DictionaryAttr attr)
|
|
: metadata(attr.begin(), attr.end()) {}
|
|
|
|
/// Return a pessimistic value state, i.e. the `top`/`overdefined`/`unknown`
|
|
/// state, for our value type. The resultant state should not assume any
|
|
/// information about the state of the IR.
|
|
static MetadataLatticeValue getPessimisticValueState(MLIRContext *context) {
|
|
// The `top`/`overdefined`/`unknown` state is when we know nothing about any
|
|
// metadata, i.e. an empty dictionary.
|
|
return MetadataLatticeValue();
|
|
}
|
|
/// Return a pessimistic value state for our value type using only information
|
|
/// about the state of the provided IR. This is similar to the above method,
|
|
/// but may produce a slightly more refined result. This is okay, as the
|
|
/// information is already encoded as fact in the IR.
|
|
static MetadataLatticeValue getPessimisticValueState(Value value) {
|
|
// Check to see if the parent operation has metadata.
|
|
if (Operation *parentOp = value.getDefiningOp()) {
|
|
if (auto metadata = parentOp->getAttrOfType<DictionaryAttr>("metadata"))
|
|
return MetadataLatticeValue(metadata);
|
|
|
|
// If no metadata is present, fallback to the
|
|
// `top`/`overdefined`/`unknown` state.
|
|
}
|
|
return MetadataLatticeValue();
|
|
}
|
|
|
|
/// This method conservatively joins the information held by `lhs` and `rhs`
|
|
/// into a new value. This method is required to be monotonic. `monotonicity`
|
|
/// is implied by the satisfaction of the following axioms:
|
|
/// * idempotence: join(x,x) == x
|
|
/// * commutativity: join(x,y) == join(y,x)
|
|
/// * associativity: join(x,join(y,z)) == join(join(x,y),z)
|
|
///
|
|
/// When the above axioms are satisfied, we achieve `monotonicity`:
|
|
/// * monotonicity: join(x, join(x,y)) == join(x,y)
|
|
static MetadataLatticeValue join(const MetadataLatticeValue &lhs,
|
|
const MetadataLatticeValue &rhs) {
|
|
// To join `lhs` and `rhs` we will define a simple policy, which is that we
|
|
// only keep information that is the same. This means that we only keep
|
|
// facts that are true in both.
|
|
MetadataLatticeValue result;
|
|
for (const auto &lhsIt : lhs) {
|
|
// As noted above, we only merge if the values are the same.
|
|
auto it = rhs.metadata.find(lhsIt.first);
|
|
if (it == rhs.metadata.end() || it->second != lhsIt.second)
|
|
continue;
|
|
result.insert(lhsIt);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// A simple comparator that checks to see if this value is equal to the one
|
|
/// provided.
|
|
bool operator==(const MetadataLatticeValue &rhs) const {
|
|
if (metadata.size() != rhs.metadata.size())
|
|
return false;
|
|
// Check that the 'rhs' contains the same metadata.
|
|
return llvm::all_of(metadata, [&](auto &it) {
|
|
return rhs.metadata.count(it.second);
|
|
});
|
|
}
|
|
|
|
/// Our value represents the combined metadata, which is originally a
|
|
/// DictionaryAttr, so we use a map.
|
|
DenseMap<StringAttr, Attribute> metadata;
|
|
};
|
|
```
|
|
|
|
One interesting thing to note above is that we don't have an explicit method for
|
|
the `uninitialized` state. This state is handled by the `LatticeElement` class,
|
|
which manages a lattice value for a given IR entity. A quick overview of this
|
|
class, and the API that will be interesting to us while writing our analysis, is
|
|
shown below:
|
|
|
|
```c++
|
|
/// This class represents a lattice element holding a specific value of type
|
|
/// `ValueT`.
|
|
template <typename ValueT>
|
|
class LatticeElement ... {
|
|
public:
|
|
/// Return the value held by this element. This requires that a value is
|
|
/// known, i.e. not `uninitialized`.
|
|
ValueT &getValue();
|
|
const ValueT &getValue() const;
|
|
|
|
/// Join the information contained in the 'rhs' element into this
|
|
/// element. Returns if the state of the current element changed.
|
|
ChangeResult join(const LatticeElement<ValueT> &rhs);
|
|
|
|
/// Join the information contained in the 'rhs' value into this
|
|
/// lattice. Returns if the state of the current lattice changed.
|
|
ChangeResult join(const ValueT &rhs);
|
|
|
|
/// Mark the lattice element as having reached a pessimistic fixpoint. This
|
|
/// means that the lattice may potentially have conflicting value states, and
|
|
/// only the conservatively known value state should be relied on.
|
|
ChangeResult markPessimisticFixPoint();
|
|
};
|
|
```
|
|
|
|
With our lattice defined, we can now define the driver that will compute and
|
|
propagate our lattice across the IR.
|
|
|
|
### ForwardDataflowAnalysis Driver
|
|
|
|
The `ForwardDataFlowAnalysis` class represents the driver of the dataflow
|
|
analysis, and performs all of the related analysis computation. When defining
|
|
our analysis, we will inherit from this class and implement some of its hooks.
|
|
Before that, let's look at a quick overview of this class and some of the
|
|
important API for our analysis:
|
|
|
|
```c++
|
|
/// This class represents the main driver of the forward dataflow analysis. It
|
|
/// takes as a template parameter the value type of lattice being computed.
|
|
template <typename ValueT>
|
|
class ForwardDataFlowAnalysis : ... {
|
|
public:
|
|
ForwardDataFlowAnalysis(MLIRContext *context);
|
|
|
|
/// Compute the analysis on operations rooted under the given top-level
|
|
/// operation. Note that the top-level operation is not visited.
|
|
void run(Operation *topLevelOp);
|
|
|
|
/// Return the lattice element attached to the given value. If a lattice has
|
|
/// not been added for the given value, a new 'uninitialized' value is
|
|
/// inserted and returned.
|
|
LatticeElement<ValueT> &getLatticeElement(Value value);
|
|
|
|
/// Return the lattice element attached to the given value, or nullptr if no
|
|
/// lattice element for the value has yet been created.
|
|
LatticeElement<ValueT> *lookupLatticeElement(Value value);
|
|
|
|
/// Mark all of the lattice elements for the given range of Values as having
|
|
/// reached a pessimistic fixpoint.
|
|
ChangeResult markAllPessimisticFixPoint(ValueRange values);
|
|
|
|
protected:
|
|
/// Visit the given operation, and join any necessary analysis state
|
|
/// into the lattice elements for the results and block arguments owned by
|
|
/// this operation using the provided set of operand lattice elements
|
|
/// (all pointer values are guaranteed to be non-null). Returns if any result
|
|
/// or block argument value lattice elements changed during the visit. The
|
|
/// lattice element for a result or block argument value can be obtained, and
|
|
/// join'ed into, by using `getLatticeElement`.
|
|
virtual ChangeResult visitOperation(
|
|
Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) = 0;
|
|
};
|
|
```
|
|
|
|
NOTE: Some API has been redacted for our example. The `ForwardDataFlowAnalysis`
|
|
contains various other hooks that allow for injecting custom behavior when
|
|
applicable.
|
|
|
|
The main API that we are responsible for defining is the `visitOperation`
|
|
method. This method is responsible for computing new lattice elements for the
|
|
results and block arguments owned by the given operation. This is where we will
|
|
inject the lattice element computation logic, also known as the transfer
|
|
function for the operation, that is specific to our analysis. A simple
|
|
implementation for our example is shown below:
|
|
|
|
```c++
|
|
class MetadataAnalysis : public ForwardDataFlowAnalysis<MetadataLatticeValue> {
|
|
public:
|
|
using ForwardDataFlowAnalysis<MetadataLatticeValue>::ForwardDataFlowAnalysis;
|
|
|
|
ChangeResult visitOperation(
|
|
Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) override {
|
|
DictionaryAttr metadata = op->getAttrOfType<DictionaryAttr>("metadata");
|
|
|
|
// If we have no metadata for this operation, we will conservatively mark
|
|
// all of the results as having reached a pessimistic fixpoint.
|
|
if (!metadata)
|
|
return markAllPessimisticFixPoint(op->getResults());
|
|
|
|
// Otherwise, we will compute a lattice value for the metadata and join it
|
|
// into the current lattice element for all of our results.
|
|
MetadataLatticeValue latticeValue(metadata);
|
|
ChangeResult result = ChangeResult::NoChange;
|
|
for (Value value : op->getResults()) {
|
|
// We grab the lattice element for `value` via `getLatticeElement` and
|
|
// then join it with the lattice value for this operation's metadata. Note
|
|
// that during the analysis phase, it is fine to freely create a new
|
|
// lattice element for a value. This is why we don't use the
|
|
// `lookupLatticeElement` method here.
|
|
result |= getLatticeElement(value).join(latticeValue);
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
```
|
|
|
|
With that, we have all of the necessary components to compute our analysis.
|
|
After the analysis has been computed, we can grab any computed information for
|
|
values by using `lookupLatticeElement`. We use this function over
|
|
`getLatticeElement` as the analysis is not guaranteed to visit all values, e.g.
|
|
if the value is in a unreachable block, and we don't want to create a new
|
|
uninitialized lattice element in this case. See below for a quick example:
|
|
|
|
```c++
|
|
void MyPass::runOnOperation() {
|
|
MetadataAnalysis analysis(&getContext());
|
|
analysis.run(getOperation());
|
|
...
|
|
}
|
|
|
|
void MyPass::useAnalysisOn(MetadataAnalysis &analysis, Value value) {
|
|
LatticeElement<MetadataLatticeValue> *latticeElement = analysis.lookupLatticeElement(value);
|
|
|
|
// If we don't have an element, the `value` wasn't visited during our analysis
|
|
// meaning that it could be dead. We need to treat this conservatively.
|
|
if (!lattice)
|
|
return;
|
|
|
|
// Our lattice element has a value, use it:
|
|
MetadataLatticeValue &value = lattice->getValue();
|
|
...
|
|
}
|
|
```
|