diff --git a/tests/common/test_op/matmul.py b/tests/common/test_op/matmul.py index 0874d2aa4b814d93513930936b97d317e62280e6..7f4bc144d39ed57212627e75622340911747bb77 100644 --- a/tests/common/test_op/matmul.py +++ b/tests/common/test_op/matmul.py @@ -29,7 +29,49 @@ matmul_set_dim_map = { str(((1, 16, 49, 16, 16), (1, 49, 49, 16, 16), 0, 'zZ', 'zZ', 'zZ', False, False, 'float16', 'float32')) : ([(1,1),(2,2),(16,16),(16,16),(49,49)], {"bypass" : 0}), str(((1, 16, 49, 16, 16), (1, 49, 16, 16, 16), 0, 'zZ', 'zZ', 'zZ', False, False, 'float16', 'float32')) : ([(2,2),(2,2),(16,16),(16,16),(49,49)], {"bypass" : 0}), str(((1, 2, 64, 16, 16), (1, 2, 64, 16, 16), 0, 'zZ', 'zZ', 'zZ', True, False, 'float16', 'float32')) : ([(2,2),(64,64),(16,16),(16,16),(2,2)], {"bypass" : 0}), - str(((1, 2, 128, 16, 16), (1, 2, 128, 16, 16), 0, 'zZ', 'zZ', 'zZ', True, False, 'float16', 'float32')) : ([(2,2),(64,64),(16,16),(16,16),(2,2)], {"bypass" : 0}) + str(((1, 2, 128, 16, 16), (1, 2, 128, 16, 16), 0, 'zZ', 'zZ', 'zZ', True, False, 'float16', 'float32')) : ([(2,2),(64,64),(16,16),(16,16),(2,2)], {"bypass" : 0}), + + # bert best tile + # (16, 1024), (16, 1024) + str(((1, 64, 1, 16, 16), (1, 64, 1, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(16,16),(16,16),(32,32)], {"bypass" : 2}), + # (8192, 4096), (8192, 1024) + str(((1, 256, 512, 16, 16), (1, 64, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16),(8,1)], {"bypass" : 0}), + # (8192, 1024), (1024, 4096) + str(((1, 64, 512, 16, 16), (1, 256, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(16,16),(8,4),(16,16),(16,16),(64,8)], {"bypass" : 0}), + # (16, 16), (16, 1024) + str(((1, 1, 1, 16, 16), (1, 64, 1, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16)], {"bypass" : 0}), + # (1216, 1024), (1024, 1024) + str(((1, 64, 76, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float32")) : ([(4,4),(19,19),(16,16),(16,16),(4,1)], {"bypass" : 0}), + # (8192, 4096), (4096, 1024) + str(((1, 256, 512, 16 ,16), (1, 64, 256, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(8,8),(32,32),(16,16),(16,16),(2,1)], {"bypass" : 0}), + # (8192, 1024), (4096, 1024) + str(((1, 64, 512, 16, 16), (1, 64, 256, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(32,32),(16,16),(16,16),(2,1)], {"bypass" : 0}), + # (8192, 1024), (8192, 4096) + str(((1, 64, 512, 16, 16), (1, 256, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float16")) : ([[8, 8], [32, 32], [16, 16], [16, 16], [16, 2]], {"bypass": 0}), + # (1216, 1024), (1024, 1024) + str(((1, 64, 76, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(19,19),(16,16),(16,16),(16,1)], {"bypass" : 2}), + # (8192, 1024), (1024, 1024) + str(((1, 64, 512, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(16,4),(16,8),(16,16),(16,16),(64,16)], {"bypass" : 0}), + # (1216, 30522), (30522, 1024) + str(((1, 1908, 76, 16, 16), (1, 64, 1908, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(8,8),(19,19),(16,16),(16,16),(6,1)], {"bypass" : 0}), + # (1216, 30522), (1216, 1024) + str(((1, 1908, 76, 16, 16), (1, 64, 76, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(4,4),(18,18),(16,16),(16,16),(2,2)], {"bypass" : 0}), + # (1216, 1024), (30522, 1024) + str(((1, 64, 76, 16, 16), (1, 64, 1908, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float32")) : ([(9,9),(19,19),(16,16),(16,16),(64,1)], {"bypass" : 0}), + # (8192, 1024), (8192, 1024) + str(((1, 64, 512, 16, 16), (1, 64, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}), + # (1216, 1024), (1216, 1024) + str(((1, 64, 76, 16, 16), (1, 64, 76, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float16")) : ([(16,16),(8,8),(16,16),(16,16),(4,2)], {"bypass" : 0}), + # (16, 1024), (16, 1024) + str(((1, 64, 1, 16, 16), (1, 64, 1, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(2,2),(16,16),(16,16),(16,16)], {"bypass" : 0}), + # (16, 1024), (1024, 1024) + str(((1, 64, 1, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float32")) : ([(8,8),(16,16),(16,16),(32,8)], {"bypass" : 2}), + # (16, 16), (16, 1024) + str(((1, 1, 1, 16, 16), (1, 64, 1, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16)], {"bypass" : 0}), + # (8192, 1024), (1024, 1024) + str(((1, 64, 512, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(16,8),(8,8),(16,16),(16,16),(64,8)], {"bypass" : 1}), + # (8192, 4096), (1024, 4096) + str(((1, 256, 512, 16, 16), (1, 256, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(16,16),(16,16),(16,16),(128,8)], {"bypass" : 1}), } @@ -38,7 +80,7 @@ def matmul_set_dim(A, B, b, out_dtype, left_format, right_format, output_format, shape_B = B.shape bias = 0 if b is None else 1 key = () - + key += (tuple(shape_A), tuple(shape_B), bias, left_format, right_format, output_format, adj_x, adj_y, A.dtype, out_dtype) hash_key = str(key) if hash_key in matmul_set_dim_map: @@ -121,7 +163,7 @@ def matmul4D_compute(x, y, bias_value, out_dtype, left_format, right_format, out shape_A = x.shape shape_B = y.shape key = () - + key += (tuple(shape_A), tuple(shape_B), bias, left_format, right_format, out_format, transpose_x, transpose_y, x.dtype, out_dtype) hash_key = str(key) # bypass 2 left matrix ddr -> l0 @@ -271,7 +313,7 @@ def matmul(x, y, b, out_dtype, left_format="zZ", right_format="nZ", out_format=" left_format: str. Data format of left matrix. Supported data format list ["zZ", "nZ", "zN"]. right_format: str. Data format of right matrix. Supported data format list ["zZ", "nZ", "zN"]. out_format: str. Data format of output tensor. Supported data format list ["zZ", "nZ", "zN"]. - transpose_x: Boolean. Specifies whether x is transposed or not. + transpose_x: Boolean. Specifies whether x is transposed or not. transpose_y: Boolean. Specifies whether y is transposed or not. has_bias: Boolean. Specifies whether bias tensor exists or not. attrs: Dict. Used in matmul computation. diff --git a/tests/operators/cube/test_matmul_001.py b/tests/operators/cube/test_matmul_001.py index d4118d41afcfd6473c3ba94f6f9f2a7381ed5bdb..9b2b64450bfb917f3206ade2223c505c766b26a3 100644 --- a/tests/operators/cube/test_matmul_001.py +++ b/tests/operators/cube/test_matmul_001.py @@ -32,6 +32,29 @@ class TestCase(TestBase): self.testarg = [ # caseflag,opfuncname,testRunArgs, dimArgs # shape_x, shape_y, bias, left_format, right_format, output_format, adj_x, adj_y, dtype, out_dtype, kernel_name, attrs + + # bert shape + ("matmul_run_bert_00", "matmul_run", ((16, 1024), (16, 1024), 0, "zN", "zN", "zN", False, True, "float16", "float16", "matmul_cce")), + ("matmul_run_bert_01", "matmul_run", ((8192, 4096), (8192, 1024), 0, "zN", "zN", "zN", True, False, "float16", "float32", "matmul_cce")), + ("matmul_run_bert_02", "matmul_run", ((8192, 1024), (1024, 4096), 0, "zN", "zN", "zN", False, False, "float16", "float16", "matmul_cce")), + ("matmul_run_bert_03", "matmul_run", ((16, 16), (16, 1024), 0, "zN", "zN", "zN", True, False, "float16", "float32", "matmul_cce")), + ("matmul_run_bert_04", "matmul_run", ((1216, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, False, "float16", "float32", "matmul_cce")), + ("matmul_run_bert_05", "matmul_run", ((8192, 4096), (4096, 1024), 0, "zN", "zN", "zN", False, False, "float16", "float16", "matmul_cce")), + ("matmul_run_bert_06", "matmul_run", ((8192, 1024), (4096, 1024), 0, "zN", "zN", "zN", False, True, "float16", "float16", "matmul_cce")), + ("matmul_run_bert_07", "matmul_run", ((8192, 1024), (8192, 4096), 0, "zN", "zN", "zN", True, False, "float16", "float16", "matmul_cce")), + ("matmul_run_bert_08", "matmul_run", ((1216, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, True, "float16", "float16", "matmul_cce")), + ("matmul_run_bert_09", "matmul_run", ((8192, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, False, "float16", "float16", "matmul_cce")), + ("matmul_run_bert_10", "matmul_run", ((1216, 30522), (30522, 1024), 0, "zN", "zN", "zN", False, False, "float16", "float16", "matmul_cce")), + ("matmul_run_bert_11", "matmul_run", ((1216, 30522), (1216, 1024), 0, "zN", "zN", "zN", True, False, "float16", "float32", "matmul_cce")), + ("matmul_run_bert_12", "matmul_run", ((1216, 1024), (30522, 1024), 0, "zN", "zN", "zN", False, True, "float16", "float32", "matmul_cce")), + ("matmul_run_bert_13", "matmul_run", ((8192, 1024), (8192, 1024), 0, "zN", "zN", "zN", True, False, "float16", "float32", "matmul_cce")), + ("matmul_run_bert_14", "matmul_run", ((1216, 1024), (1216, 1024), 0, "zN", "zN", "zN", True, False, "float16", "float16", "matmul_cce")), + ("matmul_run_bert_15", "matmul_run", ((16, 1024), (16, 1024), 0, "zN", "zN", "zN", True, False, "float16", "float32", "matmul_cce")), + ("matmul_run_bert_16", "matmul_run", ((16, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, True, "float16", "float32", "matmul_cce")), + ("matmul_run_bert_17", "matmul_run", ((16, 16), (16, 1024), 0, "zN", "zN", "zN", False, False, "float16", "float32", "matmul_cce")), + ("matmul_run_bert_18", "matmul_run", ((8192, 1024), (1024, 1024), 0, "zN", "zN", "zN", False, True, "float16", "float16", "matmul_cce")), + ("matmul_run_bert_19", "matmul_run", ((8192, 4096), (1024, 4096), 0, "zN", "zN", "zN", False, True, "float16", "float16", "matmul_cce")), + # matmul_cast ("matmul_run1", "matmul_run", ((64, 1024), (16, 1024), 0, "zZ", "nZ", "zN", False, True, "float16", "float32", "matmul_cast_cce")),