From 68b3ccfbcf17e3e86bd3c4c670224f67c7206d40 Mon Sep 17 00:00:00 2001 From: looop5 Date: Tue, 25 Aug 2020 11:51:58 +0800 Subject: [PATCH] matmul add bert testcases and tunning setdim --- python/akg/ops/nn/matmul.py | 5 +++++ tests/operators/cube/test_matmul_001.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/python/akg/ops/nn/matmul.py b/python/akg/ops/nn/matmul.py index 7094ce7..3198f44 100644 --- a/python/akg/ops/nn/matmul.py +++ b/python/akg/ops/nn/matmul.py @@ -77,6 +77,11 @@ matmul_set_dim_map = { str(((16, 16, 32, 32, 16, 16), (16, 16, 4, 32, 16, 16), 0, "zN", "zN", "zN", True, False, "float16")) : ([(1,1),(1,1),(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}), # (16, 16, 512, 64), (16, 16, 512, 64) str(((16, 16, 4, 32, 16, 16), (16, 16, 4, 32, 16, 16), 0, "zN", "zN", "zN", False, True, "float16")) : ([(1,1),(1,1),(32,4),(32,32),(16,16),(16,16),(4,4)], {"bypass" : 0}), + # (24, 16, 512, 512), (24, 16, 512, 64) + str(((24, 16, 32, 32, 16, 16), (24, 16, 4, 32, 16, 16), 0, 'zN', 'zN', 'zN', False, False, 'float16')) : ([(1,1),(1,1),(4,4),(8,1),(16,16),(16,16),(32,32)], {"bypass" : 0}), + str(((24, 16, 32, 32, 16, 16), (24, 16, 4, 32, 16, 16), 0, 'zN', 'zN', 'zN', True, False, 'float16')) : ([(1,1),(1,1),(4,4),(32,32),(16,16),(16,16),(8,1)], {"bypass" : 0}), + # (24, 16, 512, 64), (24, 16, 512, 64) + str(((24, 16, 4, 32, 16, 16), (24, 16, 4, 32, 16, 16), 0, 'zN', 'zN', 'zN', False, True, 'float16')) : ([(1,1),(1,1),(32,16),(32,16),(16,16),(16,16),(4,2)], {"bypass" : 0}), } diff --git a/tests/operators/cube/test_matmul_001.py b/tests/operators/cube/test_matmul_001.py index c956774..a85495b 100644 --- a/tests/operators/cube/test_matmul_001.py +++ b/tests/operators/cube/test_matmul_001.py @@ -43,6 +43,16 @@ class TestCase(TestBase): # [16, 16, 4, 32, 16, 16] * [16, 16, 4, 32, 16, 16] -> [16, 16, 32, 32, 16, 16] ("batchmatmul_run_bert_02", "matmul_run", ((16, 16, 512, 64), (16, 16, 512, 64), 0, "zN", "zN", "zN", False, True, "float16", None, "float16", "batchmatmul_cce")), + # [24, 16, 32, 32, 16, 16] * [24, 16, 4, 32, 16, 16] -> [24, 16, 4, 32, 16, 16] + ("batchmatmul_run_bert_00", "matmul_run", ((24, 16, 512, 512), (24, 16, 512, 64), 0, + "zN", "zN", "zN", False, False, "float16", None, "float16", "batchmatmul_cce")), + # [24, 16, 32, 32, 16, 16] * [24, 16, 4, 32, 16, 16] -> [24, 16, 4, 32, 16, 16] + ("batchmatmul_run_bert_01", "matmul_run", ((24, 16, 512, 512), (24, 16, 512, 64), 0, + "zN", "zN", "zN", True, False, "float16", None, "float16", "batchmatmul_cce")), + # [24, 16, 4, 32, 16, 16] * [24, 16, 4, 32, 16, 16] -> [24, 16, 32, 32, 16, 16] + ("batchmatmul_run_bert_02", "matmul_run", ((24, 16, 512, 64), (24, 16, 512, 64), 0, + "zN", "zN", "zN", False, True, "float16", None, "float16", "batchmatmul_cce")), + # bert shape ("matmul_run_bert_00", "matmul_run", ((16, 1024), (16, 1024), 0, "zN", "zN", "zN", False, True, "float16", None, "float16", "matmul_cce")), ("matmul_run_bert_01", "matmul_run", ((8192, 4096), (8192, 1024), 0, "zN", "zN", "zN", True, False, "float16", None, "float32", "matmul_cce")), -- GitLab