diff --git a/python/akg/ms/cce/matmul.py b/python/akg/ms/cce/matmul.py index 9e4624a55f499aef79164b8cc41de5312d7c85f7..ab090abe2c26767f0a6a8561083b78aa8162c842 100644 --- a/python/akg/ms/cce/matmul.py +++ b/python/akg/ms/cce/matmul.py @@ -15,8 +15,10 @@ # limitations under the License. """matmul""" -from akg.ops.nn import batchmatmul +from akg.ops.nn import matmul -def MatMul(x1, x2, transpose_a=False, transpose_b=False): +def MatMul(x1, x2, out_dtype, transpose_a=False, transpose_b=False): """matmul""" - return batchmatmul.batchmatmul(x1, x2, transpose_a, transpose_b) + return matmul.matmul(x=x1, y=x2, b=None, out_dtype=out_dtype, + left_format="zN", right_format="zN", out_format="zN", + transpose_x=transpose_a, transpose_y=transpose_b)