diff --git a/python/akg/ops/nn/matmul.py b/python/akg/ops/nn/matmul.py index 7094ce7236750f95e552a49cc0baae2d60c41737..3198f44c5159df971788268f7261f5e3aecb47d3 100644 --- a/python/akg/ops/nn/matmul.py +++ b/python/akg/ops/nn/matmul.py @@ -77,6 +77,11 @@ matmul_set_dim_map = { str(((16, 16, 32, 32, 16, 16), (16, 16, 4, 32, 16, 16), 0, "zN", "zN", "zN", True, False, "float16")) : ([(1,1),(1,1),(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}), # (16, 16, 512, 64), (16, 16, 512, 64) str(((16, 16, 4, 32, 16, 16), (16, 16, 4, 32, 16, 16), 0, "zN", "zN", "zN", False, True, "float16")) : ([(1,1),(1,1),(32,4),(32,32),(16,16),(16,16),(4,4)], {"bypass" : 0}), + # (24, 16, 512, 512), (24, 16, 512, 64) + str(((24, 16, 32, 32, 16, 16), (24, 16, 4, 32, 16, 16), 0, 'zN', 'zN', 'zN', False, False, 'float16')) : ([(1,1),(1,1),(4,4),(8,1),(16,16),(16,16),(32,32)], {"bypass" : 0}), + str(((24, 16, 32, 32, 16, 16), (24, 16, 4, 32, 16, 16), 0, 'zN', 'zN', 'zN', True, False, 'float16')) : ([(1,1),(1,1),(4,4),(32,32),(16,16),(16,16),(8,1)], {"bypass" : 0}), + # (24, 16, 512, 64), (24, 16, 512, 64) + str(((24, 16, 4, 32, 16, 16), (24, 16, 4, 32, 16, 16), 0, 'zN', 'zN', 'zN', False, True, 'float16')) : ([(1,1),(1,1),(32,16),(32,16),(16,16),(16,16),(4,2)], {"bypass" : 0}), } diff --git a/tests/operators/cube/test_matmul_001.py b/tests/operators/cube/test_matmul_001.py index c9567740c8dfe2e67fa003a893c4527c4cfc5894..a85495b34c5dca456ed2115d04c95ccb57d2c443 100644 --- a/tests/operators/cube/test_matmul_001.py +++ b/tests/operators/cube/test_matmul_001.py @@ -43,6 +43,16 @@ class TestCase(TestBase): # [16, 16, 4, 32, 16, 16] * [16, 16, 4, 32, 16, 16] -> [16, 16, 32, 32, 16, 16] ("batchmatmul_run_bert_02", "matmul_run", ((16, 16, 512, 64), (16, 16, 512, 64), 0, "zN", "zN", "zN", False, True, "float16", None, "float16", "batchmatmul_cce")), + # [24, 16, 32, 32, 16, 16] * [24, 16, 4, 32, 16, 16] -> [24, 16, 4, 32, 16, 16] + ("batchmatmul_run_bert_00", "matmul_run", ((24, 16, 512, 512), (24, 16, 512, 64), 0, + "zN", "zN", "zN", False, False, "float16", None, "float16", "batchmatmul_cce")), + # [24, 16, 32, 32, 16, 16] * [24, 16, 4, 32, 16, 16] -> [24, 16, 4, 32, 16, 16] + ("batchmatmul_run_bert_01", "matmul_run", ((24, 16, 512, 512), (24, 16, 512, 64), 0, + "zN", "zN", "zN", True, False, "float16", None, "float16", "batchmatmul_cce")), + # [24, 16, 4, 32, 16, 16] * [24, 16, 4, 32, 16, 16] -> [24, 16, 32, 32, 16, 16] + ("batchmatmul_run_bert_02", "matmul_run", ((24, 16, 512, 64), (24, 16, 512, 64), 0, + "zN", "zN", "zN", False, True, "float16", None, "float16", "batchmatmul_cce")), + # bert shape ("matmul_run_bert_00", "matmul_run", ((16, 1024), (16, 1024), 0, "zN", "zN", "zN", False, True, "float16", None, "float16", "matmul_cce")), ("matmul_run_bert_01", "matmul_run", ((8192, 4096), (8192, 1024), 0, "zN", "zN", "zN", True, False, "float16", None, "float32", "matmul_cce")),