# Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """operator dsl function: matmul""" import logging import akg.tvm import akg from akg import backend as cce from akg.utils import kernel_exec as utils from akg.utils import custom_tiling as ct_util from akg.utils import validation_check as vc_util from akg.ops.math import cast logging.basicConfig(level=logging.DEBUG) matmul_set_dim_map = { str(((1, 16, 49, 16, 16), (1, 49, 49, 16, 16), 0, 'zZ', 'zZ', 'zZ', False, False, 'float16', 'float32')) : ([(1,1),(2,2),(16,16),(16,16),(49,49)], {"bypass" : 0}), str(((1, 16, 49, 16, 16), (1, 49, 16, 16, 16), 0, 'zZ', 'zZ', 'zZ', False, False, 'float16', 'float32')) : ([(2,2),(2,2),(16,16),(16,16),(49,49)], {"bypass" : 0}), str(((1, 2, 64, 16, 16), (1, 2, 64, 16, 16), 0, 'zZ', 'zZ', 'zZ', True, False, 'float16', 'float32')) : ([(2,2),(64,64),(16,16),(16,16),(2,2)], {"bypass" : 0}), str(((1, 2, 128, 16, 16), (1, 2, 128, 16, 16), 0, 'zZ', 'zZ', 'zZ', True, False, 'float16', 'float32')) : ([(2,2),(64,64),(16,16),(16,16),(2,2)], {"bypass" : 0}), # bert best tile # (16, 1024), (16, 1024) str(((1, 64, 1, 16, 16), (1, 64, 1, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(16,16),(16,16),(32,32)], {"bypass" : 2}), # (8192, 4096), (8192, 1024) str(((1, 256, 512, 16, 16), (1, 64, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16),(8,1)], {"bypass" : 0}), # (8192, 1024), (1024, 4096) str(((1, 64, 512, 16, 16), (1, 256, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(16,16),(8,4),(16,16),(16,16),(64,8)], {"bypass" : 0}), # (16, 16), (16, 1024) str(((1, 1, 1, 16, 16), (1, 64, 1, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16)], {"bypass" : 0}), # (1216, 1024), (1024, 1024) str(((1, 64, 76, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float32")) : ([(4,4),(19,19),(16,16),(16,16),(4,1)], {"bypass" : 0}), # (8192, 4096), (4096, 1024) str(((1, 256, 512, 16 ,16), (1, 64, 256, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(8,8),(32,32),(16,16),(16,16),(2,1)], {"bypass" : 0}), # (8192, 1024), (4096, 1024) str(((1, 64, 512, 16, 16), (1, 64, 256, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(32,32),(16,16),(16,16),(2,1)], {"bypass" : 0}), # (8192, 1024), (8192, 4096) str(((1, 64, 512, 16, 16), (1, 256, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float16")) : ([[8, 8], [32, 32], [16, 16], [16, 16], [16, 2]], {"bypass": 0}), # (1216, 1024), (1024, 1024) str(((1, 64, 76, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(19,19),(16,16),(16,16),(16,1)], {"bypass" : 2}), # (8192, 1024), (1024, 1024) str(((1, 64, 512, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(16,4),(16,8),(16,16),(16,16),(64,16)], {"bypass" : 0}), # (1216, 30522), (30522, 1024) str(((1, 1908, 76, 16, 16), (1, 64, 1908, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(8,8),(19,19),(16,16),(16,16),(6,1)], {"bypass" : 0}), # (1216, 30522), (1216, 1024) str(((1, 1908, 76, 16, 16), (1, 64, 76, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(4,4),(18,18),(16,16),(16,16),(2,2)], {"bypass" : 0}), # (1216, 1024), (30522, 1024) str(((1, 64, 76, 16, 16), (1, 64, 1908, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float32")) : ([(9,9),(19,19),(16,16),(16,16),(64,1)], {"bypass" : 0}), # (8192, 1024), (8192, 1024) str(((1, 64, 512, 16, 16), (1, 64, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}), # (1216, 1024), (1216, 1024) str(((1, 64, 76, 16, 16), (1, 64, 76, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float16")) : ([(16,16),(8,8),(16,16),(16,16),(4,2)], {"bypass" : 0}), # (16, 1024), (16, 1024) str(((1, 64, 1, 16, 16), (1, 64, 1, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(2,2),(16,16),(16,16),(16,16)], {"bypass" : 0}), # (16, 1024), (1024, 1024) str(((1, 64, 1, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float32")) : ([(8,8),(16,16),(16,16),(32,8)], {"bypass" : 2}), # (16, 16), (16, 1024) str(((1, 1, 1, 16, 16), (1, 64, 1, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16)], {"bypass" : 0}), # (8192, 1024), (1024, 1024) str(((1, 64, 512, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(16,8),(8,8),(16,16),(16,16),(64,8)], {"bypass" : 1}), # (8192, 4096), (1024, 4096) str(((1, 256, 512, 16, 16), (1, 256, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(16,16),(16,16),(16,16),(128,8)], {"bypass" : 1}), } def matmul_set_dim(A, B, b, out_dtype, left_format, right_format, output_format, adj_x, adj_y, has_bias, attrs): shape_A = A.shape shape_B = B.shape bias = 0 if b is None else 1 key = () key += (tuple(shape_A), tuple(shape_B), bias, left_format, right_format, output_format, adj_x, adj_y, A.dtype, out_dtype) hash_key = str(key) if hash_key in matmul_set_dim_map: configs = matmul_set_dim_map[hash_key] if isinstance(configs, tuple): tiles = configs[0] else: tiles = configs set_dims = ct_util.set_dims(tiles) return set_dims, hash_key return "", hash_key def output_shape_compute(left_shape, right_shape, left_format, right_format, out_format, trans_a, trans_b): if len(left_shape) < 4 or len(right_shape) < 4: raise RuntimeError("matmul left matrix and right matrix should be fractal shape") format_list = ["zZ", "nZ", "zN"] if not(left_format in format_list): raise RuntimeError("matmul left_format only support %s" % (",".join(format_list))) if not(right_format in format_list): raise RuntimeError("matmul right_format only support %s" % (",".join(format_list))) if not(out_format in format_list): raise RuntimeError("matmul out_format only support %s" % (",".join(format_list))) batch = left_shape[:-4] # left matrix default is zZ mo, ko, mi, ki = left_shape[-4:] if trans_a: ko, mo, ki, mi = left_shape[-4:] if left_format == "nZ": mo, ko, ki, mi = left_shape[-4:] if trans_a: ko, mo, mi, ki = left_shape[-4:] elif left_format == "zN": ko, mo, mi, ki = left_shape[-4:] if trans_a: mo, ko, ki, mi = left_shape[-4:] # right matrix default is zZ ko, no, ki, ni = right_shape[-4:] if trans_b: no, ko, ni, ki = right_shape[-4:] if right_format == "nZ": ko, no, ni, ki = right_shape[-4:] if trans_b: no, ko, ki, ni = right_shape[-4:] elif right_format == "zN": no, ko, ki, ni = right_shape[-4:] if trans_b: ko, no, ni, ki = right_shape[-4:] output_shape = tuple(batch) + (mo, no, mi, ni) if out_format == "nZ": output_shape = tuple(batch) + (mo, no, ni, mi) elif out_format == "zN": output_shape = tuple(batch) + (no, mo, mi, ni) k = ko * ki return output_shape, k def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out_format, transpose_x=False, transpose_y=False, attrs=None): # for gemv use transpose of AB --> gevm trans(trans(B) * trans(A)) data_dtype = x.dtype.lower() check_list = ["int8", "uint8", "float16", "float32", "int32"] if not (data_dtype in check_list): raise RuntimeError("matmul_cce ony supports %s while dtype is %s" % (",".join(check_list), x.dtype)) if bias_value is None: bias_name = '' bias = 0 else: bias_name = bias_value.name bias = 0 if bias_value is None else 1 output_shape_zN, k = output_shape_compute(x.shape, y.shape, left_format, right_format, "zN", transpose_x, transpose_y) output_shape_zZ, k = output_shape_compute(x.shape, y.shape, left_format, right_format, "zZ", transpose_x, transpose_y) shape_A = x.shape shape_B = y.shape key = () key += (tuple(shape_A), tuple(shape_B), bias, left_format, right_format, out_format, transpose_x, transpose_y, x.dtype, out_dtype) hash_key = str(key) # bypass 2 left matrix ddr -> l0 # bypass 1 right matrix ddr -> l0 bypass_list = [0, 1, 2] bypass = 0 if attrs is not None and 'bypass' in attrs: bypass = attrs['bypass'] elif hash_key in matmul_set_dim_map: configs = matmul_set_dim_map[hash_key] if isinstance(configs, tuple): if len(configs) > 1 and "bypass" in configs[1]: bypass = configs[1]["bypass"] if not (bypass in bypass_list): raise RuntimeError("matmul_cce ony supports %s while bypass is %d" % (",".join(str(bypass_list)), bypass)) def matmul_compute(output_shape, adj_x, adj_y, left_format, right_format, output_format, x, y, k, *indices): N = len(output_shape) # reduce axis ko = akg.tvm.reduce_axis((0, k // cce.BLOCK_REDUCE), name='ko') ki = akg.tvm.reduce_axis((0, cce.BLOCK_REDUCE), name='ki') if output_format == "zN": if left_format == "zZ": x_indices = indices[:(N - 4)] + indices[(N - 3):(N - 2)] + (ko,) + indices[(N - 2):(N - 1)] + (ki,) if adj_x: x_indices = indices[:(N - 4)] + (ko,) + indices[(N - 3):(N - 2)] + (ki,) + indices[(N - 2):(N - 1)] elif left_format == "zN": x_indices = indices[:(N - 4)] + (ko,) + indices[(N - 3):(N - 2)] + indices[(N-2):(N-1)] + (ki,) if adj_x: x_indices = indices[:(N - 4)] + indices[(N - 3):(N - 2)] + (ko,) + (ki,) + indices[(N-2):(N-1)] if right_format == "nZ": y_indices = indices[:(N - 4)] + (ko, ) + indices[(N - 4):(N - 3)] + indices[(N - 1):] + (ki,) if adj_y: y_indices = indices[:(N - 4)] + indices[(N - 4):(N - 3)] + (ko, ki) + indices[(N - 1):] elif right_format == "zZ": y_indices = indices[:(N - 4)] + (ko, ) + indices[(N - 4):(N - 3)] + (ki,) + indices[(N - 1):] if adj_y: y_indices = indices[:(N - 4)] + indices[(N - 4):(N - 3)] + (ko,) + indices[(N - 1):] + (ki,) elif right_format == "zN": y_indices = indices[:(N - 4)] + indices[(N - 4):(N - 3)] + (ko,) + (ki,) + indices[(N - 1):] if adj_y: y_indices = indices[:(N - 4)] + (ko,) + indices[(N - 4):(N - 3)] + indices[(N - 1):] + (ki,) return akg.lang.cce.mmad((x(*x_indices) * y(*y_indices)).astype("float32"), axis=[ko, ki]) if left_format == "zZ": data_trans = "N" data_trans_block = "N" data_trans_block_in = "N" if transpose_x: data_trans = "Y" elif left_format == "zN": data_trans = "Y" data_trans_block = "Y" data_trans_block_in = "N" if transpose_x: data_trans = "Y" data_trans_block = "N" data_trans_block_in = "Y" if right_format == "nZ": weight_trans = "N" weight_trans_block = "N" weight_trans_block_in = "N" if transpose_y: weight_trans = "Y" elif right_format == "zZ": if not transpose_y: weight_trans_block_in = "Y" weight_trans_block = "N" weight_trans = "Y" elif transpose_y: weight_trans = "Y" weight_trans_block = "Y" weight_trans_block_in = "N" elif right_format == "zN": weight_trans = "Y" weight_trans_block = "N" weight_trans_block_in = "N" if transpose_y: weight_trans = "N" weight_trans_block = "N" weight_trans_block_in = "N" result_matmul = akg.tvm.compute(output_shape_zN, lambda *indices: matmul_compute(output_shape_zN, transpose_x, transpose_y, left_format, right_format, "zN", x, y, k, *indices), name="resMatmul", attrs={ "pragma_gemm_data": x.name, "pragma_data_transpose": data_trans, "pragma_data_transpose_block": data_trans_block, "pragma_data_transpose_block_inner": data_trans_block_in, "pragma_gemm_weight": y.name, "pragma_weight_transpose": weight_trans, "pragma_weight_transpose_block": weight_trans_block, "pragma_weight_transpose_block_inner": weight_trans_block_in, "pragma_conv_bypass_l1": bypass, "bias": bias_name, }) if out_dtype == "float16": result_matmul = cast.cast(result_matmul, out_dtype) def matmul_reshape(shape, result_matmul, *indices): N = len(shape) new_indices = indices[:(N - 4)] + indices[(N - 3):(N - 2)] + indices[(N - 4):(N - 3)] + indices[(N - 2):] return result_matmul(*new_indices) if out_format == "zZ": result = akg.tvm.compute(output_shape_zZ, lambda *indices: matmul_reshape(output_shape_zZ, result_matmul, *indices), name="result") else: result = result_matmul def bias_compute(output_shape, result, bias, output_format, *indices): N = len(output_shape) # reduce axis if output_format == "zN": bias_indices = indices[:(N - 4)] + indices[(N - 4):(N - 3)] + (0, 0) + indices[(N - 1):] elif output_format == "zZ": bias_indices = indices[:(N - 4)] + (0,) + indices[(N - 3):(N - 2)] + (0,) + indices[(N - 1):] return result(*indices) + bias(*bias_indices) if bias == 1: if out_format == "zN": out = akg.tvm.compute(output_shape_zN, lambda *indices: bias_compute(output_shape_zN, result, bias_value, out_format, *indices), name="output") elif out_format == "zZ": out = akg.tvm.compute(output_shape_zZ, lambda *indices: bias_compute(output_shape_zZ, result, bias_value, out_format, *indices), name="output") else: out = result return out @ct_util.reg_set_dim_func(matmul_set_dim) def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format="zN", transpose_x=False, transpose_y=False, has_bias=False, attrs=None): """ Computes matrix multiplication x * y + b. Args: x: akg.tvm.Tensor of type int8, uint8, float16, float32, int32. Left matrix. y: akg.tvm.Tensor of same type as x. Right matrix. b: akg.tvm.Tensor of same type as x. Bias tensor. out_dtype: str. Data type of output tensor. left_format: str. Data format of left matrix. Supported data format list ["zZ", "nZ", "zN"]. right_format: str. Data format of right matrix. Supported data format list ["zZ", "nZ", "zN"]. out_format: str. Data format of output tensor. Supported data format list ["zZ", "nZ", "zN"]. transpose_x: Boolean. Specifies whether x is transposed or not. transpose_y: Boolean. Specifies whether y is transposed or not. has_bias: Boolean. Specifies whether bias tensor exists or not. attrs: Dict. Used in matmul computation. Note: before call matmul, 2d to Fractal is needed. Returns: akg.tvm.Tensor with type out_dtype. """ vc_util.ops_dtype_check([x.dtype, y.dtype], vc_util.DtypeForDavinci.ALL_FLOAT) shape_x = [shape_element.value for shape_element in x.shape] vc_util.check_shape(shape_x) shape_y = [shape_element.value for shape_element in y.shape] vc_util.check_shape(shape_y) if left_format not in ["zZ", "zN"]: raise ValueError("unsupport left_format now: %s" % left_format) if right_format not in ["nZ", "zZ", "zN"]: raise ValueError("unsupport right_format now: %s" % right_format) if out_format not in ["zN", "zZ"]: raise ValueError("unsupport out_format now: %s" % out_format) out = matmul4D_compute(x, y, b, out_dtype, left_format, right_format, out_format, transpose_x, transpose_y, attrs) attr_map = {"pragma_rmselfdep": False} return out, attr_map