llvm-project/mlir/lib/Transforms/Utils/CommutativityUtils.cpp

315 lines
11 KiB
C++

//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a commutativity utility pattern and a function to
// populate this pattern. The function is intended to be used inside passes to
// simplify the matching of commutative operations by fixing the order of their
// operands.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/CommutativityUtils.h"
#include <queue>
using namespace mlir;
/// The possible "types" of ancestors. Here, an ancestor is an op or a block
/// argument present in the backward slice of a value.
enum AncestorType {
/// Pertains to a block argument.
BLOCK_ARGUMENT,
/// Pertains to a non-constant-like op.
NON_CONSTANT_OP,
/// Pertains to a constant-like op.
CONSTANT_OP
};
/// Stores the "key" associated with an ancestor.
struct AncestorKey {
/// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on
/// the ancestor.
AncestorType type;
/// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or
/// `CONSTANT_OP`. Else, holds "".
StringRef opName;
/// Constructor for `AncestorKey`.
AncestorKey(Operation *op) {
if (!op) {
type = BLOCK_ARGUMENT;
} else {
type =
op->hasTrait<OpTrait::ConstantLike>() ? CONSTANT_OP : NON_CONSTANT_OP;
opName = op->getName().getStringRef();
}
}
/// Overloaded operator `<` for `AncestorKey`.
///
/// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those
/// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in
/// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller
/// ones are the ones with smaller op names (lexicographically).
///
/// TODO: Include other information like attributes, value type, etc., to
/// enhance this comparison. For example, currently this comparison doesn't
/// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and
/// `addi (in i64)`. Such an enhancement should only be done if the need
/// arises.
bool operator<(const AncestorKey &key) const {
return std::tie(type, opName) < std::tie(key.type, key.opName);
}
};
/// Stores a commutative operand along with its BFS traversal information.
struct CommutativeOperand {
/// Stores the operand.
Value operand;
/// Stores the queue of ancestors of the operand's BFS traversal at a
/// particular point in time.
std::queue<Operation *> ancestorQueue;
/// Stores the list of ancestors that have been visited by the BFS traversal
/// at a particular point in time.
DenseSet<Operation *> visitedAncestors;
/// Stores the operand's "key". This "key" is defined as a list of the
/// "AncestorKeys" associated with the ancestors of this operand, in a
/// breadth-first order.
///
/// So, if an operand, say `A`, was produced as follows:
///
/// `<block argument>` `<block argument>`
/// \ /
/// \ /
/// `arith.subi` `arith.constant`
/// \ /
/// `arith.addi`
/// |
/// returns `A`
///
/// Then, the ancestors of `A`, in the breadth-first order are:
/// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and
/// `<block argument>`.
///
/// Thus, the "key" associated with operand `A` is:
/// {
/// {type: `NON_CONSTANT_OP`, opName: "arith.addi"},
/// {type: `NON_CONSTANT_OP`, opName: "arith.subi"},
/// {type: `CONSTANT_OP`, opName: "arith.constant"},
/// {type: `BLOCK_ARGUMENT`, opName: ""},
/// {type: `BLOCK_ARGUMENT`, opName: ""}
/// }
SmallVector<AncestorKey, 4> key;
/// Push an ancestor into the operand's BFS information structure. This
/// entails it being pushed into the queue (always) and inserted into the
/// "visited ancestors" list (iff it is an op rather than a block argument).
void pushAncestor(Operation *op) {
ancestorQueue.push(op);
if (op)
visitedAncestors.insert(op);
}
/// Refresh the key.
///
/// Refreshing a key entails making it up-to-date with the operand's BFS
/// traversal that has happened till that point in time, i.e, appending the
/// existing key with the front ancestor's "AncestorKey". Note that a key
/// directly reflects the BFS and thus needs to be refreshed during the
/// progression of the traversal.
void refreshKey() {
if (ancestorQueue.empty())
return;
Operation *frontAncestor = ancestorQueue.front();
AncestorKey frontAncestorKey(frontAncestor);
key.push_back(frontAncestorKey);
}
/// Pop the front ancestor, if any, from the queue and then push its adjacent
/// unvisited ancestors, if any, to the queue (this is the main body of the
/// BFS algorithm).
void popFrontAndPushAdjacentUnvisitedAncestors() {
if (ancestorQueue.empty())
return;
Operation *frontAncestor = ancestorQueue.front();
ancestorQueue.pop();
if (!frontAncestor)
return;
for (Value operand : frontAncestor->getOperands()) {
Operation *operandDefOp = operand.getDefiningOp();
if (!operandDefOp || !visitedAncestors.contains(operandDefOp))
pushAncestor(operandDefOp);
}
}
};
/// Sorts the operands of `op` in ascending order of the "key" associated with
/// each operand iff `op` is commutative. This is a stable sort.
///
/// After the application of this pattern, since the commutative operands now
/// have a deterministic order in which they occur in an op, the matching of
/// large DAGs becomes much simpler, i.e., requires much less number of checks
/// to be written by a user in her/his pattern matching function.
///
/// Some examples of such a sorting:
///
/// Assume that the sorting is being applied to `foo.commutative`, which is a
/// commutative op.
///
/// Example 1:
///
/// %1 = foo.const 0
/// %2 = foo.mul <block argument>, <block argument>
/// %3 = foo.commutative %1, %2
///
/// Here,
/// 1. The key associated with %1 is:
/// `{
/// {CONSTANT_OP, "foo.const"}
/// }`
/// 2. The key associated with %2 is:
/// `{
/// {NON_CONSTANT_OP, "foo.mul"},
/// {BLOCK_ARGUMENT, ""},
/// {BLOCK_ARGUMENT, ""}
/// }`
///
/// The key of %2 < the key of %1
/// Thus, the sorted `foo.commutative` is:
/// %3 = foo.commutative %2, %1
///
/// Example 2:
///
/// %1 = foo.const 0
/// %2 = foo.mul <block argument>, <block argument>
/// %3 = foo.mul %2, %1
/// %4 = foo.add %2, %1
/// %5 = foo.commutative %1, %2, %3, %4
///
/// Here,
/// 1. The key associated with %1 is:
/// `{
/// {CONSTANT_OP, "foo.const"}
/// }`
/// 2. The key associated with %2 is:
/// `{
/// {NON_CONSTANT_OP, "foo.mul"},
/// {BLOCK_ARGUMENT, ""}
/// }`
/// 3. The key associated with %3 is:
/// `{
/// {NON_CONSTANT_OP, "foo.mul"},
/// {NON_CONSTANT_OP, "foo.mul"},
/// {CONSTANT_OP, "foo.const"},
/// {BLOCK_ARGUMENT, ""},
/// {BLOCK_ARGUMENT, ""}
/// }`
/// 4. The key associated with %4 is:
/// `{
/// {NON_CONSTANT_OP, "foo.add"},
/// {NON_CONSTANT_OP, "foo.mul"},
/// {CONSTANT_OP, "foo.const"},
/// {BLOCK_ARGUMENT, ""},
/// {BLOCK_ARGUMENT, ""}
/// }`
///
/// Thus, the sorted `foo.commutative` is:
/// %5 = foo.commutative %4, %3, %2, %1
class SortCommutativeOperands : public RewritePattern {
public:
SortCommutativeOperands(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Custom comparator for two commutative operands, which returns true iff
// the "key" of `constCommOperandA` < the "key" of `constCommOperandB`,
// i.e.,
// 1. In the first unequal pair of corresponding AncestorKeys, the
// AncestorKey in `constCommOperandA` is smaller, or,
// 2. Both the AncestorKeys in every pair are the same and the size of
// `constCommOperandA`'s "key" is smaller.
auto commutativeOperandComparator =
[](const std::unique_ptr<CommutativeOperand> &constCommOperandA,
const std::unique_ptr<CommutativeOperand> &constCommOperandB) {
if (constCommOperandA->operand == constCommOperandB->operand)
return false;
auto &commOperandA =
const_cast<std::unique_ptr<CommutativeOperand> &>(
constCommOperandA);
auto &commOperandB =
const_cast<std::unique_ptr<CommutativeOperand> &>(
constCommOperandB);
// Iteratively perform the BFS's of both operands until an order among
// them can be determined.
unsigned keyIndex = 0;
while (true) {
if (commOperandA->key.size() <= keyIndex) {
if (commOperandA->ancestorQueue.empty())
return true;
commOperandA->popFrontAndPushAdjacentUnvisitedAncestors();
commOperandA->refreshKey();
}
if (commOperandB->key.size() <= keyIndex) {
if (commOperandB->ancestorQueue.empty())
return false;
commOperandB->popFrontAndPushAdjacentUnvisitedAncestors();
commOperandB->refreshKey();
}
if (commOperandA->ancestorQueue.empty() ||
commOperandB->ancestorQueue.empty())
return commOperandA->key.size() < commOperandB->key.size();
if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex])
return true;
if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex])
return false;
keyIndex++;
}
};
// If `op` is not commutative, do nothing.
if (!op->hasTrait<OpTrait::IsCommutative>())
return failure();
// Populate the list of commutative operands.
SmallVector<Value, 2> operands = op->getOperands();
SmallVector<std::unique_ptr<CommutativeOperand>, 2> commOperands;
for (Value operand : operands) {
std::unique_ptr<CommutativeOperand> commOperand =
std::make_unique<CommutativeOperand>();
commOperand->operand = operand;
commOperand->pushAncestor(operand.getDefiningOp());
commOperand->refreshKey();
commOperands.push_back(std::move(commOperand));
}
// Sort the operands.
std::stable_sort(commOperands.begin(), commOperands.end(),
commutativeOperandComparator);
SmallVector<Value, 2> sortedOperands;
for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands)
sortedOperands.push_back(commOperand->operand);
if (sortedOperands == operands)
return failure();
rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); });
return success();
}
};
void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) {
patterns.add<SortCommutativeOperands>(patterns.getContext());
}