1109 lines
49 KiB
C++
1109 lines
49 KiB
C++
//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include <memory>
|
|
#include <type_traits>
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
|
|
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = op.getType();
|
|
|
|
Value real =
|
|
rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
|
|
Value imag =
|
|
rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
|
|
Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
|
|
Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
|
|
Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
|
|
|
|
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
|
|
struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
|
|
using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
auto type = op.getType().cast<ComplexType>();
|
|
Type elementType = type.getElementType();
|
|
|
|
Value lhs = adaptor.getLhs();
|
|
Value rhs = adaptor.getRhs();
|
|
|
|
Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
|
|
Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
|
|
Value rhsSquaredPlusLhsSquared =
|
|
b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
|
|
Value sqrtOfRhsSquaredPlusLhsSquared =
|
|
b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
|
|
|
|
Value zero =
|
|
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
|
|
Value one = b.create<arith::ConstantOp>(elementType,
|
|
b.getFloatAttr(elementType, 1));
|
|
Value i = b.create<complex::CreateOp>(type, zero, one);
|
|
Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
|
|
Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
|
|
|
|
Value divResult =
|
|
b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
|
|
Value logResult = b.create<complex::LogOp>(divResult);
|
|
|
|
Value negativeOne = b.create<arith::ConstantOp>(
|
|
elementType, b.getFloatAttr(elementType, -1));
|
|
Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <typename ComparisonOp, arith::CmpFPredicate p>
|
|
struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
|
|
using OpConversionPattern<ComparisonOp>::OpConversionPattern;
|
|
using ResultCombiner =
|
|
std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
|
|
arith::AndIOp, arith::OrIOp>;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = adaptor.getLhs()
|
|
.getType()
|
|
.template cast<ComplexType>()
|
|
.getElementType();
|
|
|
|
Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
|
|
Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
|
|
Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
|
|
Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
|
|
Value realComparison =
|
|
rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
|
|
Value imagComparison =
|
|
rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
|
|
|
|
rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
|
|
imagComparison);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Default conversion which applies the BinaryStandardOp separately on the real
|
|
// and imaginary parts. Can for example be used for complex::AddOp and
|
|
// complex::SubOp.
|
|
template <typename BinaryComplexOp, typename BinaryStandardOp>
|
|
struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
|
|
using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = adaptor.getLhs().getType().template cast<ComplexType>();
|
|
auto elementType = type.getElementType().template cast<FloatType>();
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
|
|
Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
|
|
Value resultReal =
|
|
b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
|
|
Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
|
|
Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
|
|
Value resultImag =
|
|
b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <typename TrigonometricOp>
|
|
struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
|
|
using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
|
|
|
|
using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = adaptor.getComplex().getType().template cast<ComplexType>();
|
|
auto elementType = type.getElementType().template cast<FloatType>();
|
|
|
|
Value real =
|
|
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
|
|
Value imag =
|
|
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
|
|
|
|
// Trigonometric ops use a set of common building blocks to convert to real
|
|
// ops. Here we create these building blocks and call into an op-specific
|
|
// implementation in the subclass to combine them.
|
|
Value half = rewriter.create<arith::ConstantOp>(
|
|
loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
|
|
Value exp = rewriter.create<math::ExpOp>(loc, imag);
|
|
Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
|
|
Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
|
|
Value sin = rewriter.create<math::SinOp>(loc, real);
|
|
Value cos = rewriter.create<math::CosOp>(loc, real);
|
|
|
|
auto resultPair =
|
|
combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
|
|
resultPair.second);
|
|
return success();
|
|
}
|
|
|
|
virtual std::pair<Value, Value>
|
|
combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
|
|
Value cos, ConversionPatternRewriter &rewriter) const = 0;
|
|
};
|
|
|
|
struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
|
|
using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
|
|
|
|
std::pair<Value, Value>
|
|
combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
|
|
Value cos, ConversionPatternRewriter &rewriter) const override {
|
|
// Complex cosine is defined as;
|
|
// cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
|
|
// Plugging in:
|
|
// exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
|
|
// exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
|
|
// and defining t := exp(y)
|
|
// We get:
|
|
// Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
|
|
// Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
|
|
Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
|
|
Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
|
|
Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
|
|
Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
|
|
return {resultReal, resultImag};
|
|
}
|
|
};
|
|
|
|
struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
|
|
using OpConversionPattern<complex::DivOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = adaptor.getLhs().getType().cast<ComplexType>();
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
|
|
Value lhsReal =
|
|
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
|
|
Value lhsImag =
|
|
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
|
|
Value rhsReal =
|
|
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
|
|
Value rhsImag =
|
|
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
|
|
|
|
// Smith's algorithm to divide complex numbers. It is just a bit smarter
|
|
// way to compute the following formula:
|
|
// (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
|
|
// = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
|
|
// ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
|
|
// = ((lhsReal * rhsReal + lhsImag * rhsImag) +
|
|
// (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
|
|
//
|
|
// Depending on whether |rhsReal| < |rhsImag| we compute either
|
|
// rhsRealImagRatio = rhsReal / rhsImag
|
|
// rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
|
|
// resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
|
|
// resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
|
|
//
|
|
// or
|
|
//
|
|
// rhsImagRealRatio = rhsImag / rhsReal
|
|
// rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
|
|
// resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
|
|
// resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
|
|
//
|
|
// See https://dl.acm.org/citation.cfm?id=368661 for more details.
|
|
Value rhsRealImagRatio =
|
|
rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
|
|
Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
|
|
loc, rhsImag,
|
|
rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
|
|
Value realNumerator1 = rewriter.create<arith::AddFOp>(
|
|
loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
|
|
lhsImag);
|
|
Value resultReal1 =
|
|
rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
|
|
Value imagNumerator1 = rewriter.create<arith::SubFOp>(
|
|
loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
|
|
lhsReal);
|
|
Value resultImag1 =
|
|
rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
|
|
|
|
Value rhsImagRealRatio =
|
|
rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
|
|
Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
|
|
loc, rhsReal,
|
|
rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
|
|
Value realNumerator2 = rewriter.create<arith::AddFOp>(
|
|
loc, lhsReal,
|
|
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
|
|
Value resultReal2 =
|
|
rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
|
|
Value imagNumerator2 = rewriter.create<arith::SubFOp>(
|
|
loc, lhsImag,
|
|
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
|
|
Value resultImag2 =
|
|
rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
|
|
|
|
// Consider corner cases.
|
|
// Case 1. Zero denominator, numerator contains at most one NaN value.
|
|
Value zero = rewriter.create<arith::ConstantOp>(
|
|
loc, elementType, rewriter.getZeroAttr(elementType));
|
|
Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal);
|
|
Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
|
|
Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag);
|
|
Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
|
|
Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::ORD, lhsReal, zero);
|
|
Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::ORD, lhsImag, zero);
|
|
Value lhsContainsNotNaNValue =
|
|
rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
|
|
Value resultIsInfinity = rewriter.create<arith::AndIOp>(
|
|
loc, lhsContainsNotNaNValue,
|
|
rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
|
|
Value inf = rewriter.create<arith::ConstantOp>(
|
|
loc, elementType,
|
|
rewriter.getFloatAttr(
|
|
elementType, APFloat::getInf(elementType.getFloatSemantics())));
|
|
Value infWithSignOfRhsReal =
|
|
rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
|
|
Value infinityResultReal =
|
|
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
|
|
Value infinityResultImag =
|
|
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
|
|
|
|
// Case 2. Infinite numerator, finite denominator.
|
|
Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
|
|
Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
|
|
Value rhsFinite =
|
|
rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
|
|
Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal);
|
|
Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
|
|
Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag);
|
|
Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
|
|
Value lhsInfinite =
|
|
rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
|
|
Value infNumFiniteDenom =
|
|
rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
|
|
Value one = rewriter.create<arith::ConstantOp>(
|
|
loc, elementType, rewriter.getFloatAttr(elementType, 1));
|
|
Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
|
|
loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
|
|
lhsReal);
|
|
Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
|
|
loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
|
|
lhsImag);
|
|
Value lhsRealIsInfWithSignTimesRhsReal =
|
|
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
|
|
Value lhsImagIsInfWithSignTimesRhsImag =
|
|
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
|
|
Value resultReal3 = rewriter.create<arith::MulFOp>(
|
|
loc, inf,
|
|
rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
|
|
lhsImagIsInfWithSignTimesRhsImag));
|
|
Value lhsRealIsInfWithSignTimesRhsImag =
|
|
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
|
|
Value lhsImagIsInfWithSignTimesRhsReal =
|
|
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
|
|
Value resultImag3 = rewriter.create<arith::MulFOp>(
|
|
loc, inf,
|
|
rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
|
|
lhsRealIsInfWithSignTimesRhsImag));
|
|
|
|
// Case 3: Finite numerator, infinite denominator.
|
|
Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
|
|
Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
|
|
Value lhsFinite =
|
|
rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
|
|
Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
|
|
Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
|
|
Value rhsInfinite =
|
|
rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
|
|
Value finiteNumInfiniteDenom =
|
|
rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
|
|
Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
|
|
loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
|
|
rhsReal);
|
|
Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
|
|
loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
|
|
rhsImag);
|
|
Value rhsRealIsInfWithSignTimesLhsReal =
|
|
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
|
|
Value rhsImagIsInfWithSignTimesLhsImag =
|
|
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
|
|
Value resultReal4 = rewriter.create<arith::MulFOp>(
|
|
loc, zero,
|
|
rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
|
|
rhsImagIsInfWithSignTimesLhsImag));
|
|
Value rhsRealIsInfWithSignTimesLhsImag =
|
|
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
|
|
Value rhsImagIsInfWithSignTimesLhsReal =
|
|
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
|
|
Value resultImag4 = rewriter.create<arith::MulFOp>(
|
|
loc, zero,
|
|
rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
|
|
rhsImagIsInfWithSignTimesLhsReal));
|
|
|
|
Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
|
|
Value resultReal = rewriter.create<arith::SelectOp>(
|
|
loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
|
|
Value resultImag = rewriter.create<arith::SelectOp>(
|
|
loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
|
|
Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
|
|
loc, finiteNumInfiniteDenom, resultReal4, resultReal);
|
|
Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
|
|
loc, finiteNumInfiniteDenom, resultImag4, resultImag);
|
|
Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
|
|
loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
|
|
Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
|
|
loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
|
|
Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
|
|
loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
|
|
Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
|
|
loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
|
|
|
|
Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::UNO, resultReal, zero);
|
|
Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
|
|
loc, arith::CmpFPredicate::UNO, resultImag, zero);
|
|
Value resultIsNaN =
|
|
rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
|
|
Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
|
|
loc, resultIsNaN, resultRealSpecialCase1, resultReal);
|
|
Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
|
|
loc, resultIsNaN, resultImagSpecialCase1, resultImag);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(
|
|
op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
|
|
using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = adaptor.getComplex().getType().cast<ComplexType>();
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
|
|
Value real =
|
|
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
|
|
Value imag =
|
|
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
|
|
Value expReal = rewriter.create<math::ExpOp>(loc, real);
|
|
Value cosImag = rewriter.create<math::CosOp>(loc, imag);
|
|
Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
|
|
Value sinImag = rewriter.create<math::SinOp>(loc, imag);
|
|
Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
|
|
using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = adaptor.getComplex().getType().cast<ComplexType>();
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
|
|
|
|
Value real = b.create<complex::ReOp>(elementType, exp);
|
|
Value one = b.create<arith::ConstantOp>(elementType,
|
|
b.getFloatAttr(elementType, 1));
|
|
Value realMinusOne = b.create<arith::SubFOp>(real, one);
|
|
Value imag = b.create<complex::ImOp>(elementType, exp);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
|
|
imag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
|
|
using OpConversionPattern<complex::LogOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = adaptor.getComplex().getType().cast<ComplexType>();
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
|
|
Value resultReal = b.create<math::LogOp>(elementType, abs);
|
|
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
|
|
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
|
|
Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
|
|
using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = adaptor.getComplex().getType().cast<ComplexType>();
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
|
|
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
|
|
|
|
Value half = b.create<arith::ConstantOp>(elementType,
|
|
b.getFloatAttr(elementType, 0.5));
|
|
Value one = b.create<arith::ConstantOp>(elementType,
|
|
b.getFloatAttr(elementType, 1));
|
|
Value two = b.create<arith::ConstantOp>(elementType,
|
|
b.getFloatAttr(elementType, 2));
|
|
|
|
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
|
|
// log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
|
|
// log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
|
|
Value sumSq = b.create<arith::MulFOp>(real, real);
|
|
sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(real, two));
|
|
sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(imag, imag));
|
|
Value logSumSq = b.create<math::Log1pOp>(elementType, sumSq);
|
|
Value resultReal = b.create<arith::MulFOp>(logSumSq, half);
|
|
|
|
Value realPlusOne = b.create<arith::AddFOp>(real, one);
|
|
|
|
Value resultImag = b.create<math::Atan2Op>(elementType, imag, realPlusOne);
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
|
|
using OpConversionPattern<complex::MulOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
auto type = adaptor.getLhs().getType().cast<ComplexType>();
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
|
|
Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
|
|
Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal);
|
|
Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
|
|
Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag);
|
|
Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
|
|
Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal);
|
|
Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
|
|
Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag);
|
|
|
|
Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
|
|
Value lhsRealTimesRhsRealAbs = b.create<math::AbsFOp>(lhsRealTimesRhsReal);
|
|
Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
|
|
Value lhsImagTimesRhsImagAbs = b.create<math::AbsFOp>(lhsImagTimesRhsImag);
|
|
Value real =
|
|
b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
|
|
|
|
Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
|
|
Value lhsImagTimesRhsRealAbs = b.create<math::AbsFOp>(lhsImagTimesRhsReal);
|
|
Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
|
|
Value lhsRealTimesRhsImagAbs = b.create<math::AbsFOp>(lhsRealTimesRhsImag);
|
|
Value imag =
|
|
b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
|
|
|
|
// Handle cases where the "naive" calculation results in NaN values.
|
|
Value realIsNan =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
|
|
Value imagIsNan =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
|
|
Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
|
|
|
|
Value inf = b.create<arith::ConstantOp>(
|
|
elementType,
|
|
b.getFloatAttr(elementType,
|
|
APFloat::getInf(elementType.getFloatSemantics())));
|
|
|
|
// Case 1. `lhsReal` or `lhsImag` are infinite.
|
|
Value lhsRealIsInf =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
|
|
Value lhsImagIsInf =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
|
|
Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
|
|
Value rhsRealIsNan =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
|
|
Value rhsImagIsNan =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
|
|
Value zero =
|
|
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
|
|
Value one = b.create<arith::ConstantOp>(elementType,
|
|
b.getFloatAttr(elementType, 1));
|
|
Value lhsRealIsInfFloat =
|
|
b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
|
|
lhsReal = b.create<arith::SelectOp>(
|
|
lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
|
|
lhsReal);
|
|
Value lhsImagIsInfFloat =
|
|
b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
|
|
lhsImag = b.create<arith::SelectOp>(
|
|
lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
|
|
lhsImag);
|
|
Value lhsIsInfAndRhsRealIsNan =
|
|
b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
|
|
rhsReal = b.create<arith::SelectOp>(
|
|
lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
|
|
rhsReal);
|
|
Value lhsIsInfAndRhsImagIsNan =
|
|
b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
|
|
rhsImag = b.create<arith::SelectOp>(
|
|
lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
|
|
rhsImag);
|
|
|
|
// Case 2. `rhsReal` or `rhsImag` are infinite.
|
|
Value rhsRealIsInf =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
|
|
Value rhsImagIsInf =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
|
|
Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
|
|
Value lhsRealIsNan =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
|
|
Value lhsImagIsNan =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
|
|
Value rhsRealIsInfFloat =
|
|
b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
|
|
rhsReal = b.create<arith::SelectOp>(
|
|
rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
|
|
rhsReal);
|
|
Value rhsImagIsInfFloat =
|
|
b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
|
|
rhsImag = b.create<arith::SelectOp>(
|
|
rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
|
|
rhsImag);
|
|
Value rhsIsInfAndLhsRealIsNan =
|
|
b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
|
|
lhsReal = b.create<arith::SelectOp>(
|
|
rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
|
|
lhsReal);
|
|
Value rhsIsInfAndLhsImagIsNan =
|
|
b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
|
|
lhsImag = b.create<arith::SelectOp>(
|
|
rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
|
|
lhsImag);
|
|
Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
|
|
|
|
// Case 3. One of the pairwise products of left hand side with right hand
|
|
// side is infinite.
|
|
Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
|
|
arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
|
|
Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
|
|
arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
|
|
Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
|
|
lhsImagTimesRhsImagIsInf);
|
|
Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
|
|
arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
|
|
isSpecialCase =
|
|
b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
|
|
Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
|
|
arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
|
|
isSpecialCase =
|
|
b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
|
|
Type i1Type = b.getI1Type();
|
|
Value notRecalc = b.create<arith::XOrIOp>(
|
|
recalc,
|
|
b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
|
|
isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
|
|
Value isSpecialCaseAndLhsRealIsNan =
|
|
b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
|
|
lhsReal = b.create<arith::SelectOp>(
|
|
isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
|
|
lhsReal);
|
|
Value isSpecialCaseAndLhsImagIsNan =
|
|
b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
|
|
lhsImag = b.create<arith::SelectOp>(
|
|
isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
|
|
lhsImag);
|
|
Value isSpecialCaseAndRhsRealIsNan =
|
|
b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
|
|
rhsReal = b.create<arith::SelectOp>(
|
|
isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
|
|
rhsReal);
|
|
Value isSpecialCaseAndRhsImagIsNan =
|
|
b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
|
|
rhsImag = b.create<arith::SelectOp>(
|
|
isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
|
|
rhsImag);
|
|
recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
|
|
recalc = b.create<arith::AndIOp>(isNan, recalc);
|
|
|
|
// Recalculate real part.
|
|
lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
|
|
lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
|
|
Value newReal =
|
|
b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
|
|
real = b.create<arith::SelectOp>(
|
|
recalc, b.create<arith::MulFOp>(inf, newReal), real);
|
|
|
|
// Recalculate imag part.
|
|
lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
|
|
lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
|
|
Value newImag =
|
|
b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
|
|
imag = b.create<arith::SelectOp>(
|
|
recalc, b.create<arith::MulFOp>(inf, newImag), imag);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
|
|
using OpConversionPattern<complex::NegOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = adaptor.getComplex().getType().cast<ComplexType>();
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
|
|
Value real =
|
|
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
|
|
Value imag =
|
|
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
|
|
Value negReal = rewriter.create<arith::NegFOp>(loc, real);
|
|
Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
|
|
using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
|
|
|
|
std::pair<Value, Value>
|
|
combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
|
|
Value cos, ConversionPatternRewriter &rewriter) const override {
|
|
// Complex sine is defined as;
|
|
// sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
|
|
// Plugging in:
|
|
// exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
|
|
// exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
|
|
// and defining t := exp(y)
|
|
// We get:
|
|
// Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
|
|
// Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
|
|
Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
|
|
Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
|
|
Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
|
|
Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
|
|
return {resultReal, resultImag};
|
|
}
|
|
};
|
|
|
|
// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
|
|
struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
|
|
using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
auto type = op.getType().cast<ComplexType>();
|
|
Type elementType = type.getElementType();
|
|
Value arg = adaptor.getComplex();
|
|
|
|
Value zero =
|
|
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
|
|
|
|
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
|
|
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
|
|
|
|
Value absLhs = b.create<math::AbsFOp>(real);
|
|
Value absArg = b.create<complex::AbsOp>(elementType, arg);
|
|
Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
|
|
|
|
Value half = b.create<arith::ConstantOp>(elementType,
|
|
b.getFloatAttr(elementType, 0.5));
|
|
Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
|
|
Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
|
|
|
|
Value realIsNegative =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
|
|
Value imagIsNegative =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
|
|
|
|
Value resultReal = sqrtAddAbs;
|
|
|
|
Value imagDivTwoResultReal = b.create<arith::DivFOp>(
|
|
imag, b.create<arith::AddFOp>(resultReal, resultReal));
|
|
|
|
Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
|
|
|
|
Value resultImag = b.create<arith::SelectOp>(
|
|
realIsNegative,
|
|
b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
|
|
resultReal),
|
|
imagDivTwoResultReal);
|
|
|
|
resultReal = b.create<arith::SelectOp>(
|
|
realIsNegative,
|
|
b.create<arith::DivFOp>(
|
|
imag, b.create<arith::AddFOp>(resultImag, resultImag)),
|
|
resultReal);
|
|
|
|
Value realIsZero =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
|
|
Value imagIsZero =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
|
|
Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
|
|
|
|
resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
|
|
resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
|
|
using OpConversionPattern<complex::SignOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = adaptor.getComplex().getType().cast<ComplexType>();
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
|
|
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
|
|
Value zero =
|
|
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
|
|
Value realIsZero =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
|
|
Value imagIsZero =
|
|
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
|
|
Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
|
|
auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
|
|
Value realSign = b.create<arith::DivFOp>(real, abs);
|
|
Value imagSign = b.create<arith::DivFOp>(imag, abs);
|
|
Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
|
|
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
|
|
adaptor.getComplex(), sign);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
|
|
using OpConversionPattern<complex::TanOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex());
|
|
Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex());
|
|
rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
|
|
using OpConversionPattern<complex::TanhOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = adaptor.getComplex().getType().cast<ComplexType>();
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
|
|
// The hyperbolic tangent for complex number can be calculated as follows.
|
|
// tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
|
|
// See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number
|
|
Value real =
|
|
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
|
|
Value imag =
|
|
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
|
|
Value tanhA = rewriter.create<math::TanhOp>(loc, real);
|
|
Value cosB = rewriter.create<math::CosOp>(loc, imag);
|
|
Value sinB = rewriter.create<math::SinOp>(loc, imag);
|
|
Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
|
|
Value numerator =
|
|
rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
|
|
Value one = rewriter.create<arith::ConstantOp>(
|
|
loc, elementType, rewriter.getFloatAttr(elementType, 1));
|
|
Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
|
|
Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
|
|
rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
|
|
using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = adaptor.getComplex().getType().cast<ComplexType>();
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
Value real =
|
|
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
|
|
Value imag =
|
|
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
|
|
Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Coverts x^y = (a+bi)^(c+di) to
|
|
/// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
|
|
/// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
|
|
static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
|
|
ComplexType type, Value a, Value b, Value c,
|
|
Value d) {
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
|
|
// Compute (a*a+b*b)^(0.5c).
|
|
Value aaPbb = builder.create<arith::AddFOp>(
|
|
builder.create<arith::MulFOp>(a, a), builder.create<arith::MulFOp>(b, b));
|
|
Value half = builder.create<arith::ConstantOp>(
|
|
elementType, builder.getFloatAttr(elementType, 0.5));
|
|
Value halfC = builder.create<arith::MulFOp>(half, c);
|
|
Value aaPbbTohalfC = builder.create<math::PowFOp>(aaPbb, halfC);
|
|
|
|
// Compute exp(-d*atan2(b,a)).
|
|
Value negD = builder.create<arith::NegFOp>(d);
|
|
Value argX = builder.create<math::Atan2Op>(b, a);
|
|
Value negDArgX = builder.create<arith::MulFOp>(negD, argX);
|
|
Value eToNegDArgX = builder.create<math::ExpOp>(negDArgX);
|
|
|
|
// Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)).
|
|
Value coeff = builder.create<arith::MulFOp>(aaPbbTohalfC, eToNegDArgX);
|
|
|
|
// Compute c*atan2(b,a)+0.5d*ln(a*a+b*b).
|
|
Value lnAaPbb = builder.create<math::LogOp>(aaPbb);
|
|
Value halfD = builder.create<arith::MulFOp>(half, d);
|
|
Value q = builder.create<arith::AddFOp>(
|
|
builder.create<arith::MulFOp>(c, argX),
|
|
builder.create<arith::MulFOp>(halfD, lnAaPbb));
|
|
|
|
Value cosQ = builder.create<math::CosOp>(q);
|
|
Value sinQ = builder.create<math::SinOp>(q);
|
|
Value zero = builder.create<arith::ConstantOp>(
|
|
elementType, builder.getFloatAttr(elementType, 0));
|
|
Value one = builder.create<arith::ConstantOp>(
|
|
elementType, builder.getFloatAttr(elementType, 1));
|
|
|
|
Value xEqZero =
|
|
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, aaPbb, zero);
|
|
Value yGeZero = builder.create<arith::AndIOp>(
|
|
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, c, zero),
|
|
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero));
|
|
Value cEqZero =
|
|
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero);
|
|
Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
|
|
Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
|
|
Value complexOther = builder.create<complex::CreateOp>(
|
|
type, builder.create<arith::MulFOp>(coeff, cosQ),
|
|
builder.create<arith::MulFOp>(coeff, sinQ));
|
|
|
|
// x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see
|
|
// Branch Cuts for Complex Elementary Functions or Much Ado About
|
|
// Nothing's Sign Bit, W. Kahan, Section 10.
|
|
return builder.create<arith::SelectOp>(
|
|
builder.create<arith::AndIOp>(xEqZero, yGeZero),
|
|
builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero),
|
|
complexOther);
|
|
}
|
|
|
|
struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
|
|
using OpConversionPattern<complex::PowOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
|
|
auto type = adaptor.getLhs().getType().cast<ComplexType>();
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
|
|
Value a = builder.create<complex::ReOp>(elementType, adaptor.getLhs());
|
|
Value b = builder.create<complex::ImOp>(elementType, adaptor.getLhs());
|
|
Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
|
|
Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
|
|
|
|
rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
|
|
using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
|
|
auto type = adaptor.getComplex().getType().cast<ComplexType>();
|
|
auto elementType = type.getElementType().cast<FloatType>();
|
|
|
|
Value a = builder.create<complex::ReOp>(elementType, adaptor.getComplex());
|
|
Value b = builder.create<complex::ImOp>(elementType, adaptor.getComplex());
|
|
Value c = builder.create<arith::ConstantOp>(
|
|
elementType, builder.getFloatAttr(elementType, -0.5));
|
|
Value d = builder.create<arith::ConstantOp>(
|
|
elementType, builder.getFloatAttr(elementType, 0));
|
|
|
|
rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
|
|
using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = op.getType();
|
|
|
|
Value real =
|
|
rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
|
|
Value imag =
|
|
rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
|
|
|
|
rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateComplexToStandardConversionPatterns(
|
|
RewritePatternSet &patterns) {
|
|
// clang-format off
|
|
patterns.add<
|
|
AbsOpConversion,
|
|
AngleOpConversion,
|
|
Atan2OpConversion,
|
|
BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
|
|
BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
|
|
ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
|
|
ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
|
|
ConjOpConversion,
|
|
CosOpConversion,
|
|
DivOpConversion,
|
|
ExpOpConversion,
|
|
Expm1OpConversion,
|
|
Log1pOpConversion,
|
|
LogOpConversion,
|
|
MulOpConversion,
|
|
NegOpConversion,
|
|
SignOpConversion,
|
|
SinOpConversion,
|
|
SqrtOpConversion,
|
|
TanOpConversion,
|
|
TanhOpConversion,
|
|
PowOpConversion,
|
|
RsqrtOpConversion
|
|
>(patterns.getContext());
|
|
// clang-format on
|
|
}
|
|
|
|
namespace {
|
|
struct ConvertComplexToStandardPass
|
|
: public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
|
|
void runOnOperation() override;
|
|
};
|
|
|
|
void ConvertComplexToStandardPass::runOnOperation() {
|
|
// Convert to the Standard dialect using the converter defined above.
|
|
RewritePatternSet patterns(&getContext());
|
|
populateComplexToStandardConversionPatterns(patterns);
|
|
|
|
ConversionTarget target(getContext());
|
|
target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
|
|
target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
|
|
if (failed(
|
|
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> mlir::createConvertComplexToStandardPass() {
|
|
return std::make_unique<ConvertComplexToStandardPass>();
|
|
}
|