提交 64cc6350 编写于 作者: W wangzhuo325

matmul code refactor

上级 19e5b2ac
......@@ -33,7 +33,7 @@ matmul_set_dim_map = {
}
def matmul_set_dim(A, B, b, out_dtype, left_format, right_format, output_format, adj_x, adj_y, has_bias, attrs):
def matmul_set_dim(A, B, b, out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs):
shape_A = A.shape
shape_B = B.shape
bias = 0 if b is None else 1
......@@ -259,7 +259,7 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out
@ct_util.reg_set_dim_func(matmul_set_dim)
def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format="zN", transpose_x=False, transpose_y=False, has_bias=False, attrs=None):
def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format="zN", transpose_x=False, transpose_y=False, attrs=None):
"""
Computes matrix multiplication x * y + b.
......@@ -273,7 +273,6 @@ def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format="
out_format: str. Data format of output tensor. Supported data format list ["zZ", "nZ", "zN"].
transpose_x: Boolean. Specifies whether x is transposed or not.
transpose_y: Boolean. Specifies whether y is transposed or not.
has_bias: Boolean. Specifies whether bias tensor exists or not.
attrs: Dict. Used in matmul computation.
Note:
......
......@@ -15,7 +15,7 @@
"""operator dsl function: matmul4d_ad"""
import akg.tvm
import akg
from test_op import matmul
from akg.ops.nn import matmul
from akg.utils import custom_tiling as ct_util
......
......@@ -19,7 +19,7 @@ from gen_random import random_gaussian
import numpy as np
import akg.backend as cce
from akg.utils import kernel_exec as utils
from test_op import matmul
from akg.ops.nn import matmul
from base import get_rtol_atol
from tensorio import compare_tensor
......@@ -353,9 +353,9 @@ def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_for
has_bias = False
if bias == 1:
has_bias = True
op_attrs = [out_dtype, left_format, right_format, output_format, adj_x, adj_y, has_bias, attrs]
op_attrs = [out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs]
if has_bias == False:
input_shapes = [shape_xx, shape_yy]
input_types = [dtype, dtype]
op_attrs = [None, out_dtype, left_format, right_format, output_format, adj_x, adj_y, has_bias, attrs]
op_attrs = [None, out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs]
return utils.op_build_test(matmul.matmul, input_shapes, input_types, op_attrs, kernel_name, attrs)
......@@ -21,7 +21,7 @@ from akg.utils import custom_tiling as ct_util
from akg.ops.nn import conv_bn1
from akg.ops.nn import conv, conv_backprop_input, conv_backprop_filter, batchmatmul
from akg.backend import build_module
from test_op import matmul
from akg.ops.nn import matmul
from test_run import batchmatmul_run, matmul_run
from .type_definitions import ConvDesc, ConvBackpropDesc, MatmulCubeDesc, ConvConfig, ConvBackpropInputConfig, ConvBackpropFilterConfig, MatmulCubeConfig
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册