152 lines
12 KiB
MLIR
152 lines
12 KiB
MLIR
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
|
|
|
// RUN: mlir-opt %s --linalg-generalize-named-ops \
|
|
// RUN: --sparsification --sparse-tensor-codegen \
|
|
// RUN: --canonicalize --cse | FileCheck %s
|
|
|
|
#CSR = #sparse_tensor.encoding<{
|
|
dimLevelType = [ "dense", "compressed" ],
|
|
dimOrdering = affine_map<(i,j) -> (i,j)>
|
|
}>
|
|
|
|
//
|
|
// Computes C = A x B with all matrices sparse (SpMSpM) in CSR.
|
|
//
|
|
// CHECK-LABEL: func.func private @_insert_D_C_4_4_f64_0_0(
|
|
// CHECK-SAME: %[[VAL_0:.*]]: memref<2xindex>,
|
|
// CHECK-SAME: %[[VAL_1:.*]]: memref<3xindex>,
|
|
// CHECK-SAME: %[[VAL_2:[^ ]+]]: memref<?xindex>,
|
|
// CHECK-SAME: %[[VAL_3:.*]]: memref<?xindex>,
|
|
// CHECK-SAME: %[[VAL_4:.*]]: memref<?xf64>,
|
|
// CHECK-SAME: %[[VAL_5:[^ ]+]]: index,
|
|
// CHECK-SAME: %[[VAL_6:.*]]: index,
|
|
// CHECK-SAME: %[[VAL_7:.*]]: f64) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
|
|
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant false
|
|
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index
|
|
// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : index
|
|
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<?xindex>
|
|
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<?xindex>
|
|
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<3xindex>
|
|
// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_9]] : index
|
|
// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index
|
|
// CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) {
|
|
// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref<?xindex>
|
|
// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_6]] : index
|
|
// CHECK: scf.yield %[[VAL_18]] : i1
|
|
// CHECK: } else {
|
|
// CHECK: memref.store %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<?xindex>
|
|
// CHECK: scf.yield %[[VAL_8]] : i1
|
|
// CHECK: }
|
|
// CHECK: %[[VAL_19:.*]] = scf.if %[[VAL_20:.*]] -> (memref<?xindex>) {
|
|
// CHECK: scf.yield %[[VAL_3]] : memref<?xindex>
|
|
// CHECK: } else {
|
|
// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : index
|
|
// CHECK: memref.store %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<?xindex>
|
|
// CHECK: %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_3]], %[[VAL_6]] {idx = 1 : index} : memref<3xindex>, memref<?xindex>, index
|
|
// CHECK: scf.yield %[[VAL_22]] : memref<?xindex>
|
|
// CHECK: }
|
|
// CHECK: %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_4]], %[[VAL_7]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
|
|
// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_24:.*]], %[[VAL_23]] : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
|
|
// CHECK: }
|
|
|
|
// CHECK-LABEL: func.func @matmul(
|
|
// CHECK-SAME: %[[VAL_0:.*0]]: memref<2xindex>,
|
|
// CHECK-SAME: %[[VAL_1:.*1]]: memref<3xindex>,
|
|
// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xindex>,
|
|
// CHECK-SAME: %[[VAL_3:.*3]]: memref<?xindex>,
|
|
// CHECK-SAME: %[[VAL_4:.*4]]: memref<?xf64>,
|
|
// CHECK-SAME: %[[VAL_5:.*5]]: memref<2xindex>,
|
|
// CHECK-SAME: %[[VAL_6:.*6]]: memref<3xindex>,
|
|
// CHECK-SAME: %[[VAL_7:.*7]]: memref<?xindex>,
|
|
// CHECK-SAME: %[[VAL_8:.*8]]: memref<?xindex>,
|
|
// CHECK-SAME: %[[VAL_9:.*9]]: memref<?xf64>) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
|
|
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 4 : index
|
|
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64
|
|
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index
|
|
// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 1 : index
|
|
// CHECK-DAG: %[[VAL_14:.*]] = arith.constant false
|
|
// CHECK-DAG: %[[VAL_15:.*]] = arith.constant true
|
|
// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<2xindex>
|
|
// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex>
|
|
// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex>
|
|
// CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref<?xindex>
|
|
// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<16xindex>
|
|
// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref<?xindex>
|
|
// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64>
|
|
// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref<?xf64>
|
|
// CHECK: linalg.fill ins(%[[VAL_12]] : index) outs(%[[VAL_17]] : memref<3xindex>)
|
|
// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_12]]] : memref<2xindex>
|
|
// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_13]]] : memref<2xindex>
|
|
// CHECK: %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_19]], %[[VAL_12]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index
|
|
// CHECK: %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_24]], %[[VAL_12]], %[[VAL_10]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index, index
|
|
// CHECK: %[[VAL_26:.*]] = memref.alloc() : memref<4xf64>
|
|
// CHECK: %[[VAL_27:.*]] = memref.alloc() : memref<4xi1>
|
|
// CHECK: %[[VAL_28:.*]] = memref.alloc() : memref<4xindex>
|
|
// CHECK: %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<4xindex> to memref<?xindex>
|
|
// CHECK: linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_26]] : memref<4xf64>)
|
|
// CHECK: linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_27]] : memref<4xi1>)
|
|
// CHECK: %[[VAL_30:.*]]:5 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_16]], %[[VAL_33:.*]] = %[[VAL_17]], %[[VAL_34:.*]] = %[[VAL_25]], %[[VAL_35:.*]] = %[[VAL_21]], %[[VAL_36:.*]] = %[[VAL_23]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
|
|
// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref<?xindex>
|
|
// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index
|
|
// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_38]]] : memref<?xindex>
|
|
// CHECK: %[[VAL_40:.*]] = scf.for %[[VAL_41:.*]] = %[[VAL_37]] to %[[VAL_39]] step %[[VAL_13]] iter_args(%[[VAL_42:.*]] = %[[VAL_12]]) -> (index) {
|
|
// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] : memref<?xindex>
|
|
// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] : memref<?xf64>
|
|
// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref<?xindex>
|
|
// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_13]] : index
|
|
// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_46]]] : memref<?xindex>
|
|
// CHECK: %[[VAL_48:.*]] = scf.for %[[VAL_49:.*]] = %[[VAL_45]] to %[[VAL_47]] step %[[VAL_13]] iter_args(%[[VAL_50:.*]] = %[[VAL_42]]) -> (index) {
|
|
// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_49]]] : memref<?xindex>
|
|
// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64>
|
|
// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_49]]] : memref<?xf64>
|
|
// CHECK: %[[VAL_54:.*]] = arith.mulf %[[VAL_44]], %[[VAL_53]] : f64
|
|
// CHECK: %[[VAL_55:.*]] = arith.addf %[[VAL_52]], %[[VAL_54]] : f64
|
|
// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1>
|
|
// CHECK: %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_14]] : i1
|
|
// CHECK: %[[VAL_58:.*]] = scf.if %[[VAL_57]] -> (index) {
|
|
// CHECK: memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1>
|
|
// CHECK: memref.store %[[VAL_51]], %[[VAL_28]]{{\[}}%[[VAL_50]]] : memref<4xindex>
|
|
// CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_50]], %[[VAL_13]] : index
|
|
// CHECK: scf.yield %[[VAL_59]] : index
|
|
// CHECK: } else {
|
|
// CHECK: scf.yield %[[VAL_50]] : index
|
|
// CHECK: }
|
|
// CHECK: memref.store %[[VAL_55]], %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64>
|
|
// CHECK: scf.yield %[[VAL_60:.*]] : index
|
|
// CHECK: } {"Emitted from" = "linalg.generic"}
|
|
// CHECK: sparse_tensor.sort %[[VAL_62:.*]], %[[VAL_29]] : memref<?xindex>
|
|
// CHECK: %[[VAL_63:.*]]:5 = scf.for %[[VAL_64:.*]] = %[[VAL_12]] to %[[VAL_62]] step %[[VAL_13]] iter_args(%[[VAL_65:.*]] = %[[VAL_32]], %[[VAL_66:.*]] = %[[VAL_33]], %[[VAL_67:.*]] = %[[VAL_34]], %[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
|
|
// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_64]]] : memref<4xindex>
|
|
// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64>
|
|
// CHECK: %[[VAL_72:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_65]], %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_69]], %[[VAL_31]], %[[VAL_70]], %[[VAL_71]]) : (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>, index, index, f64) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>)
|
|
// CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64>
|
|
// CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_70]]] : memref<4xi1>
|
|
// CHECK: scf.yield %[[VAL_72]]#0, %[[VAL_72]]#1, %[[VAL_72]]#2, %[[VAL_72]]#3, %[[VAL_72]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
|
|
// CHECK: }
|
|
// CHECK: scf.yield %[[VAL_73:.*]]#0, %[[VAL_73]]#1, %[[VAL_73]]#2, %[[VAL_73]]#3, %[[VAL_73]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
|
|
// CHECK: } {"Emitted from" = "linalg.generic"}
|
|
// CHECK: memref.dealloc %[[VAL_26]] : memref<4xf64>
|
|
// CHECK: memref.dealloc %[[VAL_27]] : memref<4xi1>
|
|
// CHECK: memref.dealloc %[[VAL_28]] : memref<4xindex>
|
|
// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_75:.*]]#1{{\[}}%[[VAL_12]]] : memref<3xindex>
|
|
// CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_12]]] : memref<?xindex>
|
|
// CHECK: %[[VAL_77:.*]] = scf.for %[[VAL_78:.*]] = %[[VAL_13]] to %[[VAL_74]] step %[[VAL_13]] iter_args(%[[VAL_79:.*]] = %[[VAL_76]]) -> (index) {
|
|
// CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref<?xindex>
|
|
// CHECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_12]] : index
|
|
// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_79]], %[[VAL_80]] : index
|
|
// CHECK: scf.if %[[VAL_81]] {
|
|
// CHECK: memref.store %[[VAL_79]], %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref<?xindex>
|
|
// CHECK: }
|
|
// CHECK: scf.yield %[[VAL_82]] : index
|
|
// CHECK: }
|
|
// CHECK: return %[[VAL_75]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3, %[[VAL_75]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
|
|
// CHECK: }
|
|
func.func @matmul(%A: tensor<4x8xf64, #CSR>,
|
|
%B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
|
|
%C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
|
|
%D = linalg.matmul
|
|
ins(%A, %B: tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>)
|
|
outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
|
|
return %D: tensor<4x4xf64, #CSR>
|
|
}
|