提交 2fb61524 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!12 matmul op input tensor fixed

Merge pull request !12 from hiranoaya/format
......@@ -33,51 +33,51 @@ matmul_set_dim_map = {
# bert best tile
# (16, 1024), (16, 1024)
str(((1, 64, 1, 16, 16), (1, 64, 1, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(16,16),(16,16),(32,32)], {"bypass" : 2}),
str(((64, 1, 16, 16), (64, 1, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(16,16),(16,16),(32,32)], {"bypass" : 2}),
# (8192, 4096), (8192, 1024)
str(((1, 256, 512, 16, 16), (1, 64, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16),(8,1)], {"bypass" : 0}),
str(((256, 512, 16, 16), (64, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16),(8,1)], {"bypass" : 0}),
# (8192, 1024), (1024, 4096)
str(((1, 64, 512, 16, 16), (1, 256, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(16,16),(8,4),(16,16),(16,16),(64,8)], {"bypass" : 0}),
str(((64, 512, 16, 16), (256, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(16,16),(8,4),(16,16),(16,16),(64,8)], {"bypass" : 0}),
# (16, 16), (16, 1024)
str(((1, 1, 1, 16, 16), (1, 64, 1, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16)], {"bypass" : 0}),
str(((1, 1, 16, 16), (64, 1, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16)], {"bypass" : 0}),
# (1216, 1024), (1024, 1024)
str(((1, 64, 76, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float32")) : ([(4,4),(19,19),(16,16),(16,16),(4,1)], {"bypass" : 0}),
str(((64, 76, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float32")) : ([(4,4),(19,19),(16,16),(16,16),(4,1)], {"bypass" : 0}),
# (8192, 4096), (4096, 1024)
str(((1, 256, 512, 16 ,16), (1, 64, 256, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(8,8),(32,32),(16,16),(16,16),(2,1)], {"bypass" : 0}),
str(((256, 512, 16 ,16), (64, 256, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(8,8),(32,32),(16,16),(16,16),(2,1)], {"bypass" : 0}),
# (8192, 1024), (4096, 1024)
str(((1, 64, 512, 16, 16), (1, 64, 256, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(32,32),(16,16),(16,16),(2,1)], {"bypass" : 0}),
str(((64, 512, 16, 16), (64, 256, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(32,32),(16,16),(16,16),(2,1)], {"bypass" : 0}),
# (8192, 1024), (8192, 4096)
str(((1, 64, 512, 16, 16), (1, 256, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float16")) : ([[8, 8], [32, 32], [16, 16], [16, 16], [16, 2]], {"bypass": 0}),
str(((64, 512, 16, 16), (256, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float16")) : ([[8, 8], [32, 32], [16, 16], [16, 16], [16, 2]], {"bypass": 0}),
# (1216, 1024), (1024, 1024)
str(((1, 64, 76, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(19,19),(16,16),(16,16),(16,1)], {"bypass" : 2}),
str(((64, 76, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(19,19),(16,16),(16,16),(16,1)], {"bypass" : 2}),
# (8192, 1024), (1024, 1024)
str(((1, 64, 512, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(16,4),(16,8),(16,16),(16,16),(64,16)], {"bypass" : 0}),
str(((64, 512, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(16,4),(16,8),(16,16),(16,16),(64,16)], {"bypass" : 0}),
# (1216, 30522), (30522, 1024)
str(((1, 1908, 76, 16, 16), (1, 64, 1908, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(8,8),(19,19),(16,16),(16,16),(6,1)], {"bypass" : 0}),
str(((1908, 76, 16, 16), (64, 1908, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float16")) : ([(8,8),(19,19),(16,16),(16,16),(6,1)], {"bypass" : 0}),
# (1216, 30522), (1216, 1024)
str(((1, 1908, 76, 16, 16), (1, 64, 76, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(4,4),(18,18),(16,16),(16,16),(2,2)], {"bypass" : 0}),
str(((1908, 76, 16, 16), (64, 76, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(4,4),(18,18),(16,16),(16,16),(2,2)], {"bypass" : 0}),
# (1216, 1024), (30522, 1024)
str(((1, 64, 76, 16, 16), (1, 64, 1908, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float32")) : ([(9,9),(19,19),(16,16),(16,16),(64,1)], {"bypass" : 0}),
str(((64, 76, 16, 16), (64, 1908, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float32")) : ([(9,9),(19,19),(16,16),(16,16),(64,1)], {"bypass" : 0}),
# (8192, 1024), (8192, 1024)
str(((1, 64, 512, 16, 16), (1, 64, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}),
str(((64, 512, 16, 16), (64, 512, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(4,4),(16,16),(16,16),(16,16),(16,4)], {"bypass" : 0}),
# (1216, 1024), (1216, 1024)
str(((1, 64, 76, 16, 16), (1, 64, 76, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float16")) : ([(16,16),(8,8),(16,16),(16,16),(4,2)], {"bypass" : 0}),
str(((64, 76, 16, 16), (64, 76, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float16")) : ([(16,16),(8,8),(16,16),(16,16),(4,2)], {"bypass" : 0}),
# (16, 1024), (16, 1024)
str(((1, 64, 1, 16, 16), (1, 64, 1, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(2,2),(16,16),(16,16),(16,16)], {"bypass" : 0}),
str(((64, 1, 16, 16), (64, 1, 16, 16), 0, "zN", "zN", "zN", True, False, "float16", "float32")) : ([(8,8),(2,2),(16,16),(16,16),(16,16)], {"bypass" : 0}),
# (16, 1024), (1024, 1024)
str(((1, 64, 1, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float32")) : ([(8,8),(16,16),(16,16),(32,8)], {"bypass" : 2}),
str(((64, 1, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float32")) : ([(8,8),(16,16),(16,16),(32,8)], {"bypass" : 2}),
# (16, 16), (16, 1024)
str(((1, 1, 1, 16, 16), (1, 64, 1, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16)], {"bypass" : 0}),
str(((1, 1, 16, 16), (64, 1, 16, 16), 0, "zN", "zN", "zN", False, False, "float16", "float32")) : ([(8,8),(16,16),(16,16),(16,16)], {"bypass" : 0}),
# (8192, 1024), (1024, 1024)
str(((1, 64, 512, 16, 16), (1, 64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(16,8),(8,8),(16,16),(16,16),(64,8)], {"bypass" : 1}),
str(((64, 512, 16, 16), (64, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(16,8),(8,8),(16,16),(16,16),(64,8)], {"bypass" : 1}),
# (8192, 4096), (1024, 4096)
str(((1, 256, 512, 16, 16), (1, 256, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(16,16),(16,16),(16,16),(128,8)], {"bypass" : 1}),
str(((256, 512, 16, 16), (256, 64, 16, 16), 0, "zN", "zN", "zN", False, True, "float16", "float16")) : ([(8,8),(32,32),(16,16),(16,16),(1,1)], {"bypass" : 0}),
}
def matmul_set_dim(A, B, b, out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs):
shape_A = A.shape
shape_B = B.shape
shape_A = A.shape[1:5] if len(A.shape) == 5 else A.shape
shape_B = B.shape[1:5] if len(B.shape) == 5 else B.shape
bias = 0 if b is None else 1
key = ()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册