提交 64cc6350 编写于 作者: W wangzhuo325

matmul code refactor

上级 19e5b2ac
...@@ -33,7 +33,7 @@ matmul_set_dim_map = { ...@@ -33,7 +33,7 @@ matmul_set_dim_map = {
} }
def matmul_set_dim(A, B, b, out_dtype, left_format, right_format, output_format, adj_x, adj_y, has_bias, attrs): def matmul_set_dim(A, B, b, out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs):
shape_A = A.shape shape_A = A.shape
shape_B = B.shape shape_B = B.shape
bias = 0 if b is None else 1 bias = 0 if b is None else 1
...@@ -259,7 +259,7 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out ...@@ -259,7 +259,7 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out
@ct_util.reg_set_dim_func(matmul_set_dim) @ct_util.reg_set_dim_func(matmul_set_dim)
def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format="zN", transpose_x=False, transpose_y=False, has_bias=False, attrs=None): def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format="zN", transpose_x=False, transpose_y=False, attrs=None):
""" """
Computes matrix multiplication x * y + b. Computes matrix multiplication x * y + b.
...@@ -273,7 +273,6 @@ def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format=" ...@@ -273,7 +273,6 @@ def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format="
out_format: str. Data format of output tensor. Supported data format list ["zZ", "nZ", "zN"]. out_format: str. Data format of output tensor. Supported data format list ["zZ", "nZ", "zN"].
transpose_x: Boolean. Specifies whether x is transposed or not. transpose_x: Boolean. Specifies whether x is transposed or not.
transpose_y: Boolean. Specifies whether y is transposed or not. transpose_y: Boolean. Specifies whether y is transposed or not.
has_bias: Boolean. Specifies whether bias tensor exists or not.
attrs: Dict. Used in matmul computation. attrs: Dict. Used in matmul computation.
Note: Note:
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""operator dsl function: matmul4d_ad""" """operator dsl function: matmul4d_ad"""
import akg.tvm import akg.tvm
import akg import akg
from test_op import matmul from akg.ops.nn import matmul
from akg.utils import custom_tiling as ct_util from akg.utils import custom_tiling as ct_util
......
...@@ -19,7 +19,7 @@ from gen_random import random_gaussian ...@@ -19,7 +19,7 @@ from gen_random import random_gaussian
import numpy as np import numpy as np
import akg.backend as cce import akg.backend as cce
from akg.utils import kernel_exec as utils from akg.utils import kernel_exec as utils
from test_op import matmul from akg.ops.nn import matmul
from base import get_rtol_atol from base import get_rtol_atol
from tensorio import compare_tensor from tensorio import compare_tensor
...@@ -353,9 +353,9 @@ def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_for ...@@ -353,9 +353,9 @@ def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_for
has_bias = False has_bias = False
if bias == 1: if bias == 1:
has_bias = True has_bias = True
op_attrs = [out_dtype, left_format, right_format, output_format, adj_x, adj_y, has_bias, attrs] op_attrs = [out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs]
if has_bias == False: if has_bias == False:
input_shapes = [shape_xx, shape_yy] input_shapes = [shape_xx, shape_yy]
input_types = [dtype, dtype] input_types = [dtype, dtype]
op_attrs = [None, out_dtype, left_format, right_format, output_format, adj_x, adj_y, has_bias, attrs] op_attrs = [None, out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs]
return utils.op_build_test(matmul.matmul, input_shapes, input_types, op_attrs, kernel_name, attrs) return utils.op_build_test(matmul.matmul, input_shapes, input_types, op_attrs, kernel_name, attrs)
...@@ -21,7 +21,7 @@ from akg.utils import custom_tiling as ct_util ...@@ -21,7 +21,7 @@ from akg.utils import custom_tiling as ct_util
from akg.ops.nn import conv_bn1 from akg.ops.nn import conv_bn1
from akg.ops.nn import conv, conv_backprop_input, conv_backprop_filter, batchmatmul from akg.ops.nn import conv, conv_backprop_input, conv_backprop_filter, batchmatmul
from akg.backend import build_module from akg.backend import build_module
from test_op import matmul from akg.ops.nn import matmul
from test_run import batchmatmul_run, matmul_run from test_run import batchmatmul_run, matmul_run
from .type_definitions import ConvDesc, ConvBackpropDesc, MatmulCubeDesc, ConvConfig, ConvBackpropInputConfig, ConvBackpropFilterConfig, MatmulCubeConfig from .type_definitions import ConvDesc, ConvBackpropDesc, MatmulCubeDesc, ConvConfig, ConvBackpropInputConfig, ConvBackpropFilterConfig, MatmulCubeConfig
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册