diff --git a/python/akg/ops/nn/matmul.py b/python/akg/ops/nn/matmul.py index 2dd3474e2c07ed1eaa1297de3ed3c96e094701c1..4830cd51747067e01943a0ecea61df40c2d1b854 100644 --- a/python/akg/ops/nn/matmul.py +++ b/python/akg/ops/nn/matmul.py @@ -299,8 +299,6 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out return out - -@ct_util.reg_set_dim_func(matmul_set_dim) def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format="zN", transpose_x=False, transpose_y=False, attrs=None): """ Computes matrix multiplication x * y + b. @@ -337,4 +335,8 @@ def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format=" out = matmul4D_compute(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y, attrs) attr_map = {"pragma_rmselfdep": False} + + dims_info, _ = matmul_set_dim(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y, attrs) + attr_map["dim"] = dims_info + return out, attr_map