Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4a9895b1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4a9895b1
编写于
9月 05, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
9月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] dist_matmul trans_x or trans_y (#45678)
* dist_matmul trans * update unittest * update cmakelist
上级
a12c806f
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
478 addition
and
47 deletion
+478
-47
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+103
-47
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_dist_matmul.py
...e/fluid/tests/unittests/auto_parallel/test_dist_matmul.py
+374
-0
未找到文件。
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
4a9895b1
...
@@ -44,6 +44,15 @@ from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost
...
@@ -44,6 +44,15 @@ from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost
from
paddle.distributed.auto_parallel.cost.comm_op_cost
import
AllreduceSumOpCost
,
IdentityOpCost
from
paddle.distributed.auto_parallel.cost.comm_op_cost
import
AllreduceSumOpCost
,
IdentityOpCost
def
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
):
if
trans_x
:
x_dims_mapping
[
-
1
],
x_dims_mapping
[
-
2
]
=
x_dims_mapping
[
-
2
],
x_dims_mapping
[
-
1
]
if
trans_y
:
y_dims_mapping
[
-
1
],
y_dims_mapping
[
-
2
]
=
y_dims_mapping
[
-
2
],
y_dims_mapping
[
-
1
]
def
copy_op_with_new_input_output
(
ctx
,
block
,
src_op
,
**
kwargs
):
def
copy_op_with_new_input_output
(
ctx
,
block
,
src_op
,
**
kwargs
):
dist_op_desc
=
block
.
append_op
(
type
=
'nop'
).
desc
dist_op_desc
=
block
.
append_op
(
type
=
'nop'
).
desc
dist_op_desc
.
copy_from
(
src_op
.
desc
)
dist_op_desc
.
copy_from
(
src_op
.
desc
)
...
@@ -90,6 +99,8 @@ def _update_dims_mapping_for_matmul(dist_op):
...
@@ -90,6 +99,8 @@ def _update_dims_mapping_for_matmul(dist_op):
y_dims_mapping
.
insert
(
1
,
-
1
)
y_dims_mapping
.
insert
(
1
,
-
1
)
out_dims_mapping
.
insert
(
out_dims_mapping_len
,
0
)
out_dims_mapping
.
insert
(
out_dims_mapping_len
,
0
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
new_x_dims_mapping_len
=
len
(
x_dims_mapping
)
new_x_dims_mapping_len
=
len
(
x_dims_mapping
)
new_y_dims_mapping_len
=
len
(
y_dims_mapping
)
new_y_dims_mapping_len
=
len
(
y_dims_mapping
)
new_out_dims_mapping_len
=
len
(
out_dims_mapping
)
new_out_dims_mapping_len
=
len
(
out_dims_mapping
)
...
@@ -117,6 +128,8 @@ def _update_dims_mapping_for_matmul(dist_op):
...
@@ -117,6 +128,8 @@ def _update_dims_mapping_for_matmul(dist_op):
broadcast_out_dims_mapping
broadcast_out_dims_mapping
])
])
if
compatible_dims_mapping
is
None
:
if
compatible_dims_mapping
is
None
:
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
return
False
return
False
for
i
in
range
(
new_x_dims_mapping_len
-
2
):
for
i
in
range
(
new_x_dims_mapping_len
-
2
):
...
@@ -136,13 +149,6 @@ def _update_dims_mapping_for_matmul(dist_op):
...
@@ -136,13 +149,6 @@ def _update_dims_mapping_for_matmul(dist_op):
out_dims_mapping
[
i
]
=
compatible_dims_mapping
[
i
]
out_dims_mapping
[
i
]
=
compatible_dims_mapping
[
i
]
changed
=
True
changed
=
True
if
trans_x
:
x_dims_mapping
[
-
1
],
x_dims_mapping
[
-
2
]
=
x_dims_mapping
[
-
2
],
x_dims_mapping
[
-
1
]
if
trans_y
:
y_dims_mapping
[
-
1
],
y_dims_mapping
[
-
2
]
=
y_dims_mapping
[
-
2
],
y_dims_mapping
[
-
1
]
# The following which uses negative index can be work
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed
=
compute_compatible_and_update_dim_mapping
(
dim_changed
=
compute_compatible_and_update_dim_mapping
(
...
@@ -160,12 +166,7 @@ def _update_dims_mapping_for_matmul(dist_op):
...
@@ -160,12 +166,7 @@ def _update_dims_mapping_for_matmul(dist_op):
if
dim_changed
:
if
dim_changed
:
changed
=
True
changed
=
True
if
trans_x
:
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
x_dims_mapping
[
-
1
],
x_dims_mapping
[
-
2
]
=
x_dims_mapping
[
-
2
],
x_dims_mapping
[
-
1
]
if
trans_y
:
y_dims_mapping
[
-
1
],
y_dims_mapping
[
-
2
]
=
y_dims_mapping
[
-
2
],
y_dims_mapping
[
-
1
]
# Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
# Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
if
x_dims_mapping_len
==
1
:
if
x_dims_mapping_len
==
1
:
...
@@ -188,6 +189,15 @@ def _is_auto_compatible_for_matmul(dist_op):
...
@@ -188,6 +189,15 @@ def _is_auto_compatible_for_matmul(dist_op):
x_name
=
op_desc
.
input
(
'X'
)[
0
]
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
out_name
=
op_desc
.
output
(
'Out'
)[
0
]
out_name
=
op_desc
.
output
(
'Out'
)[
0
]
trans_x
=
None
trans_y
=
None
if
op_desc
.
type
()
==
"matmul_v2"
:
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
elif
op_desc
.
type
()
==
"matmul"
:
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
# Deep copy these dims_mappings for keeping them unchanged.
# Deep copy these dims_mappings for keeping them unchanged.
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
...
@@ -203,17 +213,7 @@ def _is_auto_compatible_for_matmul(dist_op):
...
@@ -203,17 +213,7 @@ def _is_auto_compatible_for_matmul(dist_op):
if
y_dims_mapping_len
==
1
:
if
y_dims_mapping_len
==
1
:
y_dims_mapping
.
insert
(
1
,
-
1
)
y_dims_mapping
.
insert
(
1
,
-
1
)
# NOTE: Partition is not supported if matmul op has trans.
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
op_desc
.
type
()
==
"matmul_v2"
:
if
op_desc
.
attr
(
'trans_x'
)
or
op_desc
.
attr
(
'trans_y'
):
if
x_dims_mapping
[
-
2
:]
!=
[
-
1
,
-
1
]
or
y_dims_mapping
[
-
2
:]
!=
[
-
1
,
-
1
]:
return
False
elif
op_desc
.
type
()
==
"matmul"
:
if
op_desc
.
attr
(
'transpose_X'
)
or
op_desc
.
attr
(
'transpose_Y'
):
if
x_dims_mapping
[
-
2
:]
!=
[
-
1
,
-
1
]
or
y_dims_mapping
[
-
2
:]
!=
[
-
1
,
-
1
]:
return
False
# Deal with dim > 2 and take care of broadcasting
# Deal with dim > 2 and take care of broadcasting
if
out_dims_mapping_len
>
2
:
if
out_dims_mapping_len
>
2
:
...
@@ -304,9 +304,23 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -304,9 +304,23 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
),
"left operand(X) [{}] of dist matmul should not be parameter"
.
format
(
),
"left operand(X) [{}] of dist matmul should not be parameter"
.
format
(
X_var
.
name
)
X_var
.
name
)
X_var_dims_mapping
=
dist_attr
.
get_input_dims_mapping
(
X_var
.
name
)
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var
.
name
)
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var
.
name
)
process_mesh_shape
=
dist_attr
.
process_mesh
.
topology
process_mesh_shape
=
dist_attr
.
process_mesh
.
topology
process_mesh_group
=
dist_attr
.
process_mesh
.
processes
process_mesh_group
=
dist_attr
.
process_mesh
.
processes
trans_x
=
None
trans_y
=
None
if
backward_op
.
desc
.
type
()
==
"matmul_v2_grad"
:
trans_x
=
backward_op
.
desc
.
attr
(
'trans_x'
)
trans_y
=
backward_op
.
desc
.
attr
(
'trans_y'
)
elif
backward_op
.
desc
.
type
()
==
"matmul_grad"
:
trans_x
=
backward_op
.
desc
.
attr
(
'transpose_X'
)
trans_y
=
backward_op
.
desc
.
attr
(
'transpose_Y'
)
if
trans_y
:
trans_x_y_dims_mapping
(
False
,
True
,
None
,
Y_var_dim_mapping
)
# assert len(
# assert len(
# Y_var_dim_mapping
# Y_var_dim_mapping
# ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
# ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
...
@@ -431,9 +445,17 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -431,9 +445,17 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
if
is_parameter_related
(
Y_var
.
name
,
main_block
):
if
is_parameter_related
(
Y_var
.
name
,
main_block
):
out_grad_names
=
[
kwargs
[
'Y@GRAD'
][
0
]]
out_grad_names
=
[
kwargs
[
'Y@GRAD'
][
0
]]
if
trans_x
:
trans_x_y_dims_mapping
(
True
,
False
,
X_var_dims_mapping
,
None
)
gradient_synchronization
(
ctx
,
backward_op
,
act_grad_names
,
out_grad_names
,
gradient_synchronization
(
ctx
,
backward_op
,
act_grad_names
,
out_grad_names
,
rank_id
)
rank_id
)
if
trans_x
:
trans_x_y_dims_mapping
(
True
,
False
,
X_var_dims_mapping
,
None
)
if
trans_y
:
trans_x_y_dims_mapping
(
False
,
True
,
None
,
Y_var_dim_mapping
)
def
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
):
def
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
):
...
@@ -583,8 +605,13 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -583,8 +605,13 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
op_dist_attr
=
dist_op
.
dist_attr
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
x_dims_mapping
=
copy
.
deepcopy
(
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
...
@@ -660,10 +687,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -660,10 +687,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
trans_x
=
src_op
.
attr
(
"transpose_X"
)
trans_y
=
src_op
.
attr
(
"transpose_Y"
)
# TODO infer logic comm presentation
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
Weight_var
.
name
)[
-
1
]
if
trans_y
:
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
...
@@ -723,10 +755,10 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -723,10 +755,10 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
attrs
=
{
attrs
=
{
'transpose_X'
:
False
,
'transpose_X'
:
trans_x
,
'transpose_Y'
:
False
,
'transpose_Y'
:
trans_y
,
'alpha'
:
1
,
'alpha'
:
1
,
OP_ROLE_KEY
:
src_op
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
matmul_op
=
main_block
.
append_op
(
type
=
'matmul'
,
matmul_op
=
main_block
.
append_op
(
type
=
'matmul'
,
...
@@ -902,8 +934,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -902,8 +934,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
op_dist_attr
=
dist_op
.
dist_attr
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
x_dims_mapping
=
copy
.
deepcopy
(
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
...
@@ -932,10 +969,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -932,10 +969,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
(
not
self
.
is_output_compatible
(
dist_op
)):
return
False
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
False
return
True
return
True
def
update_dims_mapping
(
self
,
dist_op
):
def
update_dims_mapping
(
self
,
dist_op
):
...
@@ -983,10 +1018,15 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -983,10 +1018,15 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
trans_x
=
src_op
.
attr
(
'transpose_X'
)
trans_y
=
src_op
.
attr
(
'transpose_Y'
)
# TODO infer logic comm presentation
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
Weight_var
.
name
)[
-
2
]
if
trans_y
:
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
...
@@ -1002,8 +1042,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -1002,8 +1042,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
'linear'
)
attrs
=
{
attrs
=
{
'transpose_X'
:
False
,
'transpose_X'
:
trans_x
,
'transpose_Y'
:
False
,
'transpose_Y'
:
trans_y
,
'alpha'
:
1
,
'alpha'
:
1
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
}
...
@@ -1354,8 +1394,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1354,8 +1394,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
op_dist_attr
=
dist_op
.
dist_attr
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
x_dims_mapping
=
copy
.
deepcopy
(
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
...
@@ -1382,10 +1427,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1382,10 +1427,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
(
not
self
.
is_output_compatible
(
dist_op
)):
return
False
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
False
return
True
return
True
def
update_dims_mapping
(
self
,
dist_op
):
def
update_dims_mapping
(
self
,
dist_op
):
...
@@ -1433,10 +1476,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1433,10 +1476,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
trans_x
=
src_op
.
attr
(
'trans_x'
)
trans_y
=
src_op
.
attr
(
'trans_y'
)
# TODO infer logic comm presentation
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
Weight_var
.
name
)[
-
1
]
if
trans_y
:
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
...
@@ -1495,8 +1543,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1495,8 +1543,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
attrs
=
{
attrs
=
{
'trans_x'
:
False
,
'trans_x'
:
trans_x
,
'trans_y'
:
False
,
'trans_y'
:
trans_y
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
...
@@ -1670,8 +1718,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1670,8 +1718,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
op_dist_attr
=
dist_op
.
dist_attr
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
x_dims_mapping
=
copy
.
deepcopy
(
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
...
@@ -1700,10 +1753,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1700,10 +1753,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
(
not
self
.
is_output_compatible
(
dist_op
)):
return
False
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
False
return
True
return
True
def
update_dims_mapping
(
self
,
dist_op
):
def
update_dims_mapping
(
self
,
dist_op
):
...
@@ -1751,10 +1802,15 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1751,10 +1802,15 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
trans_x
=
src_op
.
attr
(
'trans_x'
)
trans_y
=
src_op
.
attr
(
'trans_y'
)
# TODO infer logic comm presentation
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
Weight_var
.
name
)[
-
2
]
if
trans_y
:
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
...
@@ -1770,8 +1826,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1770,8 +1826,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
'linear'
)
attrs
=
{
attrs
=
{
'trans_x'
:
False
,
'trans_x'
:
trans_x
,
'trans_y'
:
False
,
'trans_y'
:
trans_y
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
...
...
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
4a9895b1
...
@@ -71,4 +71,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
...
@@ -71,4 +71,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_dist_attr_v2 MODULES test_dist_attr_v2
)
py_test_modules
(
test_dist_attr_v2 MODULES test_dist_attr_v2
)
py_test_modules
(
test_lr_grad_clip MODULES test_lr_grad_clip
)
py_test_modules
(
test_lr_grad_clip MODULES test_lr_grad_clip
)
py_test_modules
(
test_quantization MODULES test_quantization
)
py_test_modules
(
test_quantization MODULES test_quantization
)
py_test_modules
(
test_dist_matmul MODULES test_dist_matmul
)
endif
()
endif
()
python/paddle/fluid/tests/unittests/auto_parallel/test_dist_matmul.py
0 → 100644
浏览文件 @
4a9895b1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
paddle.distributed.auto_parallel
as
auto
from
paddle.fluid
import
program_guard
from
paddle.fluid.backward
import
append_backward
from
paddle.distributed.auto_parallel.utils
import
print_program_with_dist_attr
paddle
.
enable_static
()
mesh
=
[[
0
,
1
],
[
2
,
3
]]
def
init_x_row
(
trans_x
):
if
trans_x
:
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
10
,
6
,
8
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
x
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
0
,
1
,
-
1
]
})
return
x
else
:
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
10
,
8
,
6
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
x
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
0
,
-
1
,
1
]
})
return
x
def
init_x_col
(
trans_x
):
if
trans_x
:
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
6
,
8
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
x
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
-
1
,
0
]
})
return
x
else
:
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
8
,
6
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
x
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
0
,
-
1
]
})
return
x
def
init_y_row
(
trans_y
):
if
trans_y
:
y
=
paddle
.
static
.
data
(
name
=
'y'
,
shape
=
[
4
,
6
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
y
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
-
1
,
1
]
})
return
y
else
:
y
=
paddle
.
static
.
data
(
name
=
'y'
,
shape
=
[
6
,
4
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
y
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
1
,
-
1
]
})
return
y
def
init_y_col
(
trans_y
):
if
trans_y
:
y
=
paddle
.
static
.
data
(
name
=
'y'
,
shape
=
[
4
,
6
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
y
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
1
,
-
1
]
})
return
y
else
:
y
=
paddle
.
static
.
data
(
name
=
'y'
,
shape
=
[
6
,
4
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
y
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
-
1
,
1
]
})
return
y
def
matmul_dp2mp2
(
init_x
,
init_y
,
trans_x
,
trans_y
):
main_program
=
paddle
.
fluid
.
Program
()
start_program
=
paddle
.
fluid
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
,
start_program
):
x
=
init_x
(
trans_x
)
y
=
init_y
(
trans_y
)
x
.
stop_gradient
=
False
y
.
stop_gradient
=
False
out
=
paddle
.
fluid
.
layers
.
matmul
(
x
,
y
,
transpose_x
=
trans_x
,
transpose_y
=
trans_y
)
loss
=
paddle
.
mean
(
out
)
return
main_program
,
start_program
,
loss
def
matmulv2_dp2mp2
(
init_x
,
init_y
,
trans_x
,
trans_y
):
main_program
=
paddle
.
fluid
.
Program
()
start_program
=
paddle
.
fluid
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
,
start_program
):
x
=
init_x
(
trans_x
)
y
=
init_y
(
trans_y
)
x
.
stop_gradient
=
False
y
.
stop_gradient
=
False
out
=
paddle
.
matmul
(
x
,
y
,
transpose_x
=
trans_x
,
transpose_y
=
trans_y
)
loss
=
paddle
.
mean
(
out
)
return
main_program
,
start_program
,
loss
def
parallelizer
(
program_func
,
*
args
,
**
kwargs
):
from
paddle.distributed.auto_parallel.completion
import
Completer
from
paddle.distributed.auto_parallel.partitioner
import
Partitioner
from
paddle.distributed.auto_parallel.dist_context
import
DistributedContext
main_program
,
start_program
,
loss
=
program_func
(
*
args
,
**
kwargs
)
dist_context
=
DistributedContext
()
completer
=
Completer
(
dist_context
)
completer
.
complete_forward_annotation
(
main_program
)
dist_context
.
block_state
.
parse_forward_blocks
(
main_program
)
with
program_guard
(
main_program
,
start_program
):
append_backward
(
loss
,
distop_context
=
dist_context
.
dist_op_context
)
completer
.
complete_backward_annotation
(
main_program
)
dist_context
.
block_state
.
parse_backward_blocks
(
main_program
)
partitioner
=
Partitioner
(
dist_context
,
0
)
dist_main_prog
,
_
,
_
=
partitioner
.
partition
(
main_program
,
start_program
,
[])
return
dist_main_prog
,
dist_context
class
TestDistMatmul
(
unittest
.
TestCase
):
def
check_col_program
(
self
,
main_program
,
dist_ctx
):
# [0, -1] * [-1, 1] --> [0, 1]
ref_ops
=
[
"c_identity"
,
"matmul"
,
"reduce_mean"
,
"fill_constant"
,
"reduce_mean_grad"
,
"matmul_grad"
]
ops
=
[]
block
=
main_program
.
global_block
()
for
op
in
block
.
ops
:
ops
.
append
(
op
.
type
)
if
op
.
type
==
"matmul"
:
out_name
=
op
.
output
(
'Out'
)[
0
]
out_var
=
block
.
vars
[
out_name
]
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
0
assert
op_dist_attr
.
impl_type
==
"matmul"
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
assert
out_dims_mapping
==
[
0
,
1
]
tensor_dist_attr
=
dist_ctx
.
get_tensor_dist_attr_for_program
(
out_var
)
assert
tensor_dist_attr
.
dims_mapping
==
[
0
,
1
]
if
op
.
type
==
"matmul_grad"
:
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
0
assert
op_dist_attr
.
impl_type
==
"matmul"
assert
ops
==
ref_ops
def
check_row_program
(
self
,
main_program
,
dist_ctx
):
# [0, -1, 1] * [1, -1] --> [0, -1, -1]
ref_ops
=
[
"matmul"
,
"c_allreduce_sum"
,
"reduce_mean"
,
"fill_constant"
,
"reduce_mean_grad"
,
"matmul_grad"
]
ops
=
[]
block
=
main_program
.
global_block
()
for
op
in
block
.
ops
:
ops
.
append
(
op
.
type
)
if
op
.
type
==
"matmul"
:
out_name
=
op
.
output
(
'Out'
)[
0
]
out_var
=
block
.
vars
[
out_name
]
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
1
assert
op_dist_attr
.
impl_type
==
"matmul"
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
assert
out_dims_mapping
==
[
0
,
-
1
,
-
1
]
tensor_dist_attr
=
dist_ctx
.
get_tensor_dist_attr_for_program
(
out_var
)
assert
tensor_dist_attr
.
dims_mapping
==
[
0
,
-
1
,
-
1
]
if
op
.
type
==
"matmul_grad"
:
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
1
assert
op_dist_attr
.
impl_type
==
"matmul"
assert
ops
==
ref_ops
class
TestDistMatmulCol
(
TestDistMatmul
):
def
init
(
self
,
trans_x
,
trans_y
):
dist_main_prog
,
dist_ctx
=
parallelizer
(
matmul_dp2mp2
,
init_x_col
,
init_y_col
,
trans_x
,
trans_y
)
return
dist_main_prog
,
dist_ctx
def
test_matmul_col
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
False
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
False
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
True
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
True
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
class
TestDistMatmulRow
(
TestDistMatmul
):
def
init
(
self
,
trans_x
,
trans_y
):
dist_main_prog
,
dist_ctx
=
parallelizer
(
matmul_dp2mp2
,
init_x_row
,
init_y_row
,
trans_x
,
trans_y
)
return
dist_main_prog
,
dist_ctx
def
test_matmul_row
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
False
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
False
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
True
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
True
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
class
TestDistMatmulV2
(
unittest
.
TestCase
):
def
check_col_program
(
self
,
main_program
,
dist_ctx
):
# [0, -1] * [-1, 1] --> [0, 1]
ref_ops
=
[
"c_identity"
,
"matmul_v2"
,
"reduce_mean"
,
"fill_constant"
,
"reduce_mean_grad"
,
"matmul_v2_grad"
]
ops
=
[]
block
=
main_program
.
global_block
()
for
op
in
block
.
ops
:
ops
.
append
(
op
.
type
)
if
op
.
type
==
"matmul_v2"
:
out_name
=
op
.
output
(
'Out'
)[
0
]
out_var
=
block
.
vars
[
out_name
]
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
0
assert
op_dist_attr
.
impl_type
==
"matmul_v2"
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
assert
out_dims_mapping
==
[
0
,
1
]
tensor_dist_attr
=
dist_ctx
.
get_tensor_dist_attr_for_program
(
out_var
)
assert
tensor_dist_attr
.
dims_mapping
==
[
0
,
1
]
if
op
.
type
==
"matmul_v2_grad"
:
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
0
assert
op_dist_attr
.
impl_type
==
"matmul_v2"
assert
ops
==
ref_ops
def
check_row_program
(
self
,
main_program
,
dist_ctx
):
# [0, -1, 1] * [1, -1] --> [0, -1, -1]
ref_ops
=
[
"matmul_v2"
,
"c_allreduce_sum"
,
"reduce_mean"
,
"fill_constant"
,
"reduce_mean_grad"
,
"matmul_v2_grad"
]
ops
=
[]
block
=
main_program
.
global_block
()
for
op
in
block
.
ops
:
ops
.
append
(
op
.
type
)
if
op
.
type
==
"matmul_v2"
:
out_name
=
op
.
output
(
'Out'
)[
0
]
out_var
=
block
.
vars
[
out_name
]
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
1
assert
op_dist_attr
.
impl_type
==
"matmul_v2"
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
assert
out_dims_mapping
==
[
0
,
-
1
,
-
1
]
tensor_dist_attr
=
dist_ctx
.
get_tensor_dist_attr_for_program
(
out_var
)
assert
tensor_dist_attr
.
dims_mapping
==
[
0
,
-
1
,
-
1
]
if
op
.
type
==
"matmul_v2_grad"
:
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
1
assert
op_dist_attr
.
impl_type
==
"matmul_v2"
assert
ops
==
ref_ops
class
TestDistMatmulV2Col
(
TestDistMatmulV2
):
def
init
(
self
,
trans_x
,
trans_y
):
dist_main_prog
,
dist_ctx
=
parallelizer
(
matmulv2_dp2mp2
,
init_x_col
,
init_y_col
,
trans_x
,
trans_y
)
return
dist_main_prog
,
dist_ctx
def
test_matmul_col
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
False
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
False
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
True
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
True
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
class
TestDistMatmulV2Row
(
TestDistMatmulV2
):
def
init
(
self
,
trans_x
,
trans_y
):
dist_main_prog
,
dist_ctx
=
parallelizer
(
matmulv2_dp2mp2
,
init_x_row
,
init_y_row
,
trans_x
,
trans_y
)
return
dist_main_prog
,
dist_ctx
def
test_matmul_row
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
False
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
False
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
True
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
True
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录