259 lines
10 KiB
C++
259 lines
10 KiB
C++
//===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// These rewriters lower from the Tosa to the Arith dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
using namespace tosa;
|
|
|
|
namespace {
|
|
|
|
class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
|
|
public:
|
|
using OpRewritePattern<tosa::ConstOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ConstOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
Type matchContainerType(Type element, Type container) {
|
|
if (auto shapedTy = container.dyn_cast<ShapedType>())
|
|
return shapedTy.clone(element);
|
|
|
|
return element;
|
|
}
|
|
|
|
Attribute getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
|
|
if (auto shapedTy = type.dyn_cast<ShapedType>()) {
|
|
Type eTy = shapedTy.getElementType();
|
|
APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
|
|
return DenseIntElementsAttr::get(shapedTy, valueInt);
|
|
}
|
|
|
|
return rewriter.getIntegerAttr(type, value);
|
|
}
|
|
|
|
Value getConstantValue(Location loc, Type type, int64_t value,
|
|
PatternRewriter &rewriter) {
|
|
return rewriter.create<arith::ConstantOp>(
|
|
loc, getConstantAttr(type, value, rewriter));
|
|
}
|
|
|
|
// This converts the TOSA ApplyScale operator to a set of arithmetic ops,
|
|
// using 64-bit operations to perform the necessary multiply, bias, and shift.
|
|
class ApplyScaleGenericOpConverter
|
|
: public OpRewritePattern<tosa::ApplyScaleOp> {
|
|
public:
|
|
using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
Value value = op.getValue();
|
|
Value multiplier32 = op.getMultiplier();
|
|
|
|
Type resultTy = op.getType();
|
|
Type valueTy = value.getType();
|
|
Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
|
|
Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
|
|
|
|
Value zero = getConstantValue(loc, valueTy, 0, rewriter);
|
|
Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
|
|
Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
|
|
|
|
Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
|
|
|
|
// Compute the multiplication in 64-bits then select the high / low parts.
|
|
Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
|
|
Value multiplier64 =
|
|
rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
|
|
Value multiply64 =
|
|
rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
|
|
|
|
// Apply normal rounding.
|
|
Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
|
|
Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
|
|
round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
|
|
multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
|
|
|
|
// Apply double rounding if necessary.
|
|
if (op.getDoubleRound()) {
|
|
int64_t roundInt = 1 << 30;
|
|
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
|
|
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
|
|
Value positive = rewriter.create<arith::CmpIOp>(
|
|
loc, arith::CmpIPredicate::sge, value, zero);
|
|
Value dir =
|
|
rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
|
|
Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
|
|
Value valid = rewriter.create<arith::CmpIOp>(
|
|
loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
|
|
multiply64 =
|
|
rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
|
|
}
|
|
|
|
Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
|
|
Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
|
|
|
|
rewriter.replaceOp(op, result32);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
|
|
public:
|
|
using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
|
|
Type resultTy = op.getType();
|
|
Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
|
|
Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
|
|
|
|
Value value = op.getValue();
|
|
if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
|
|
return failure();
|
|
}
|
|
|
|
Value value32 = op.getValue();
|
|
Value multiplier32 = op.getMultiplier();
|
|
Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
|
|
|
|
// Constants used during the scaling operation.
|
|
Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
|
|
Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
|
|
Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
|
|
Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
|
|
Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
|
|
Value thirtyTwo64 = getConstantValue(loc, i64Ty, 32, rewriter);
|
|
|
|
// Compute the multiplication in 64-bits then select the high / low parts.
|
|
Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value32);
|
|
Value multiplier64 =
|
|
rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
|
|
Value multiply64 =
|
|
rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
|
|
|
|
// Grab out the high/low of the computation
|
|
Value high64 =
|
|
rewriter.create<arith::ShRUIOp>(loc, multiply64, thirtyTwo64);
|
|
Value high32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, high64);
|
|
Value low32 = rewriter.create<arith::MulIOp>(loc, value32, multiplier32);
|
|
|
|
// Determine the direction and amount to shift the high bits.
|
|
Value shiftOver32 = rewriter.create<arith::CmpIOp>(
|
|
loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
|
|
Value roundHighBits = rewriter.create<arith::CmpIOp>(
|
|
loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
|
|
|
|
Value shiftHighL =
|
|
rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
|
|
Value shiftHighR =
|
|
rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
|
|
|
|
shiftHighL =
|
|
rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
|
|
shiftHighR =
|
|
rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
|
|
|
|
// Conditionally perform our double round.
|
|
if (op.getDoubleRound()) {
|
|
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
|
|
Value valuePositive = rewriter.create<arith::CmpIOp>(
|
|
loc, arith::CmpIPredicate::sge, value32, zero32);
|
|
|
|
Value roundDir =
|
|
rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
|
|
roundDir =
|
|
rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
|
|
|
|
Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
|
|
Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
|
|
Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
|
|
|
|
Value shiftRound =
|
|
rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
|
|
|
|
low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
|
|
high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
|
|
}
|
|
|
|
// Conditionally apply rounding in the low bits.
|
|
{
|
|
Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
|
|
Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
|
|
roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
|
|
roundBit);
|
|
|
|
Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
|
|
Value wasRounded = rewriter.create<arith::CmpIOp>(
|
|
loc, arith::CmpIPredicate::ugt, low32, newLow32);
|
|
low32 = newLow32;
|
|
|
|
Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
|
|
high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
|
|
}
|
|
|
|
// Conditionally apply rounding in the high bits.
|
|
{
|
|
Value shiftSubOne =
|
|
rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
|
|
Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
|
|
roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
|
|
zero32);
|
|
high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
|
|
}
|
|
|
|
// Combine the correct high/low bits into the final rescale result.
|
|
high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
|
|
high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
|
|
low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
|
|
low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
|
|
|
|
// Apply the rounding behavior and shift to the final alignment.
|
|
Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
|
|
|
|
// Truncate if necessary.
|
|
if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
|
|
result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
|
|
}
|
|
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::tosa::populateTosaToArithConversionPatterns(
|
|
RewritePatternSet *patterns) {
|
|
patterns->add<ConstOpConverter>(patterns->getContext());
|
|
}
|
|
|
|
void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
|
|
RewritePatternSet *patterns, bool include32Bit) {
|
|
patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100);
|
|
if (include32Bit) {
|
|
patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200);
|
|
}
|
|
}
|