diff --git a/python/akg/ops/nn/matmul.py b/python/akg/ops/nn/matmul.py index 7e5cfd977efca7b4baa3ea49692247464996eee8..e4cab53b7c4ce2ec90ad33a2509670c3af6eb14a 100644 --- a/python/akg/ops/nn/matmul.py +++ b/python/akg/ops/nn/matmul.py @@ -26,57 +26,57 @@ from akg.ops.math import cast logging.basicConfig(level=logging.DEBUG) matmul_set_dim_map = { - str(((1, 16, 49, 16, 16), (1, 49, 49, 16, 16), 0, 'zZ', 'zZ', 'zZ', False, False, 'float16', 'float32')) : ([(1,1),(2,2),(16,16),(16,16),(49,49)], {"bypass" : 0}), - str(((1, 16, 49, 16, 16), (1, 49, 16, 16, 16), 0, 'zZ', 'zZ', 'zZ', False, False, 'float16', 'float32')) : ([(2,2),(2,2),(16,16),(16,16),(49,49)], {"bypass" : 0}), - str(((1, 2, 64, 16, 16), (1, 2, 64, 16, 16), 0, 'zZ', 'zZ', 'zZ', True, False, 'float16', 'float32')) : ([(2,2),(64,64),(16,16),(16,16),(2,2)], {"bypass" : 0}), - str(((1, 2, 128, 16, 16), (1, 2, 128, 16, 16), 0, 'zZ', 'zZ', 'zZ', True, False, 'float16', 'float32')) : ([(2,2),(64,64),(16,16),(16,16),(2,2)], {"bypass" : 0}), + str(((1, 16, 49, 16, 16), (1, 49, 49, 16, 16), 0, 'zZ', 'zZ', 'zZ', False, False, 'float16')) : ([(1,1),(2,2),(16,16),(16,16),(49,49)], {"bypass" : 0}), + str(((1, 16, 49, 16, 16), (1, 49, 16, 16, 16), 0, 'zZ', 'zZ', 'zZ', False, False, 'float16')) : ([(2,2),(2,2),(16,16),(16,16),(49,49)], {"bypass" : 0}), + str(((1, 2, 64, 16, 16), (1, 2, 64, 16, 16), 0, 'zZ', 'zZ', 'zZ', True, False, 'float16')) : ([(2,2),(64,64),(16,16),(16,16),(2,2)], {"bypass" : 0}), + str(((1, 2, 128, 16, 16), (1, 2, 128, 16, 16), 0, 'zZ', 'zZ', 'zZ', True, False, 'float16')) : ([(2,2),(64,64),(16,16),(16,16),(2,2)], {"bypass" : 0}), # bert best tile # (16, 1024), (16, 1024) - str(((64, 1, 16, 16), (64, 1, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(16,16),(16,16),(32,32)], {"bypass" : 2}), + str(((64, 1, 16, 16), (64, 1, 16, 16), 0, "zN", "zN", "zN", False, True, "float16")) : ([(16,16),(16,16),(32,32)], {"bypass" : 2}), # (8192, 4096), (8192, 1024) - str(((256, 512, 16, 16), (64, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16),(8,1)], {"bypass" : 0}), + str(((256, 512, 16, 16), (64, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16")) : ([(8,8),(16,16),(16,16),(16,16),(8,1)], {"bypass" : 0}), # (8192, 1024), (1024, 4096) - str(((64, 512, 16, 16), (256, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(16,16),(8,4),(16,16),(16,16),(64,8)], {"bypass" : 0}), + str(((64, 512, 16, 16), (256, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16")) : ([(16,16),(8,4),(16,16),(16,16),(64,8)], {"bypass" : 0}), # (16, 16), (16, 1024) - str(((1, 1, 16, 16), (64, 1, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16)], {"bypass" : 0}), + str(((1, 1, 16, 16), (64, 1, 16, 16), 0, "zN", "zN", "zN", True, False, "float16")) : ([(8,8),(16,16),(16,16),(16,16)], {"bypass" : 0}), # (1216, 1024), (1024, 1024) - str(((64, 76, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float32")) : ([(4,4),(19,19),(16,16),(16,16),(4,1)], {"bypass" : 0}), + str(((64, 76, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16")) : ([(4,4),(19,19),(16,16),(16,16),(4,1)], {"bypass" : 0}), # (8192, 4096), (4096, 1024) - str(((256, 512, 16 ,16), (64, 256, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(8,8),(32,32),(16,16),(16,16),(2,1)], {"bypass" : 0}), + str(((256, 512, 16 ,16), (64, 256, 16, 16), 0, "zN", "zN", "zN", False, False, "float16")) : ([(8,8),(32,32),(16,16),(16,16),(2,1)], {"bypass" : 0}), # (8192, 1024), (4096, 1024) - str(((64, 512, 16, 16), (64, 256, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(32,32),(16,16),(16,16),(2,1)], {"bypass" : 0}), + str(((64, 512, 16, 16), (64, 256, 16, 16), 0, "zN", "zN", "zN", False, True, "float16")) : ([(8,8),(32,32),(16,16),(16,16),(2,1)], {"bypass" : 0}), # (8192, 1024), (8192, 4096) - str(((64, 512, 16, 16), (256, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float16")) : ([[8, 8], [32, 32], [16, 16], [16, 16], [16, 2]], {"bypass": 0}), + str(((64, 512, 16, 16), (256, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16")) : ([[8, 8], [32, 32], [16, 16], [16, 16], [16, 2]], {"bypass": 0}), # (1216, 1024), (1024, 1024) - str(((64, 76, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(19,19),(16,16),(16,16),(16,1)], {"bypass" : 2}), + str(((64, 76, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16")) : ([(8,8),(19,19),(16,16),(16,16),(16,1)], {"bypass" : 2}), # (8192, 1024), (1024, 1024) - str(((64, 512, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(16,4),(16,8),(16,16),(16,16),(64,16)], {"bypass" : 0}), + str(((64, 512, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16")) : ([(16,4),(16,8),(16,16),(16,16),(64,16)], {"bypass" : 0}), # (1216, 30522), (30522, 1024) - str(((1908, 76, 16, 16), (64, 1908, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(8,8),(19,19),(16,16),(16,16),(6,1)], {"bypass" : 0}), + str(((1908, 76, 16, 16), (64, 1908, 16, 16), 0, "zN", "zN", "zN", False, False, "float16")) : ([(8,8),(19,19),(16,16),(16,16),(6,1)], {"bypass" : 0}), # (1216, 30522), (1216, 1024) - str(((1908, 76, 16, 16), (64, 76, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(4,4),(18,18),(16,16),(16,16),(2,2)], {"bypass" : 0}), + str(((1908, 76, 16, 16), (64, 76, 16, 16), 0, "zN", "zN", "zN", True, False, "float16")) : ([(4,4),(18,18),(16,16),(16,16),(2,2)], {"bypass" : 0}), # (1216, 1024), (30522, 1024) - str(((64, 76, 16, 16), (64, 1908, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float32")) : ([(9,9),(19,19),(16,16),(16,16),(64,1)], {"bypass" : 0}), + str(((64, 76, 16, 16), (64, 1908, 16, 16), 0, "zN", "zN", "zN", False, True, "float16")) : ([(9,9),(19,19),(16,16),(16,16),(64,1)], {"bypass" : 0}), # (8192, 1024), (8192, 1024) - str(((64, 512, 16, 16), (64, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}), + str(((64, 512, 16, 16), (64, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16")) : ([(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}), # (1216, 1024), (1216, 1024) - str(((64, 76, 16, 16), (64, 76, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float16")) : ([(16,16),(8,8),(16,16),(16,16),(4,2)], {"bypass" : 0}), + str(((64, 76, 16, 16), (64, 76, 16, 16), 0, "zN", "zN", "zN", True, False, "float16")) : ([(16,16),(8,8),(16,16),(16,16),(4,2)], {"bypass" : 0}), # (16, 1024), (16, 1024) - str(((64, 1, 16, 16), (64, 1, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(2,2),(16,16),(16,16),(16,16)], {"bypass" : 0}), + str(((64, 1, 16, 16), (64, 1, 16, 16), 0, "zN", "zN", "zN", True, False, "float16")) : ([(8,8),(2,2),(16,16),(16,16),(16,16)], {"bypass" : 0}), # (16, 1024), (1024, 1024) - str(((64, 1, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float32")) : ([(8,8),(16,16),(16,16),(32,8)], {"bypass" : 2}), + str(((64, 1, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16")) : ([(8,8),(16,16),(16,16),(32,8)], {"bypass" : 2}), # (16, 16), (16, 1024) - str(((1, 1, 16, 16), (64, 1, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16)], {"bypass" : 0}), + str(((1, 1, 16, 16), (64, 1, 16, 16), 0, "zN", "zN", "zN", False, False, "float16")) : ([(8,8),(16,16),(16,16),(16,16)], {"bypass" : 0}), # (8192, 1024), (1024, 1024) - str(((64, 512, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(16,8),(8,8),(16,16),(16,16),(64,8)], {"bypass" : 1}), + str(((64, 512, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16")) : ([(16,8),(8,8),(16,16),(16,16),(64,8)], {"bypass" : 1}), # (8192, 4096), (1024, 4096) - str(((256, 512, 16, 16), (256, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(32,32),(16,16),(16,16),(1,1)], {"bypass" : 0}), + str(((256, 512, 16, 16), (256, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16")) : ([(8,8),(32,32),(16,16),(16,16),(1,1)], {"bypass" : 0}), # (16, 16, 512, 512), (16, 16, 512, 64) - str(((16, 16, 32, 32, 16, 16), (16, 16, 4, 32, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(1,1),(1,1),(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}), - str(((16, 16, 32, 32, 16, 16), (16, 16, 4, 32, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float16")) : ([(1,1),(1,1),(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}), + str(((16, 16, 32, 32, 16, 16), (16, 16, 4, 32, 16, 16), 0, "zN", "zN", "zN", False, False, "float16")) : ([(1,1),(1,1),(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}), + str(((16, 16, 32, 32, 16, 16), (16, 16, 4, 32, 16, 16), 0, "zN", "zN", "zN", True, False, "float16")) : ([(1,1),(1,1),(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}), # (16, 16, 512, 64), (16, 16, 512, 64) - str(((16, 16, 4, 32, 16, 16), (16, 16, 4, 32, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(1,1),(1,1),(32,4),(32,32),(16,16),(16,16),(4,4)], {"bypass" : 0}), + str(((16, 16, 4, 32, 16, 16), (16, 16, 4, 32, 16, 16), 0, "zN", "zN", "zN", False, True, "float16")) : ([(1,1),(1,1),(32,4),(32,32),(16,16),(16,16),(4,4)], {"bypass" : 0}), } @@ -86,7 +86,7 @@ def matmul_set_dim(A, B, b, out_dtype, left_format, right_format, output_format, bias = 0 if b is None else 1 key = () - key += (tuple(shape_A), tuple(shape_B), bias, left_format, right_format, output_format, adj_x, adj_y, A.dtype, out_dtype) + key += (tuple(shape_A), tuple(shape_B), bias, left_format, right_format, output_format, adj_x, adj_y, A.dtype) hash_key = str(key) if hash_key in matmul_set_dim_map: configs = matmul_set_dim_map[hash_key] @@ -169,7 +169,7 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out shape_B = y.shape key = () - key += (tuple(shape_A), tuple(shape_B), bias, left_format, right_format, out_format, transpose_x, transpose_y, x.dtype, out_dtype) + key += (tuple(shape_A), tuple(shape_B), bias, left_format, right_format, out_format, transpose_x, transpose_y, x.dtype) hash_key = str(key) # bypass 2 left matrix ddr -> l0 # bypass 1 right matrix ddr -> l0 @@ -199,7 +199,7 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out elif left_format == "zN": x_indices = indices[:(N - 4)] + (ko,) + indices[(N - 3):(N - 2)] + indices[(N-2):(N-1)] + (ki,) if adj_x: - x_indices = indices[:(N - 4)] + indices[(N - 3):(N - 2)] + (ko,) + (ki,) + indices[(N-2):(N-1)] + x_indices = indices[:(N - 4)] + indices[(N - 3):(N - 2)] + (ko,) + (ki,) + indices[(N-2):(N-1)] if right_format == "nZ": y_indices = indices[:(N - 4)] + (ko, ) + indices[(N - 4):(N - 3)] + indices[(N - 1):] + (ki,)