556 lines
14 KiB
Python
556 lines
14 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
import gc
|
|
|
|
from mlir.ir import *
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
gc.collect()
|
|
assert Context._get_live_count() == 0
|
|
return f
|
|
|
|
|
|
# CHECK-LABEL: TEST: testParsePrint
|
|
@run
|
|
def testParsePrint():
|
|
with Context() as ctx:
|
|
t = Attribute.parse('"hello"')
|
|
assert t.context is ctx
|
|
ctx = None
|
|
gc.collect()
|
|
# CHECK: "hello"
|
|
print(str(t))
|
|
# CHECK: Attribute("hello")
|
|
print(repr(t))
|
|
|
|
|
|
# CHECK-LABEL: TEST: testParseError
|
|
# TODO: Hook the diagnostic manager to capture a more meaningful error
|
|
# message.
|
|
@run
|
|
def testParseError():
|
|
with Context():
|
|
try:
|
|
t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
|
|
except ValueError as e:
|
|
# CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
|
|
print("testParseError:", e)
|
|
else:
|
|
print("Exception not produced")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAttrEq
|
|
@run
|
|
def testAttrEq():
|
|
with Context():
|
|
a1 = Attribute.parse('"attr1"')
|
|
a2 = Attribute.parse('"attr2"')
|
|
a3 = Attribute.parse('"attr1"')
|
|
# CHECK: a1 == a1: True
|
|
print("a1 == a1:", a1 == a1)
|
|
# CHECK: a1 == a2: False
|
|
print("a1 == a2:", a1 == a2)
|
|
# CHECK: a1 == a3: True
|
|
print("a1 == a3:", a1 == a3)
|
|
# CHECK: a1 == None: False
|
|
print("a1 == None:", a1 == None)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAttrHash
|
|
@run
|
|
def testAttrHash():
|
|
with Context():
|
|
a1 = Attribute.parse('"attr1"')
|
|
a2 = Attribute.parse('"attr2"')
|
|
a3 = Attribute.parse('"attr1"')
|
|
# CHECK: hash(a1) == hash(a3): True
|
|
print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__())
|
|
|
|
s = set()
|
|
s.add(a1)
|
|
s.add(a2)
|
|
s.add(a3)
|
|
# CHECK: len(s): 2
|
|
print("len(s): ", len(s))
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAttrCast
|
|
@run
|
|
def testAttrCast():
|
|
with Context():
|
|
a1 = Attribute.parse('"attr1"')
|
|
a2 = Attribute(a1)
|
|
# CHECK: a1 == a2: True
|
|
print("a1 == a2:", a1 == a2)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAttrIsInstance
|
|
@run
|
|
def testAttrIsInstance():
|
|
with Context():
|
|
a1 = Attribute.parse("42")
|
|
a2 = Attribute.parse("[42]")
|
|
assert IntegerAttr.isinstance(a1)
|
|
assert not IntegerAttr.isinstance(a2)
|
|
assert not ArrayAttr.isinstance(a1)
|
|
assert ArrayAttr.isinstance(a2)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
|
|
@run
|
|
def testAttrEqDoesNotRaise():
|
|
with Context():
|
|
a1 = Attribute.parse('"attr1"')
|
|
not_an_attr = "foo"
|
|
# CHECK: False
|
|
print(a1 == not_an_attr)
|
|
# CHECK: False
|
|
print(a1 == None)
|
|
# CHECK: True
|
|
print(a1 != None)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAttrCapsule
|
|
@run
|
|
def testAttrCapsule():
|
|
with Context() as ctx:
|
|
a1 = Attribute.parse('"attr1"')
|
|
# CHECK: mlir.ir.Attribute._CAPIPtr
|
|
attr_capsule = a1._CAPIPtr
|
|
print(attr_capsule)
|
|
a2 = Attribute._CAPICreate(attr_capsule)
|
|
assert a2 == a1
|
|
assert a2.context is ctx
|
|
|
|
|
|
# CHECK-LABEL: TEST: testStandardAttrCasts
|
|
@run
|
|
def testStandardAttrCasts():
|
|
with Context():
|
|
a1 = Attribute.parse('"attr1"')
|
|
astr = StringAttr(a1)
|
|
aself = StringAttr(astr)
|
|
# CHECK: Attribute("attr1")
|
|
print(repr(astr))
|
|
try:
|
|
tillegal = StringAttr(Attribute.parse("1.0"))
|
|
except ValueError as e:
|
|
# CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
|
|
print("ValueError:", e)
|
|
else:
|
|
print("Exception not produced")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAffineMapAttr
|
|
@run
|
|
def testAffineMapAttr():
|
|
with Context() as ctx:
|
|
d0 = AffineDimExpr.get(0)
|
|
d1 = AffineDimExpr.get(1)
|
|
c2 = AffineConstantExpr.get(2)
|
|
map0 = AffineMap.get(2, 3, [])
|
|
|
|
# CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()>
|
|
attr_built = AffineMapAttr.get(map0)
|
|
print(str(attr_built))
|
|
|
|
attr_parsed = Attribute.parse(str(attr_built))
|
|
assert attr_built == attr_parsed
|
|
|
|
|
|
# CHECK-LABEL: TEST: testFloatAttr
|
|
@run
|
|
def testFloatAttr():
|
|
with Context(), Location.unknown():
|
|
fattr = FloatAttr(Attribute.parse("42.0 : f32"))
|
|
# CHECK: fattr value: 42.0
|
|
print("fattr value:", fattr.value)
|
|
|
|
# Test factory methods.
|
|
# CHECK: default_get: 4.200000e+01 : f32
|
|
print("default_get:", FloatAttr.get(
|
|
F32Type.get(), 42.0))
|
|
# CHECK: f32_get: 4.200000e+01 : f32
|
|
print("f32_get:", FloatAttr.get_f32(42.0))
|
|
# CHECK: f64_get: 4.200000e+01 : f64
|
|
print("f64_get:", FloatAttr.get_f64(42.0))
|
|
try:
|
|
fattr_invalid = FloatAttr.get(
|
|
IntegerType.get_signless(32), 42)
|
|
except ValueError as e:
|
|
# CHECK: invalid 'Type(i32)' and expected floating point type.
|
|
print(e)
|
|
else:
|
|
print("Exception not produced")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testIntegerAttr
|
|
@run
|
|
def testIntegerAttr():
|
|
with Context() as ctx:
|
|
i_attr = IntegerAttr(Attribute.parse("42"))
|
|
# CHECK: i_attr value: 42
|
|
print("i_attr value:", i_attr.value)
|
|
# CHECK: i_attr type: i64
|
|
print("i_attr type:", i_attr.type)
|
|
si_attr = IntegerAttr(Attribute.parse("-1 : si8"))
|
|
# CHECK: si_attr value: -1
|
|
print("si_attr value:", si_attr.value)
|
|
ui_attr = IntegerAttr(Attribute.parse("255 : ui8"))
|
|
# CHECK: ui_attr value: 255
|
|
print("ui_attr value:", ui_attr.value)
|
|
idx_attr = IntegerAttr(Attribute.parse("-1 : index"))
|
|
# CHECK: idx_attr value: -1
|
|
print("idx_attr value:", idx_attr.value)
|
|
|
|
# Test factory methods.
|
|
# CHECK: default_get: 42 : i32
|
|
print("default_get:", IntegerAttr.get(
|
|
IntegerType.get_signless(32), 42))
|
|
|
|
|
|
# CHECK-LABEL: TEST: testBoolAttr
|
|
@run
|
|
def testBoolAttr():
|
|
with Context() as ctx:
|
|
battr = BoolAttr(Attribute.parse("true"))
|
|
# CHECK: iattr value: True
|
|
print("iattr value:", battr.value)
|
|
|
|
# Test factory methods.
|
|
# CHECK: default_get: true
|
|
print("default_get:", BoolAttr.get(True))
|
|
|
|
|
|
# CHECK-LABEL: TEST: testFlatSymbolRefAttr
|
|
@run
|
|
def testFlatSymbolRefAttr():
|
|
with Context() as ctx:
|
|
sattr = FlatSymbolRefAttr(Attribute.parse('@symbol'))
|
|
# CHECK: symattr value: symbol
|
|
print("symattr value:", sattr.value)
|
|
|
|
# Test factory methods.
|
|
# CHECK: default_get: @foobar
|
|
print("default_get:", FlatSymbolRefAttr.get("foobar"))
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOpaqueAttr
|
|
@run
|
|
def testOpaqueAttr():
|
|
with Context() as ctx:
|
|
ctx.allow_unregistered_dialects = True
|
|
oattr = OpaqueAttr(Attribute.parse("#pytest_dummy.dummyattr<>"))
|
|
# CHECK: oattr value: pytest_dummy
|
|
print("oattr value:", oattr.dialect_namespace)
|
|
# CHECK: oattr value: dummyattr<>
|
|
print("oattr value:", oattr.data)
|
|
|
|
# Test factory methods.
|
|
# CHECK: default_get: #foobar<123>
|
|
print(
|
|
"default_get:",
|
|
OpaqueAttr.get("foobar", bytes("123", "utf-8"), NoneType.get()))
|
|
|
|
|
|
# CHECK-LABEL: TEST: testStringAttr
|
|
@run
|
|
def testStringAttr():
|
|
with Context() as ctx:
|
|
sattr = StringAttr(Attribute.parse('"stringattr"'))
|
|
# CHECK: sattr value: stringattr
|
|
print("sattr value:", sattr.value)
|
|
|
|
# Test factory methods.
|
|
# CHECK: default_get: "foobar"
|
|
print("default_get:", StringAttr.get("foobar"))
|
|
# CHECK: typed_get: "12345" : i32
|
|
print("typed_get:", StringAttr.get_typed(
|
|
IntegerType.get_signless(32), "12345"))
|
|
|
|
|
|
# CHECK-LABEL: TEST: testNamedAttr
|
|
@run
|
|
def testNamedAttr():
|
|
with Context():
|
|
a = Attribute.parse('"stringattr"')
|
|
named = a.get_named("foobar") # Note: under the small object threshold
|
|
# CHECK: attr: "stringattr"
|
|
print("attr:", named.attr)
|
|
# CHECK: name: foobar
|
|
print("name:", named.name)
|
|
# CHECK: named: NamedAttribute(foobar="stringattr")
|
|
print("named:", named)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testDenseIntAttr
|
|
@run
|
|
def testDenseIntAttr():
|
|
with Context():
|
|
raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
|
|
# CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
|
|
print("attr:", raw)
|
|
|
|
a = DenseIntElementsAttr(raw)
|
|
assert len(a) == 6
|
|
|
|
# CHECK: 0 1 2 3 4 5
|
|
for value in a:
|
|
print(value, end=" ")
|
|
print()
|
|
|
|
# CHECK: i32
|
|
print(ShapedType(a.type).element_type)
|
|
|
|
raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
|
|
# CHECK: attr: dense<[true, false, true, false]>
|
|
print("attr:", raw)
|
|
|
|
a = DenseIntElementsAttr(raw)
|
|
assert len(a) == 4
|
|
|
|
# CHECK: 1 0 1 0
|
|
for value in a:
|
|
print(value, end=" ")
|
|
print()
|
|
|
|
# CHECK: i1
|
|
print(ShapedType(a.type).element_type)
|
|
|
|
|
|
@run
|
|
def testDenseArrayGetItem():
|
|
def print_item(AttrClass, attr_asm):
|
|
attr = AttrClass(Attribute.parse(attr_asm))
|
|
print(f"{len(attr)}: {attr[0]}, {attr[1]}")
|
|
|
|
with Context():
|
|
# CHECK: 2: 0, 1
|
|
print_item(DenseBoolArrayAttr, "array<i1: false, true>")
|
|
# CHECK: 2: 2, 3
|
|
print_item(DenseI8ArrayAttr, "array<i8: 2, 3>")
|
|
# CHECK: 2: 4, 5
|
|
print_item(DenseI16ArrayAttr, "array<i16: 4, 5>")
|
|
# CHECK: 2: 6, 7
|
|
print_item(DenseI32ArrayAttr, "array<i32: 6, 7>")
|
|
# CHECK: 2: 8, 9
|
|
print_item(DenseI64ArrayAttr, "array<i64: 8, 9>")
|
|
# CHECK: 2: 1.{{0+}}, 2.{{0+}}
|
|
print_item(DenseF32ArrayAttr, "array<f32: 1.0, 2.0>")
|
|
# CHECK: 2: 3.{{0+}}, 4.{{0+}}
|
|
print_item(DenseF64ArrayAttr, "array<f64: 3.0, 4.0>")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testDenseIntAttrGetItem
|
|
@run
|
|
def testDenseIntAttrGetItem():
|
|
def print_item(attr_asm):
|
|
attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
|
|
dtype = ShapedType(attr.type).element_type
|
|
try:
|
|
item = attr[0]
|
|
print(f"{dtype}:", item)
|
|
except TypeError as e:
|
|
print(f"{dtype}:", e)
|
|
|
|
with Context():
|
|
# CHECK: i1: 1
|
|
print_item("dense<true> : tensor<i1>")
|
|
# CHECK: i8: 123
|
|
print_item("dense<123> : tensor<i8>")
|
|
# CHECK: i16: 123
|
|
print_item("dense<123> : tensor<i16>")
|
|
# CHECK: i32: 123
|
|
print_item("dense<123> : tensor<i32>")
|
|
# CHECK: i64: 123
|
|
print_item("dense<123> : tensor<i64>")
|
|
# CHECK: ui8: 123
|
|
print_item("dense<123> : tensor<ui8>")
|
|
# CHECK: ui16: 123
|
|
print_item("dense<123> : tensor<ui16>")
|
|
# CHECK: ui32: 123
|
|
print_item("dense<123> : tensor<ui32>")
|
|
# CHECK: ui64: 123
|
|
print_item("dense<123> : tensor<ui64>")
|
|
# CHECK: si8: -123
|
|
print_item("dense<-123> : tensor<si8>")
|
|
# CHECK: si16: -123
|
|
print_item("dense<-123> : tensor<si16>")
|
|
# CHECK: si32: -123
|
|
print_item("dense<-123> : tensor<si32>")
|
|
# CHECK: si64: -123
|
|
print_item("dense<-123> : tensor<si64>")
|
|
|
|
# CHECK: i7: Unsupported integer type
|
|
print_item("dense<123> : tensor<i7>")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testDenseFPAttr
|
|
@run
|
|
def testDenseFPAttr():
|
|
with Context():
|
|
raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
|
|
# CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
|
|
|
|
print("attr:", raw)
|
|
|
|
a = DenseFPElementsAttr(raw)
|
|
assert len(a) == 4
|
|
|
|
# CHECK: 0.0 1.0 2.0 3.0
|
|
for value in a:
|
|
print(value, end=" ")
|
|
print()
|
|
|
|
# CHECK: f32
|
|
print(ShapedType(a.type).element_type)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testDictAttr
|
|
@run
|
|
def testDictAttr():
|
|
with Context():
|
|
dict_attr = {
|
|
'stringattr': StringAttr.get('string'),
|
|
'integerattr' : IntegerAttr.get(
|
|
IntegerType.get_signless(32), 42)
|
|
}
|
|
|
|
a = DictAttr.get(dict_attr)
|
|
|
|
# CHECK attr: {integerattr = 42 : i32, stringattr = "string"}
|
|
print("attr:", a)
|
|
|
|
assert len(a) == 2
|
|
|
|
# CHECK: 42 : i32
|
|
print(a['integerattr'])
|
|
|
|
# CHECK: "string"
|
|
print(a['stringattr'])
|
|
|
|
# CHECK: True
|
|
print('stringattr' in a)
|
|
|
|
# CHECK: False
|
|
print('not_in_dict' in a)
|
|
|
|
# Check that exceptions are raised as expected.
|
|
try:
|
|
_ = a['does_not_exist']
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
assert False, "Exception not produced"
|
|
|
|
try:
|
|
_ = a[42]
|
|
except IndexError:
|
|
pass
|
|
else:
|
|
assert False, "expected IndexError on accessing an out-of-bounds attribute"
|
|
|
|
# CHECK "empty: {}"
|
|
print("empty: ", DictAttr.get())
|
|
|
|
|
|
# CHECK-LABEL: TEST: testTypeAttr
|
|
@run
|
|
def testTypeAttr():
|
|
with Context():
|
|
raw = Attribute.parse("vector<4xf32>")
|
|
# CHECK: attr: vector<4xf32>
|
|
print("attr:", raw)
|
|
type_attr = TypeAttr(raw)
|
|
# CHECK: f32
|
|
print(ShapedType(type_attr.value).element_type)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testArrayAttr
|
|
@run
|
|
def testArrayAttr():
|
|
with Context():
|
|
raw = Attribute.parse("[42, true, vector<4xf32>]")
|
|
# CHECK: attr: [42, true, vector<4xf32>]
|
|
print("raw attr:", raw)
|
|
# CHECK: - 42
|
|
# CHECK: - true
|
|
# CHECK: - vector<4xf32>
|
|
for attr in ArrayAttr(raw):
|
|
print("- ", attr)
|
|
|
|
with Context():
|
|
intAttr = Attribute.parse("42")
|
|
vecAttr = Attribute.parse("vector<4xf32>")
|
|
boolAttr = BoolAttr.get(True)
|
|
raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
|
|
# CHECK: attr: [vector<4xf32>, true, 42]
|
|
print("raw attr:", raw)
|
|
# CHECK: - vector<4xf32>
|
|
# CHECK: - true
|
|
# CHECK: - 42
|
|
arr = ArrayAttr(raw)
|
|
for attr in arr:
|
|
print("- ", attr)
|
|
# CHECK: attr[0]: vector<4xf32>
|
|
print("attr[0]:", arr[0])
|
|
# CHECK: attr[1]: true
|
|
print("attr[1]:", arr[1])
|
|
# CHECK: attr[2]: 42
|
|
print("attr[2]:", arr[2])
|
|
try:
|
|
print("attr[3]:", arr[3])
|
|
except IndexError as e:
|
|
# CHECK: Error: ArrayAttribute index out of range
|
|
print("Error: ", e)
|
|
with Context():
|
|
try:
|
|
ArrayAttr.get([None])
|
|
except RuntimeError as e:
|
|
# CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute
|
|
print("Error: ", e)
|
|
try:
|
|
ArrayAttr.get([42])
|
|
except RuntimeError as e:
|
|
# CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
|
|
print("Error: ", e)
|
|
|
|
with Context():
|
|
array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
|
|
array = array + [StringAttr.get("c")]
|
|
# CHECK: concat: ["a", "b", "c"]
|
|
print("concat: ", array)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testStridedLayoutAttr
|
|
@run
|
|
def testStridedLayoutAttr():
|
|
with Context():
|
|
attr = StridedLayoutAttr.get(42, [5, 7, 13])
|
|
# CHECK: strided<[5, 7, 13], offset: 42>
|
|
print(attr)
|
|
# CHECK: 42
|
|
print(attr.offset)
|
|
# CHECK: 3
|
|
print(len(attr.strides))
|
|
# CHECK: 5
|
|
print(attr.strides[0])
|
|
# CHECK: 7
|
|
print(attr.strides[1])
|
|
# CHECK: 13
|
|
print(attr.strides[2])
|
|
|
|
attr = StridedLayoutAttr.get_fully_dynamic(3)
|
|
dynamic = ShapedType.get_dynamic_stride_or_offset()
|
|
# CHECK: strided<[?, ?, ?], offset: ?>
|
|
print(attr)
|
|
# CHECK: offset is dynamic: True
|
|
print(f"offset is dynamic: {attr.offset == dynamic}")
|
|
# CHECK: rank: 3
|
|
print(f"rank: {len(attr.strides)}")
|
|
# CHECK: strides are dynamic: [True, True, True]
|
|
print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}")
|