diff --git a/python/akg/ops/nn/matmul.py b/python/akg/ops/nn/matmul.py index e4cab53b7c4ce2ec90ad33a2509670c3af6eb14a..31a2bab6229023a3298a3b84d05eb204a452a2ff 100644 --- a/python/akg/ops/nn/matmul.py +++ b/python/akg/ops/nn/matmul.py @@ -80,7 +80,7 @@ matmul_set_dim_map = { } -def matmul_set_dim(A, B, b, out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs): +def matmul_set_dim(A, B, b, out_dtype, left_format, right_format, output_format, adj_x, adj_y): shape_A = A.shape[1:5] if len(A.shape) == 5 else A.shape shape_B = B.shape[1:5] if len(B.shape) == 5 else B.shape bias = 0 if b is None else 1 @@ -341,7 +341,7 @@ def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format=" out = matmul4D_compute(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y, attrs) attr_map = {"pragma_rmselfdep": False} - dims_info, _ = matmul_set_dim(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y, attrs) + dims_info, _ = matmul_set_dim(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y) attr_map["dim"] = dims_info return out, attr_map diff --git a/tests/common/test_run/matmul_run.py b/tests/common/test_run/matmul_run.py index 924a66c5d2fb88418ea8bc2ca610d5882df07c93..3e88148c2502fcda7f320a69d5accda8bc5e1392 100644 --- a/tests/common/test_run/matmul_run.py +++ b/tests/common/test_run/matmul_run.py @@ -67,13 +67,13 @@ def get_shapes(batch_tuple, M, K, N, trans_data=False, trans_weight=False): return shape_x, shape_y -def getMatmulType(m, n, k): - type = MatmulType.gemm +def getMatmulType(m, n): + matmul_type = MatmulType.gemm if m // cce.BLOCK_IN == 0: - type = MatmulType.gevm + matmul_type = MatmulType.gevm elif n == 1: - type = MatmulType.gemv - return type + matmul_type = MatmulType.gemv + return matmul_type def np_matmul(matrix_a, matrix_b, batch_tuple, M, K, N, trans_data=False, trans_weight=False, output_format=None): @@ -100,7 +100,7 @@ def np_matmul(matrix_a, matrix_b, batch_tuple, M, K, N, trans_data=False, trans_ for b in range(mul): out[b, :] = np.dot(reshape_x[b, :], reshape_y[b, :]) #out[b,:] = np.matmul(reshape_x[b,:], reshape_y[b,:]) - matmul_type = getMatmulType(M, N, K) + matmul_type = getMatmulType(M, N) out_shape = () if matmul_type == MatmulType.gemm: out_shape = batch_tuple + (M // cce.BLOCK_IN, cce.BLOCK_IN, N // cce.BLOCK_OUT, cce.BLOCK_OUT) @@ -132,7 +132,7 @@ def genData(batch_tuple, M, K, N, trans_data=False, trans_weight=False, matrix_a_for_np = matrix_a.astype(np.float32) matrix_b_for_np = matrix_b.astype(np.float32) - matmul_type = getMatmulType(M, N, K) + matmul_type = getMatmulType(M, N) out = np_matmul(matrix_a_for_np, matrix_b_for_np, batch_tuple, M, K, N, trans_data, trans_weight, output_format).astype(out_dtype) if dtype == "float16": out.astype(np.float16) @@ -226,24 +226,24 @@ def reduce_data(reduce_type): return res -def get_fractal_shape(dim1, dim2, reduce1="in", reduce2="reduce", format="zZ"): +def get_fractal_shape(dim1, dim2, reduce1="in", reduce2="reduce", matrix_format="zZ"): result = () dim1_reduce = reduce_data(reduce1) dim2_reduce = reduce_data(reduce2) - if format == "zZ": + if matrix_format == "zZ": result = (dim1 // dim1_reduce, dim2 // dim2_reduce, dim1_reduce, dim2_reduce) - elif format == "nZ": + elif matrix_format == "nZ": result = (dim1 // dim1_reduce, dim2 // dim2_reduce, dim2_reduce, dim1_reduce) - elif format == "nN": + elif matrix_format == "nN": result = (dim2 // dim2_reduce, dim1 // dim1_reduce, dim2_reduce, dim1_reduce) - elif format == "zN": + elif matrix_format == "zN": result = (dim2 // dim2_reduce, dim1 // dim1_reduce, dim1_reduce, dim2_reduce) return result def get_converted_shapes(m, n, k, batch_tuple, adj_x, adj_y, bias, left_format="zZ", right_format="nZ", out_format="zN"): - matmul_type = getMatmulType(m, n, k) + matmul_type = getMatmulType(m, n) if matmul_type == MatmulType.gemm: # left_format zZ process if left_format == "zZ": @@ -341,7 +341,7 @@ def matmul_execute(shape_x, shape_y, bias, left_format, right_format, out_format return (m_x, m_y), output, bench_mark, compare_result -def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_format, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs): +def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_format, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs, tuning=False): batch_tuple, m, k, n = extract_dim(shape_x, shape_y, adj_x, adj_y) m = (m + 15) // 16 * 16 n = (n + 15) // 16 * 16 @@ -358,4 +358,4 @@ def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_for input_shapes = [shape_xx, shape_yy] input_types = [dtype, dtype] op_attrs = [None, out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs] - return utils.op_build_test(matmul.matmul, input_shapes, input_types, op_attrs, kernel_name, attrs) + return utils.op_build_test(matmul.matmul, input_shapes, input_types, op_attrs, kernel_name, attrs, tuning=tuning) diff --git a/tests/fuzz/tune/autotuning/kernel_compiler.py b/tests/fuzz/tune/autotuning/kernel_compiler.py index ca0bd31f869bebd6678755f5a26df6068aefcd51..5c56a43287965823dde2772ba888ec23048f8f3a 100644 --- a/tests/fuzz/tune/autotuning/kernel_compiler.py +++ b/tests/fuzz/tune/autotuning/kernel_compiler.py @@ -100,16 +100,6 @@ def gen_kernel_matmul_cube(op_desc: MatmulCubeDesc, _, index_table, kernel_name = "matmul_cube_poly" if idx is not None: kernel_name += str(idx) - batch_tuple, m, k, n = matmul_run.extract_dim(op_desc.x_shape, op_desc.y_shape, op_desc.adj_x, op_desc.adj_y) - m = (m + 15) // 16 * 16 - n = (n + 15) // 16 * 16 - k = (k + 15) // 16 * 16 - shape_xx, shape_yy, bias_shape, _, _ = matmul_run.get_converted_shapes(m, n, k, batch_tuple, op_desc.adj_x, - op_desc.adj_y, op_desc.bias, - op_desc.left_format, op_desc.right_format, - op_desc.out_format) - input_shapes = [shape_xx, shape_yy, bias_shape] - input_types = [op_desc.dtype, op_desc.dtype, op_desc.dtype] if config is None: attrs = {'dim': ""} else: @@ -123,18 +113,9 @@ def gen_kernel_matmul_cube(op_desc: MatmulCubeDesc, _, index_table, tiling_param.extend([(16, 16), (16, 16), (config.k_l1, config.k_l0)]) dim_info = ct_util.set_dims(tuple(tiling_param)) attrs = {'dim': dim_info, 'bypass': config.bypass} - has_bias = False - if op_desc.bias == 1: - has_bias = True - op_attrs = [op_desc.out_dtype, op_desc.left_format, op_desc.right_format, op_desc.out_format, - op_desc.adj_x, op_desc.adj_y, has_bias, attrs] - if has_bias == False: - input_shapes = [shape_xx, shape_yy] - input_types = [op_desc.dtype, op_desc.dtype] - op_attrs = [None, op_desc.out_dtype, op_desc.left_format, op_desc.right_format, op_desc.out_format, - op_desc.adj_x, op_desc.adj_y, has_bias, attrs] - return utils.op_build(matmul.matmul, input_shapes, input_types, op_attrs, - kernel_name=kernel_name, attrs=attrs, polyhedral=True, tuning=gen_tiling_spaces) + return matmul_run.matmul_compile(op_desc.x_shape, op_desc.y_shape, op_desc.bias, op_desc.left_format, + op_desc.right_format, op_desc.out_format, op_desc.adj_x, op_desc.adj_y, + op_desc.dtype, op_desc.out_dtype, kernel_name, attrs, gen_tiling_spaces) def gen_kernel_conv_backprop_input(op_desc: ConvBackpropDesc, _, index_table, config: ConvBackpropInputConfig = None,