提交 2ad9c1bd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!117 matmul add bert testcases and tunning setdim

Merge pull request !117 from looop5/matmul_bert
......@@ -77,6 +77,11 @@ matmul_set_dim_map = {
str(((16, 16, 32, 32, 16, 16), (16, 16, 4, 32, 16, 16), 0, "zN", "zN", "zN", True, False, "float16")) : ([(1,1),(1,1),(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}),
# (16, 16, 512, 64), (16, 16, 512, 64)
str(((16, 16, 4, 32, 16, 16), (16, 16, 4, 32, 16, 16), 0, "zN", "zN", "zN", False, True, "float16")) : ([(1,1),(1,1),(32,4),(32,32),(16,16),(16,16),(4,4)], {"bypass" : 0}),
# (24, 16, 512, 512), (24, 16, 512, 64)
str(((24, 16, 32, 32, 16, 16), (24, 16, 4, 32, 16, 16), 0, 'zN', 'zN', 'zN', False, False, 'float16')) : ([(1,1),(1,1),(4,4),(8,1),(16,16),(16,16),(32,32)], {"bypass" : 0}),
str(((24, 16, 32, 32, 16, 16), (24, 16, 4, 32, 16, 16), 0, 'zN', 'zN', 'zN', True, False, 'float16')) : ([(1,1),(1,1),(4,4),(32,32),(16,16),(16,16),(8,1)], {"bypass" : 0}),
# (24, 16, 512, 64), (24, 16, 512, 64)
str(((24, 16, 4, 32, 16, 16), (24, 16, 4, 32, 16, 16), 0, 'zN', 'zN', 'zN', False, True, 'float16')) : ([(1,1),(1,1),(32,16),(32,16),(16,16),(16,16),(4,2)], {"bypass" : 0}),
}
......
......@@ -43,6 +43,16 @@ class TestCase(TestBase):
# [16, 16, 4, 32, 16, 16] * [16, 16, 4, 32, 16, 16] -> [16, 16, 32, 32, 16, 16]
("batchmatmul_run_bert_02", "matmul_run", ((16, 16, 512, 64), (16, 16, 512, 64), 0,
"zN", "zN", "zN", False, True, "float16", None, "float16", "batchmatmul_cce")),
# [24, 16, 32, 32, 16, 16] * [24, 16, 4, 32, 16, 16] -> [24, 16, 4, 32, 16, 16]
("batchmatmul_run_bert_00", "matmul_run", ((24, 16, 512, 512), (24, 16, 512, 64), 0,
"zN", "zN", "zN", False, False, "float16", None, "float16", "batchmatmul_cce")),
# [24, 16, 32, 32, 16, 16] * [24, 16, 4, 32, 16, 16] -> [24, 16, 4, 32, 16, 16]
("batchmatmul_run_bert_01", "matmul_run", ((24, 16, 512, 512), (24, 16, 512, 64), 0,
"zN", "zN", "zN", True, False, "float16", None, "float16", "batchmatmul_cce")),
# [24, 16, 4, 32, 16, 16] * [24, 16, 4, 32, 16, 16] -> [24, 16, 32, 32, 16, 16]
("batchmatmul_run_bert_02", "matmul_run", ((24, 16, 512, 64), (24, 16, 512, 64), 0,
"zN", "zN", "zN", False, True, "float16", None, "float16", "batchmatmul_cce")),
# bert shape
("matmul_run_bert_00", "matmul_run", ((16, 1024), (16, 1024), 0, "zN", "zN", "zN", False, True, "float16", None, "float16", "matmul_cce")),
("matmul_run_bert_01", "matmul_run", ((8192, 4096), (8192, 1024), 0, "zN", "zN", "zN", True, False, "float16", None, "float32", "matmul_cce")),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册