242 lines
6.2 KiB
Python
242 lines
6.2 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
import gc
|
|
from mlir.ir import *
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
gc.collect()
|
|
assert Context._get_live_count() == 0
|
|
return f
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAffineMapCapsule
|
|
@run
|
|
def testAffineMapCapsule():
|
|
with Context() as ctx:
|
|
am1 = AffineMap.get_empty(ctx)
|
|
# CHECK: mlir.ir.AffineMap._CAPIPtr
|
|
affine_map_capsule = am1._CAPIPtr
|
|
print(affine_map_capsule)
|
|
am2 = AffineMap._CAPICreate(affine_map_capsule)
|
|
assert am2 == am1
|
|
assert am2.context is ctx
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAffineMapGet
|
|
@run
|
|
def testAffineMapGet():
|
|
with Context() as ctx:
|
|
d0 = AffineDimExpr.get(0)
|
|
d1 = AffineDimExpr.get(1)
|
|
c2 = AffineConstantExpr.get(2)
|
|
|
|
# CHECK: (d0, d1)[s0, s1, s2] -> ()
|
|
map0 = AffineMap.get(2, 3, [])
|
|
print(map0)
|
|
|
|
# CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
|
|
map1 = AffineMap.get(2, 3, [d1, c2])
|
|
print(map1)
|
|
|
|
# CHECK: () -> (2)
|
|
map2 = AffineMap.get(0, 0, [c2])
|
|
print(map2)
|
|
|
|
# CHECK: (d0, d1) -> (d0, d1)
|
|
map3 = AffineMap.get(2, 0, [d0, d1])
|
|
print(map3)
|
|
|
|
# CHECK: (d0, d1) -> (d1)
|
|
map4 = AffineMap.get(2, 0, [d1])
|
|
print(map4)
|
|
|
|
# CHECK: (d0, d1, d2) -> (d2, d0, d1)
|
|
map5 = AffineMap.get_permutation([2, 0, 1])
|
|
print(map5)
|
|
|
|
assert map1 == AffineMap.get(2, 3, [d1, c2])
|
|
assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
|
|
assert map2 == AffineMap.get_constant(2)
|
|
assert map3 == AffineMap.get_identity(2)
|
|
assert map4 == AffineMap.get_minor_identity(2, 1)
|
|
|
|
try:
|
|
AffineMap.get(1, 1, [1])
|
|
except RuntimeError as e:
|
|
# CHECK: Invalid expression when attempting to create an AffineMap
|
|
print(e)
|
|
|
|
try:
|
|
AffineMap.get(1, 1, [None])
|
|
except RuntimeError as e:
|
|
# CHECK: Invalid expression (None?) when attempting to create an AffineMap
|
|
print(e)
|
|
|
|
try:
|
|
AffineMap.get_permutation([1, 0, 1])
|
|
except RuntimeError as e:
|
|
# CHECK: Invalid permutation when attempting to create an AffineMap
|
|
print(e)
|
|
|
|
try:
|
|
map3.get_submap([42])
|
|
except ValueError as e:
|
|
# CHECK: result position out of bounds
|
|
print(e)
|
|
|
|
try:
|
|
map3.get_minor_submap(42)
|
|
except ValueError as e:
|
|
# CHECK: number of results out of bounds
|
|
print(e)
|
|
|
|
try:
|
|
map3.get_major_submap(42)
|
|
except ValueError as e:
|
|
# CHECK: number of results out of bounds
|
|
print(e)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAffineMapDerive
|
|
@run
|
|
def testAffineMapDerive():
|
|
with Context() as ctx:
|
|
map5 = AffineMap.get_identity(5)
|
|
|
|
# CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
|
|
map123 = map5.get_submap([1, 2, 3])
|
|
print(map123)
|
|
|
|
# CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
|
|
map01 = map5.get_major_submap(2)
|
|
print(map01)
|
|
|
|
# CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
|
|
map34 = map5.get_minor_submap(2)
|
|
print(map34)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAffineMapProperties
|
|
@run
|
|
def testAffineMapProperties():
|
|
with Context():
|
|
d0 = AffineDimExpr.get(0)
|
|
d1 = AffineDimExpr.get(1)
|
|
d2 = AffineDimExpr.get(2)
|
|
map1 = AffineMap.get(3, 0, [d2, d0])
|
|
map2 = AffineMap.get(3, 0, [d2, d0, d1])
|
|
map3 = AffineMap.get(3, 1, [d2, d0, d1])
|
|
# CHECK: False
|
|
print(map1.is_permutation)
|
|
# CHECK: True
|
|
print(map1.is_projected_permutation)
|
|
# CHECK: True
|
|
print(map2.is_permutation)
|
|
# CHECK: True
|
|
print(map2.is_projected_permutation)
|
|
# CHECK: False
|
|
print(map3.is_permutation)
|
|
# CHECK: False
|
|
print(map3.is_projected_permutation)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAffineMapExprs
|
|
@run
|
|
def testAffineMapExprs():
|
|
with Context():
|
|
d0 = AffineDimExpr.get(0)
|
|
d1 = AffineDimExpr.get(1)
|
|
d2 = AffineDimExpr.get(2)
|
|
map3 = AffineMap.get(3, 1, [d2, d0, d1])
|
|
|
|
# CHECK: 3
|
|
print(map3.n_dims)
|
|
# CHECK: 4
|
|
print(map3.n_inputs)
|
|
# CHECK: 1
|
|
print(map3.n_symbols)
|
|
assert map3.n_inputs == map3.n_dims + map3.n_symbols
|
|
|
|
# CHECK: 3
|
|
print(len(map3.results))
|
|
for expr in map3.results:
|
|
# CHECK: d2
|
|
# CHECK: d0
|
|
# CHECK: d1
|
|
print(expr)
|
|
for expr in map3.results[-1:-4:-1]:
|
|
# CHECK: d1
|
|
# CHECK: d0
|
|
# CHECK: d2
|
|
print(expr)
|
|
assert list(map3.results) == [d2, d0, d1]
|
|
|
|
|
|
# CHECK-LABEL: TEST: testCompressUnusedSymbols
|
|
@run
|
|
def testCompressUnusedSymbols():
|
|
with Context() as ctx:
|
|
d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
|
|
AffineDimExpr.get(2))
|
|
s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
|
|
AffineSymbolExpr.get(2))
|
|
maps = [
|
|
AffineMap.get(3, 3, [d2, d0, d1]),
|
|
AffineMap.get(3, 3, [d2, d0 + s2, d1]),
|
|
AffineMap.get(3, 3, [d1, d2, d0])
|
|
]
|
|
|
|
compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
|
|
|
|
# CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1))
|
|
# CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1))
|
|
# CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0))
|
|
print(maps)
|
|
|
|
# CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1))
|
|
# CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1))
|
|
# CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0))
|
|
print(compressed_maps)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testReplace
|
|
@run
|
|
def testReplace():
|
|
with Context() as ctx:
|
|
d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
|
|
AffineDimExpr.get(2))
|
|
s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
|
|
AffineSymbolExpr.get(2))
|
|
map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])
|
|
|
|
replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
|
|
replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
|
|
replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)
|
|
|
|
# CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42)
|
|
print(replace0)
|
|
|
|
# CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0)
|
|
print(replace1)
|
|
|
|
# CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0)
|
|
print(replace3)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testHash
|
|
@run
|
|
def testHash():
|
|
with Context():
|
|
d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1)
|
|
m1 = AffineMap.get(2, 0, [d0, d1])
|
|
m2 = AffineMap.get(2, 0, [d1, d0])
|
|
assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1]))
|
|
|
|
dictionary = dict()
|
|
dictionary[m1] = 1
|
|
dictionary[m2] = 2
|
|
assert m1 in dictionary
|