Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4a9895b1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4a9895b1
编写于
9月 05, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
9月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] dist_matmul trans_x or trans_y (#45678)
* dist_matmul trans * update unittest * update cmakelist
上级
a12c806f
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
478 addition
and
47 deletion
+478
-47
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+103
-47
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_dist_matmul.py
...e/fluid/tests/unittests/auto_parallel/test_dist_matmul.py
+374
-0
未找到文件。
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
4a9895b1
...
...
@@ -44,6 +44,15 @@ from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost
from
paddle.distributed.auto_parallel.cost.comm_op_cost
import
AllreduceSumOpCost
,
IdentityOpCost
def
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
):
if
trans_x
:
x_dims_mapping
[
-
1
],
x_dims_mapping
[
-
2
]
=
x_dims_mapping
[
-
2
],
x_dims_mapping
[
-
1
]
if
trans_y
:
y_dims_mapping
[
-
1
],
y_dims_mapping
[
-
2
]
=
y_dims_mapping
[
-
2
],
y_dims_mapping
[
-
1
]
def
copy_op_with_new_input_output
(
ctx
,
block
,
src_op
,
**
kwargs
):
dist_op_desc
=
block
.
append_op
(
type
=
'nop'
).
desc
dist_op_desc
.
copy_from
(
src_op
.
desc
)
...
...
@@ -90,6 +99,8 @@ def _update_dims_mapping_for_matmul(dist_op):
y_dims_mapping
.
insert
(
1
,
-
1
)
out_dims_mapping
.
insert
(
out_dims_mapping_len
,
0
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
new_x_dims_mapping_len
=
len
(
x_dims_mapping
)
new_y_dims_mapping_len
=
len
(
y_dims_mapping
)
new_out_dims_mapping_len
=
len
(
out_dims_mapping
)
...
...
@@ -117,6 +128,8 @@ def _update_dims_mapping_for_matmul(dist_op):
broadcast_out_dims_mapping
])
if
compatible_dims_mapping
is
None
:
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
return
False
for
i
in
range
(
new_x_dims_mapping_len
-
2
):
...
...
@@ -136,13 +149,6 @@ def _update_dims_mapping_for_matmul(dist_op):
out_dims_mapping
[
i
]
=
compatible_dims_mapping
[
i
]
changed
=
True
if
trans_x
:
x_dims_mapping
[
-
1
],
x_dims_mapping
[
-
2
]
=
x_dims_mapping
[
-
2
],
x_dims_mapping
[
-
1
]
if
trans_y
:
y_dims_mapping
[
-
1
],
y_dims_mapping
[
-
2
]
=
y_dims_mapping
[
-
2
],
y_dims_mapping
[
-
1
]
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed
=
compute_compatible_and_update_dim_mapping
(
...
...
@@ -160,12 +166,7 @@ def _update_dims_mapping_for_matmul(dist_op):
if
dim_changed
:
changed
=
True
if
trans_x
:
x_dims_mapping
[
-
1
],
x_dims_mapping
[
-
2
]
=
x_dims_mapping
[
-
2
],
x_dims_mapping
[
-
1
]
if
trans_y
:
y_dims_mapping
[
-
1
],
y_dims_mapping
[
-
2
]
=
y_dims_mapping
[
-
2
],
y_dims_mapping
[
-
1
]
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
# Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
if
x_dims_mapping_len
==
1
:
...
...
@@ -188,6 +189,15 @@ def _is_auto_compatible_for_matmul(dist_op):
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
out_name
=
op_desc
.
output
(
'Out'
)[
0
]
trans_x
=
None
trans_y
=
None
if
op_desc
.
type
()
==
"matmul_v2"
:
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
elif
op_desc
.
type
()
==
"matmul"
:
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
# Deep copy these dims_mappings for keeping them unchanged.
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
...
...
@@ -203,17 +213,7 @@ def _is_auto_compatible_for_matmul(dist_op):
if
y_dims_mapping_len
==
1
:
y_dims_mapping
.
insert
(
1
,
-
1
)
# NOTE: Partition is not supported if matmul op has trans.
if
op_desc
.
type
()
==
"matmul_v2"
:
if
op_desc
.
attr
(
'trans_x'
)
or
op_desc
.
attr
(
'trans_y'
):
if
x_dims_mapping
[
-
2
:]
!=
[
-
1
,
-
1
]
or
y_dims_mapping
[
-
2
:]
!=
[
-
1
,
-
1
]:
return
False
elif
op_desc
.
type
()
==
"matmul"
:
if
op_desc
.
attr
(
'transpose_X'
)
or
op_desc
.
attr
(
'transpose_Y'
):
if
x_dims_mapping
[
-
2
:]
!=
[
-
1
,
-
1
]
or
y_dims_mapping
[
-
2
:]
!=
[
-
1
,
-
1
]:
return
False
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
# Deal with dim > 2 and take care of broadcasting
if
out_dims_mapping_len
>
2
:
...
...
@@ -304,9 +304,23 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
),
"left operand(X) [{}] of dist matmul should not be parameter"
.
format
(
X_var
.
name
)
X_var_dims_mapping
=
dist_attr
.
get_input_dims_mapping
(
X_var
.
name
)
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var
.
name
)
process_mesh_shape
=
dist_attr
.
process_mesh
.
topology
process_mesh_group
=
dist_attr
.
process_mesh
.
processes
trans_x
=
None
trans_y
=
None
if
backward_op
.
desc
.
type
()
==
"matmul_v2_grad"
:
trans_x
=
backward_op
.
desc
.
attr
(
'trans_x'
)
trans_y
=
backward_op
.
desc
.
attr
(
'trans_y'
)
elif
backward_op
.
desc
.
type
()
==
"matmul_grad"
:
trans_x
=
backward_op
.
desc
.
attr
(
'transpose_X'
)
trans_y
=
backward_op
.
desc
.
attr
(
'transpose_Y'
)
if
trans_y
:
trans_x_y_dims_mapping
(
False
,
True
,
None
,
Y_var_dim_mapping
)
# assert len(
# Y_var_dim_mapping
# ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
...
...
@@ -431,9 +445,17 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
if
is_parameter_related
(
Y_var
.
name
,
main_block
):
out_grad_names
=
[
kwargs
[
'Y@GRAD'
][
0
]]
if
trans_x
:
trans_x_y_dims_mapping
(
True
,
False
,
X_var_dims_mapping
,
None
)
gradient_synchronization
(
ctx
,
backward_op
,
act_grad_names
,
out_grad_names
,
rank_id
)
if
trans_x
:
trans_x_y_dims_mapping
(
True
,
False
,
X_var_dims_mapping
,
None
)
if
trans_y
:
trans_x_y_dims_mapping
(
False
,
True
,
None
,
Y_var_dim_mapping
)
def
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
):
...
...
@@ -583,8 +605,13 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
...
...
@@ -660,10 +687,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
trans_x
=
src_op
.
attr
(
"transpose_X"
)
trans_y
=
src_op
.
attr
(
"transpose_Y"
)
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
if
trans_y
:
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
...
...
@@ -723,10 +755,10 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
attrs
=
{
'transpose_X'
:
False
,
'transpose_Y'
:
False
,
'transpose_X'
:
trans_x
,
'transpose_Y'
:
trans_y
,
'alpha'
:
1
,
OP_ROLE_KEY
:
src_op
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
matmul_op
=
main_block
.
append_op
(
type
=
'matmul'
,
...
...
@@ -902,8 +934,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
...
...
@@ -932,10 +969,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
True
def
update_dims_mapping
(
self
,
dist_op
):
...
...
@@ -983,10 +1018,15 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
trans_x
=
src_op
.
attr
(
'transpose_X'
)
trans_y
=
src_op
.
attr
(
'transpose_Y'
)
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
if
trans_y
:
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
...
...
@@ -1002,8 +1042,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
attrs
=
{
'transpose_X'
:
False
,
'transpose_Y'
:
False
,
'transpose_X'
:
trans_x
,
'transpose_Y'
:
trans_y
,
'alpha'
:
1
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
...
...
@@ -1354,8 +1394,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
...
...
@@ -1382,10 +1427,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
True
def
update_dims_mapping
(
self
,
dist_op
):
...
...
@@ -1433,10 +1476,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
trans_x
=
src_op
.
attr
(
'trans_x'
)
trans_y
=
src_op
.
attr
(
'trans_y'
)
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
if
trans_y
:
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
...
...
@@ -1495,8 +1543,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
attrs
=
{
'trans_x'
:
False
,
'trans_y'
:
False
,
'trans_x'
:
trans_x
,
'trans_y'
:
trans_y
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
...
...
@@ -1670,8 +1718,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
op_dist_attr
=
dist_op
.
dist_attr
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
...
...
@@ -1700,10 +1753,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
True
def
update_dims_mapping
(
self
,
dist_op
):
...
...
@@ -1751,10 +1802,15 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Out_var
=
main_block
.
var
(
kwargs
[
'Out'
][
0
])
trans_x
=
src_op
.
attr
(
'trans_x'
)
trans_y
=
src_op
.
attr
(
'trans_y'
)
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
if
trans_y
:
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
...
...
@@ -1770,8 +1826,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
attrs
=
{
'trans_x'
:
False
,
'trans_y'
:
False
,
'trans_x'
:
trans_x
,
'trans_y'
:
trans_y
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
...
...
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
4a9895b1
...
...
@@ -71,4 +71,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_dist_attr_v2 MODULES test_dist_attr_v2
)
py_test_modules
(
test_lr_grad_clip MODULES test_lr_grad_clip
)
py_test_modules
(
test_quantization MODULES test_quantization
)
py_test_modules
(
test_dist_matmul MODULES test_dist_matmul
)
endif
()
python/paddle/fluid/tests/unittests/auto_parallel/test_dist_matmul.py
0 → 100644
浏览文件 @
4a9895b1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
paddle.distributed.auto_parallel
as
auto
from
paddle.fluid
import
program_guard
from
paddle.fluid.backward
import
append_backward
from
paddle.distributed.auto_parallel.utils
import
print_program_with_dist_attr
paddle
.
enable_static
()
mesh
=
[[
0
,
1
],
[
2
,
3
]]
def
init_x_row
(
trans_x
):
if
trans_x
:
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
10
,
6
,
8
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
x
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
0
,
1
,
-
1
]
})
return
x
else
:
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
10
,
8
,
6
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
x
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
0
,
-
1
,
1
]
})
return
x
def
init_x_col
(
trans_x
):
if
trans_x
:
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
6
,
8
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
x
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
-
1
,
0
]
})
return
x
else
:
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
8
,
6
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
x
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
0
,
-
1
]
})
return
x
def
init_y_row
(
trans_y
):
if
trans_y
:
y
=
paddle
.
static
.
data
(
name
=
'y'
,
shape
=
[
4
,
6
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
y
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
-
1
,
1
]
})
return
y
else
:
y
=
paddle
.
static
.
data
(
name
=
'y'
,
shape
=
[
6
,
4
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
y
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
1
,
-
1
]
})
return
y
def
init_y_col
(
trans_y
):
if
trans_y
:
y
=
paddle
.
static
.
data
(
name
=
'y'
,
shape
=
[
4
,
6
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
y
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
1
,
-
1
]
})
return
y
else
:
y
=
paddle
.
static
.
data
(
name
=
'y'
,
shape
=
[
6
,
4
],
dtype
=
'float32'
)
auto
.
shard_tensor
(
y
,
dist_attr
=
{
"process_mesh"
:
mesh
,
"dims_mapping"
:
[
-
1
,
1
]
})
return
y
def
matmul_dp2mp2
(
init_x
,
init_y
,
trans_x
,
trans_y
):
main_program
=
paddle
.
fluid
.
Program
()
start_program
=
paddle
.
fluid
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
,
start_program
):
x
=
init_x
(
trans_x
)
y
=
init_y
(
trans_y
)
x
.
stop_gradient
=
False
y
.
stop_gradient
=
False
out
=
paddle
.
fluid
.
layers
.
matmul
(
x
,
y
,
transpose_x
=
trans_x
,
transpose_y
=
trans_y
)
loss
=
paddle
.
mean
(
out
)
return
main_program
,
start_program
,
loss
def
matmulv2_dp2mp2
(
init_x
,
init_y
,
trans_x
,
trans_y
):
main_program
=
paddle
.
fluid
.
Program
()
start_program
=
paddle
.
fluid
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
,
start_program
):
x
=
init_x
(
trans_x
)
y
=
init_y
(
trans_y
)
x
.
stop_gradient
=
False
y
.
stop_gradient
=
False
out
=
paddle
.
matmul
(
x
,
y
,
transpose_x
=
trans_x
,
transpose_y
=
trans_y
)
loss
=
paddle
.
mean
(
out
)
return
main_program
,
start_program
,
loss
def
parallelizer
(
program_func
,
*
args
,
**
kwargs
):
from
paddle.distributed.auto_parallel.completion
import
Completer
from
paddle.distributed.auto_parallel.partitioner
import
Partitioner
from
paddle.distributed.auto_parallel.dist_context
import
DistributedContext
main_program
,
start_program
,
loss
=
program_func
(
*
args
,
**
kwargs
)
dist_context
=
DistributedContext
()
completer
=
Completer
(
dist_context
)
completer
.
complete_forward_annotation
(
main_program
)
dist_context
.
block_state
.
parse_forward_blocks
(
main_program
)
with
program_guard
(
main_program
,
start_program
):
append_backward
(
loss
,
distop_context
=
dist_context
.
dist_op_context
)
completer
.
complete_backward_annotation
(
main_program
)
dist_context
.
block_state
.
parse_backward_blocks
(
main_program
)
partitioner
=
Partitioner
(
dist_context
,
0
)
dist_main_prog
,
_
,
_
=
partitioner
.
partition
(
main_program
,
start_program
,
[])
return
dist_main_prog
,
dist_context
class
TestDistMatmul
(
unittest
.
TestCase
):
def
check_col_program
(
self
,
main_program
,
dist_ctx
):
# [0, -1] * [-1, 1] --> [0, 1]
ref_ops
=
[
"c_identity"
,
"matmul"
,
"reduce_mean"
,
"fill_constant"
,
"reduce_mean_grad"
,
"matmul_grad"
]
ops
=
[]
block
=
main_program
.
global_block
()
for
op
in
block
.
ops
:
ops
.
append
(
op
.
type
)
if
op
.
type
==
"matmul"
:
out_name
=
op
.
output
(
'Out'
)[
0
]
out_var
=
block
.
vars
[
out_name
]
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
0
assert
op_dist_attr
.
impl_type
==
"matmul"
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
assert
out_dims_mapping
==
[
0
,
1
]
tensor_dist_attr
=
dist_ctx
.
get_tensor_dist_attr_for_program
(
out_var
)
assert
tensor_dist_attr
.
dims_mapping
==
[
0
,
1
]
if
op
.
type
==
"matmul_grad"
:
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
0
assert
op_dist_attr
.
impl_type
==
"matmul"
assert
ops
==
ref_ops
def
check_row_program
(
self
,
main_program
,
dist_ctx
):
# [0, -1, 1] * [1, -1] --> [0, -1, -1]
ref_ops
=
[
"matmul"
,
"c_allreduce_sum"
,
"reduce_mean"
,
"fill_constant"
,
"reduce_mean_grad"
,
"matmul_grad"
]
ops
=
[]
block
=
main_program
.
global_block
()
for
op
in
block
.
ops
:
ops
.
append
(
op
.
type
)
if
op
.
type
==
"matmul"
:
out_name
=
op
.
output
(
'Out'
)[
0
]
out_var
=
block
.
vars
[
out_name
]
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
1
assert
op_dist_attr
.
impl_type
==
"matmul"
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
assert
out_dims_mapping
==
[
0
,
-
1
,
-
1
]
tensor_dist_attr
=
dist_ctx
.
get_tensor_dist_attr_for_program
(
out_var
)
assert
tensor_dist_attr
.
dims_mapping
==
[
0
,
-
1
,
-
1
]
if
op
.
type
==
"matmul_grad"
:
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
1
assert
op_dist_attr
.
impl_type
==
"matmul"
assert
ops
==
ref_ops
class
TestDistMatmulCol
(
TestDistMatmul
):
def
init
(
self
,
trans_x
,
trans_y
):
dist_main_prog
,
dist_ctx
=
parallelizer
(
matmul_dp2mp2
,
init_x_col
,
init_y_col
,
trans_x
,
trans_y
)
return
dist_main_prog
,
dist_ctx
def
test_matmul_col
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
False
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
False
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
True
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
True
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
class
TestDistMatmulRow
(
TestDistMatmul
):
def
init
(
self
,
trans_x
,
trans_y
):
dist_main_prog
,
dist_ctx
=
parallelizer
(
matmul_dp2mp2
,
init_x_row
,
init_y_row
,
trans_x
,
trans_y
)
return
dist_main_prog
,
dist_ctx
def
test_matmul_row
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
False
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
False
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
True
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
True
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
class
TestDistMatmulV2
(
unittest
.
TestCase
):
def
check_col_program
(
self
,
main_program
,
dist_ctx
):
# [0, -1] * [-1, 1] --> [0, 1]
ref_ops
=
[
"c_identity"
,
"matmul_v2"
,
"reduce_mean"
,
"fill_constant"
,
"reduce_mean_grad"
,
"matmul_v2_grad"
]
ops
=
[]
block
=
main_program
.
global_block
()
for
op
in
block
.
ops
:
ops
.
append
(
op
.
type
)
if
op
.
type
==
"matmul_v2"
:
out_name
=
op
.
output
(
'Out'
)[
0
]
out_var
=
block
.
vars
[
out_name
]
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
0
assert
op_dist_attr
.
impl_type
==
"matmul_v2"
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
assert
out_dims_mapping
==
[
0
,
1
]
tensor_dist_attr
=
dist_ctx
.
get_tensor_dist_attr_for_program
(
out_var
)
assert
tensor_dist_attr
.
dims_mapping
==
[
0
,
1
]
if
op
.
type
==
"matmul_v2_grad"
:
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
0
assert
op_dist_attr
.
impl_type
==
"matmul_v2"
assert
ops
==
ref_ops
def
check_row_program
(
self
,
main_program
,
dist_ctx
):
# [0, -1, 1] * [1, -1] --> [0, -1, -1]
ref_ops
=
[
"matmul_v2"
,
"c_allreduce_sum"
,
"reduce_mean"
,
"fill_constant"
,
"reduce_mean_grad"
,
"matmul_v2_grad"
]
ops
=
[]
block
=
main_program
.
global_block
()
for
op
in
block
.
ops
:
ops
.
append
(
op
.
type
)
if
op
.
type
==
"matmul_v2"
:
out_name
=
op
.
output
(
'Out'
)[
0
]
out_var
=
block
.
vars
[
out_name
]
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
1
assert
op_dist_attr
.
impl_type
==
"matmul_v2"
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
assert
out_dims_mapping
==
[
0
,
-
1
,
-
1
]
tensor_dist_attr
=
dist_ctx
.
get_tensor_dist_attr_for_program
(
out_var
)
assert
tensor_dist_attr
.
dims_mapping
==
[
0
,
-
1
,
-
1
]
if
op
.
type
==
"matmul_v2_grad"
:
op_dist_attr
=
dist_ctx
.
get_op_dist_attr_for_program
(
op
)
assert
op_dist_attr
.
impl_idx
==
1
assert
op_dist_attr
.
impl_type
==
"matmul_v2"
assert
ops
==
ref_ops
class
TestDistMatmulV2Col
(
TestDistMatmulV2
):
def
init
(
self
,
trans_x
,
trans_y
):
dist_main_prog
,
dist_ctx
=
parallelizer
(
matmulv2_dp2mp2
,
init_x_col
,
init_y_col
,
trans_x
,
trans_y
)
return
dist_main_prog
,
dist_ctx
def
test_matmul_col
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
False
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
False
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
True
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
True
)
self
.
check_col_program
(
dist_main_prog
,
dist_ctx
)
class
TestDistMatmulV2Row
(
TestDistMatmulV2
):
def
init
(
self
,
trans_x
,
trans_y
):
dist_main_prog
,
dist_ctx
=
parallelizer
(
matmulv2_dp2mp2
,
init_x_row
,
init_y_row
,
trans_x
,
trans_y
)
return
dist_main_prog
,
dist_ctx
def
test_matmul_row
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
False
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
False
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
False
,
True
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
def
test_trans_x_trans_y
(
self
):
dist_main_prog
,
dist_ctx
=
self
.
init
(
True
,
True
)
self
.
check_row_program
(
dist_main_prog
,
dist_ctx
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录