diff --git a/python/akg/ops/nn/matmul.py b/python/akg/ops/nn/matmul.py index b2e23697fbc67c0415b09331d7fc78971dd3b882..4d74d278ef53b5135afc3aa26e9fbe3329c348be 100644 --- a/python/akg/ops/nn/matmul.py +++ b/python/akg/ops/nn/matmul.py @@ -299,8 +299,6 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out return out - -@ct_util.reg_set_dim_func(matmul_set_dim) def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format="zN", transpose_x=False, transpose_y=False, attrs=None): """ Computes matrix multiplication x * y + b. @@ -337,4 +335,8 @@ def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format=" out = matmul4D_compute(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y, attrs) attr_map = {"pragma_rmselfdep": False} + + dims_info, _ = matmul_set_dim(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y, attrs) + attr_map["dim"] = dims_info + return out, attr_map