提交 7decc565 编写于 作者: C chenlei_autodiff

modify matmul tuning.

上级 2be38102
......@@ -80,7 +80,7 @@ matmul_set_dim_map = {
}
def matmul_set_dim(A, B, b, out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs):
def matmul_set_dim(A, B, b, out_dtype, left_format, right_format, output_format, adj_x, adj_y):
shape_A = A.shape[1:5] if len(A.shape) == 5 else A.shape
shape_B = B.shape[1:5] if len(B.shape) == 5 else B.shape
bias = 0 if b is None else 1
......@@ -199,7 +199,7 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out
elif left_format == "zN":
x_indices = indices[:(N - 4)] + (ko,) + indices[(N - 3):(N - 2)] + indices[(N-2):(N-1)] + (ki,)
if adj_x:
x_indices = indices[:(N - 4)] + indices[(N - 3):(N - 2)] + (ko,) + (ki,) + indices[(N-2):(N-1)]
x_indices = indices[:(N - 4)] + indices[(N - 3):(N - 2)] + (ko,) + (ki,) + indices[(N-2):(N-1)]
if right_format == "nZ":
y_indices = indices[:(N - 4)] + (ko, ) + indices[(N - 4):(N - 3)] + indices[(N - 1):] + (ki,)
......@@ -341,7 +341,7 @@ def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format="
out = matmul4D_compute(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y, attrs)
attr_map = {"pragma_rmselfdep": False}
dims_info, _ = matmul_set_dim(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y, attrs)
dims_info, _ = matmul_set_dim(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y)
attr_map["dim"] = dims_info
return out, attr_map
......@@ -67,13 +67,13 @@ def get_shapes(batch_tuple, M, K, N, trans_data=False, trans_weight=False):
return shape_x, shape_y
def getMatmulType(m, n, k):
type = MatmulType.gemm
def getMatmulType(m, n):
matmul_type = MatmulType.gemm
if m // cce.BLOCK_IN == 0:
type = MatmulType.gevm
matmul_type = MatmulType.gevm
elif n == 1:
type = MatmulType.gemv
return type
matmul_type = MatmulType.gemv
return matmul_type
def np_matmul(matrix_a, matrix_b, batch_tuple, M, K, N, trans_data=False, trans_weight=False, output_format=None):
......@@ -100,7 +100,7 @@ def np_matmul(matrix_a, matrix_b, batch_tuple, M, K, N, trans_data=False, trans_
for b in range(mul):
out[b, :] = np.dot(reshape_x[b, :], reshape_y[b, :])
#out[b,:] = np.matmul(reshape_x[b,:], reshape_y[b,:])
matmul_type = getMatmulType(M, N, K)
matmul_type = getMatmulType(M, N)
out_shape = ()
if matmul_type == MatmulType.gemm:
out_shape = batch_tuple + (M // cce.BLOCK_IN, cce.BLOCK_IN, N // cce.BLOCK_OUT, cce.BLOCK_OUT)
......@@ -132,7 +132,7 @@ def genData(batch_tuple, M, K, N, trans_data=False, trans_weight=False,
matrix_a_for_np = matrix_a.astype(np.float32)
matrix_b_for_np = matrix_b.astype(np.float32)
matmul_type = getMatmulType(M, N, K)
matmul_type = getMatmulType(M, N)
out = np_matmul(matrix_a_for_np, matrix_b_for_np, batch_tuple, M, K, N, trans_data, trans_weight, output_format).astype(out_dtype)
if dtype == "float16":
out.astype(np.float16)
......@@ -226,24 +226,24 @@ def reduce_data(reduce_type):
return res
def get_fractal_shape(dim1, dim2, reduce1="in", reduce2="reduce", format="zZ"):
def get_fractal_shape(dim1, dim2, reduce1="in", reduce2="reduce", matrix_format="zZ"):
result = ()
dim1_reduce = reduce_data(reduce1)
dim2_reduce = reduce_data(reduce2)
if format == "zZ":
if matrix_format == "zZ":
result = (dim1 // dim1_reduce, dim2 // dim2_reduce, dim1_reduce, dim2_reduce)
elif format == "nZ":
elif matrix_format == "nZ":
result = (dim1 // dim1_reduce, dim2 // dim2_reduce, dim2_reduce, dim1_reduce)
elif format == "nN":
elif matrix_format == "nN":
result = (dim2 // dim2_reduce, dim1 // dim1_reduce, dim2_reduce, dim1_reduce)
elif format == "zN":
elif matrix_format == "zN":
result = (dim2 // dim2_reduce, dim1 // dim1_reduce, dim1_reduce, dim2_reduce)
return result
def get_converted_shapes(m, n, k, batch_tuple, adj_x, adj_y, bias, left_format="zZ", right_format="nZ", out_format="zN"):
matmul_type = getMatmulType(m, n, k)
matmul_type = getMatmulType(m, n)
if matmul_type == MatmulType.gemm:
# left_format zZ process
if left_format == "zZ":
......@@ -341,7 +341,7 @@ def matmul_execute(shape_x, shape_y, bias, left_format, right_format, out_format
return (m_x, m_y), output, bench_mark, compare_result
def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_format, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs):
def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_format, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs, tuning=False):
batch_tuple, m, k, n = extract_dim(shape_x, shape_y, adj_x, adj_y)
m = (m + 15) // 16 * 16
n = (n + 15) // 16 * 16
......@@ -358,4 +358,4 @@ def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_for
input_shapes = [shape_xx, shape_yy]
input_types = [dtype, dtype]
op_attrs = [None, out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs]
return utils.op_build_test(matmul.matmul, input_shapes, input_types, op_attrs, kernel_name, attrs)
return utils.op_build_test(matmul.matmul, input_shapes, input_types, op_attrs, kernel_name, attrs, tuning=tuning)
......@@ -100,16 +100,6 @@ def gen_kernel_matmul_cube(op_desc: MatmulCubeDesc, _, index_table,
kernel_name = "matmul_cube_poly"
if idx is not None:
kernel_name += str(idx)
batch_tuple, m, k, n = matmul_run.extract_dim(op_desc.x_shape, op_desc.y_shape, op_desc.adj_x, op_desc.adj_y)
m = (m + 15) // 16 * 16
n = (n + 15) // 16 * 16
k = (k + 15) // 16 * 16
shape_xx, shape_yy, bias_shape, _, _ = matmul_run.get_converted_shapes(m, n, k, batch_tuple, op_desc.adj_x,
op_desc.adj_y, op_desc.bias,
op_desc.left_format, op_desc.right_format,
op_desc.out_format)
input_shapes = [shape_xx, shape_yy, bias_shape]
input_types = [op_desc.dtype, op_desc.dtype, op_desc.dtype]
if config is None:
attrs = {'dim': ""}
else:
......@@ -123,18 +113,9 @@ def gen_kernel_matmul_cube(op_desc: MatmulCubeDesc, _, index_table,
tiling_param.extend([(16, 16), (16, 16), (config.k_l1, config.k_l0)])
dim_info = ct_util.set_dims(tuple(tiling_param))
attrs = {'dim': dim_info, 'bypass': config.bypass}
has_bias = False
if op_desc.bias == 1:
has_bias = True
op_attrs = [op_desc.out_dtype, op_desc.left_format, op_desc.right_format, op_desc.out_format,
op_desc.adj_x, op_desc.adj_y, has_bias, attrs]
if has_bias == False:
input_shapes = [shape_xx, shape_yy]
input_types = [op_desc.dtype, op_desc.dtype]
op_attrs = [None, op_desc.out_dtype, op_desc.left_format, op_desc.right_format, op_desc.out_format,
op_desc.adj_x, op_desc.adj_y, has_bias, attrs]
return utils.op_build(matmul.matmul, input_shapes, input_types, op_attrs,
kernel_name=kernel_name, attrs=attrs, polyhedral=True, tuning=gen_tiling_spaces)
return matmul_run.matmul_compile(op_desc.x_shape, op_desc.y_shape, op_desc.bias, op_desc.left_format,
op_desc.right_format, op_desc.out_format, op_desc.adj_x, op_desc.adj_y,
op_desc.dtype, op_desc.out_dtype, kernel_name, attrs, gen_tiling_spaces)
def gen_kernel_conv_backprop_input(op_desc: ConvBackpropDesc, _, index_table, config: ConvBackpropInputConfig = None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册