llvm-project/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp

300 lines
11 KiB
C++

//===- QuantDialectBytecode.cpp - Quant Bytecode Implementation
//------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "QuantDialectBytecode.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::quant;
//===----------------------------------------------------------------------===//
// Encoding
//===----------------------------------------------------------------------===//
namespace {
namespace quant_encoding {
/// This enum contains marker codes used to indicate which type is currently
/// being decoded, and how it should be decoded. The order of these codes should
/// generally be unchanged, as any changes will inevitably break compatibility
/// with older bytecode.
enum TypeCode {
/// AnyQuantizedType {
/// flags: varint
/// storageType: Type
/// storageTypeMin: svarint
/// storageTypeMax: svarint
/// }
///
kAnyQuantizedType = 1,
/// AnyQuantizedType {
/// flags: varint
/// storageType: Type
/// expressedType: Type
/// storageTypeMin: svarint
/// storageTypeMax: svarint
/// }
///
kAnyQuantizedTypeWithExpressedType = 2,
/// CalibratedQuantizedType {
/// expressedType: Type
/// min: APFloat
/// max: APFloat
/// }
///
kCalibratedQuantizedType = 3,
/// UniformQuantizedType {
/// flags: varint
/// storageType: Type
/// expressedType: Type
/// scale: APFloat
/// zeroPoint: svarint
/// storageTypeMin: svarint
/// storageTypeMax: svarint
/// }
///
kUniformQuantizedType = 4,
/// UniformQuantizedPerAxisType {
/// flags: varint
/// storageType: Type
/// expressedType: Type
/// quantizedDimension: varint
/// storageTypeMin: svarint
/// storageTypeMax: svarint
/// scale: APFloat[]
/// zeroPoint: svarint[]
/// }
///
kUniformQuantizedPerAxisType = 5,
};
} // namespace quant_encoding
} // namespace
//===----------------------------------------------------------------------===//
// QuantDialectBytecodeInterface
//===----------------------------------------------------------------------===//
namespace {
/// This class implements the bytecode interface for the Quant dialect.
struct QuantDialectBytecodeInterface : public BytecodeDialectInterface {
QuantDialectBytecodeInterface(Dialect *dialect)
: BytecodeDialectInterface(dialect) {}
//===--------------------------------------------------------------------===//
// Types
Type readType(DialectBytecodeReader &reader) const override;
LogicalResult writeType(Type type,
DialectBytecodeWriter &writer) const override;
AnyQuantizedType readAnyQuantizedType(bool withExpressedType,
DialectBytecodeReader &reader) const;
void write(AnyQuantizedType type, DialectBytecodeWriter &writer) const;
CalibratedQuantizedType
readCalibratedQuantizedType(DialectBytecodeReader &reader) const;
void write(CalibratedQuantizedType type, DialectBytecodeWriter &writer) const;
UniformQuantizedType
readUniformQuantizedType(DialectBytecodeReader &reader) const;
void write(UniformQuantizedType type, DialectBytecodeWriter &writer) const;
UniformQuantizedPerAxisType
readUniformQuantizedPerAxisType(DialectBytecodeReader &reader) const;
void write(UniformQuantizedPerAxisType type,
DialectBytecodeWriter &writer) const;
};
} // namespace
void quant::detail::addBytecodeInterface(QuantizationDialect *dialect) {
dialect->addInterfaces<QuantDialectBytecodeInterface>();
}
//===----------------------------------------------------------------------===//
// Types
//===----------------------------------------------------------------------===//
Type QuantDialectBytecodeInterface::readType(
DialectBytecodeReader &reader) const {
uint64_t code;
if (failed(reader.readVarInt(code)))
return Type();
switch (code) {
case quant_encoding::kAnyQuantizedType:
return readAnyQuantizedType(/*withExpressedType=*/false, reader);
case quant_encoding::kAnyQuantizedTypeWithExpressedType:
return readAnyQuantizedType(/*withExpressedType=*/true, reader);
case quant_encoding::kCalibratedQuantizedType:
return readCalibratedQuantizedType(reader);
case quant_encoding::kUniformQuantizedType:
return readUniformQuantizedType(reader);
case quant_encoding::kUniformQuantizedPerAxisType:
return readUniformQuantizedPerAxisType(reader);
default:
reader.emitError() << "unknown builtin type code: " << code;
return Type();
}
}
LogicalResult
QuantDialectBytecodeInterface::writeType(Type type,
DialectBytecodeWriter &writer) const {
return TypeSwitch<Type, LogicalResult>(type)
.Case<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType>(
[&](auto attr) { return write(attr, writer), success(); })
.Default([&](Type) { return failure(); });
}
AnyQuantizedType QuantDialectBytecodeInterface::readAnyQuantizedType(
bool withExpressedType, DialectBytecodeReader &reader) const {
uint64_t flags;
Type storageType, expressedType;
int64_t storageTypeMin, storageTypeMax;
if (failed(reader.readVarInt(flags)) ||
failed(reader.readType(storageType)) ||
(withExpressedType && failed(reader.readType(expressedType))) ||
failed(reader.readSignedVarInt(storageTypeMin)) ||
failed(reader.readSignedVarInt(storageTypeMax)))
return reader.emitError("invalid AnyQuantizedType"), AnyQuantizedType();
return AnyQuantizedType::get(flags, storageType, expressedType,
storageTypeMin, storageTypeMax);
}
void QuantDialectBytecodeInterface::write(AnyQuantizedType type,
DialectBytecodeWriter &writer) const {
if (type.getExpressedType())
writer.writeVarInt(quant_encoding::kAnyQuantizedTypeWithExpressedType);
else
writer.writeVarInt(quant_encoding::kAnyQuantizedType);
writer.writeVarInt(type.getFlags());
writer.writeType(type.getStorageType());
if (type.getExpressedType())
writer.writeType(type.getExpressedType());
writer.writeSignedVarInt(type.getStorageTypeMin());
writer.writeSignedVarInt(type.getStorageTypeMax());
}
CalibratedQuantizedType
QuantDialectBytecodeInterface::readCalibratedQuantizedType(
DialectBytecodeReader &reader) const {
Type expressedType;
FailureOr<APFloat> min, max;
if (failed(reader.readType(expressedType)) ||
failed(min = reader.readAPFloatWithKnownSemantics(
llvm::APFloat::IEEEdouble())) ||
failed(max = reader.readAPFloatWithKnownSemantics(
llvm::APFloat::IEEEdouble())))
return reader.emitError("invalid CalibratedQuantizedType"),
CalibratedQuantizedType();
return CalibratedQuantizedType::get(expressedType,
min.value().convertToDouble(),
max.value().convertToDouble());
}
void QuantDialectBytecodeInterface::write(CalibratedQuantizedType type,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(quant_encoding::kCalibratedQuantizedType);
writer.writeType(type.getExpressedType());
writer.writeAPFloatWithKnownSemantics(APFloat(type.getMin()));
writer.writeAPFloatWithKnownSemantics(APFloat(type.getMax()));
}
UniformQuantizedType QuantDialectBytecodeInterface::readUniformQuantizedType(
DialectBytecodeReader &reader) const {
uint64_t flags;
Type storageType, expressedType;
FailureOr<APFloat> scale;
int64_t zeroPoint, storageTypeMin, storageTypeMax;
if (failed(reader.readVarInt(flags)) ||
failed(reader.readType(storageType)) ||
failed(reader.readType(expressedType)) ||
failed(scale = reader.readAPFloatWithKnownSemantics(
llvm::APFloat::IEEEdouble())) ||
failed(reader.readSignedVarInt(zeroPoint)) ||
failed(reader.readSignedVarInt(storageTypeMin)) ||
failed(reader.readSignedVarInt(storageTypeMax)))
return reader.emitError("invalid UniformQuantizedType"),
UniformQuantizedType();
return UniformQuantizedType::get(flags, storageType, expressedType,
scale.value().convertToDouble(), zeroPoint,
storageTypeMin, storageTypeMax);
}
void QuantDialectBytecodeInterface::write(UniformQuantizedType type,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(quant_encoding::kUniformQuantizedType);
writer.writeVarInt(type.getFlags());
writer.writeType(type.getStorageType());
writer.writeType(type.getExpressedType());
writer.writeAPFloatWithKnownSemantics(APFloat(type.getScale()));
writer.writeSignedVarInt(type.getZeroPoint());
writer.writeSignedVarInt(type.getStorageTypeMin());
writer.writeSignedVarInt(type.getStorageTypeMax());
}
UniformQuantizedPerAxisType
QuantDialectBytecodeInterface::readUniformQuantizedPerAxisType(
DialectBytecodeReader &reader) const {
uint64_t flags;
Type storageType, expressedType;
SmallVector<double> scales;
SmallVector<int64_t> zeroPoints;
uint64_t quantizedDimension;
int64_t storageTypeMin, storageTypeMax;
auto scalesRead = [&](double &val) -> LogicalResult {
FailureOr<APFloat> fl =
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble());
if (succeeded(fl)) {
val = fl.value().convertToDouble();
return success();
}
return failure();
};
if (failed(reader.readVarInt(flags)) ||
failed(reader.readType(storageType)) ||
failed(reader.readType(expressedType)) ||
failed(reader.readList(scales, scalesRead)) ||
failed(reader.readSignedVarInts(zeroPoints)) ||
failed(reader.readVarInt(quantizedDimension)) ||
failed(reader.readSignedVarInt(storageTypeMin)) ||
failed(reader.readSignedVarInt(storageTypeMax)))
return reader.emitError("invalid UniformQuantizedPerAxisType"),
UniformQuantizedPerAxisType();
return UniformQuantizedPerAxisType::get(
flags, storageType, expressedType, scales, zeroPoints,
(int32_t)quantizedDimension, storageTypeMin, storageTypeMax);
}
void QuantDialectBytecodeInterface::write(UniformQuantizedPerAxisType type,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(quant_encoding::kUniformQuantizedType);
writer.writeVarInt(type.getFlags());
writer.writeType(type.getStorageType());
writer.writeType(type.getExpressedType());
writer.writeList(type.getScales(), [&](double val) {
writer.writeAPFloatWithKnownSemantics(APFloat(val));
});
writer.writeSignedVarInts(type.getZeroPoints());
writer.writeVarInt(type.getQuantizedDimension());
writer.writeSignedVarInt(type.getStorageTypeMin());
writer.writeSignedVarInt(type.getStorageTypeMax());
}