Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
529f1425
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
529f1425
编写于
1月 25, 2022
作者:
C
caozhou
提交者:
GitHub
1月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Auto Parallel】Update reshard for complete (#39073)
* update reshard for newest completion * update unitest * merge newest
上级
0c3657ad
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
87 addition
and
11 deletion
+87
-11
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+56
-9
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
...addle/fluid/tests/unittests/test_auto_parallel_reshard.py
+31
-2
未找到文件。
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
529f1425
...
...
@@ -279,7 +279,7 @@ def _is_overlapped(shape_x, shape_y):
return
overlapped
def
_need_reshard
(
dist_tensor
,
dist_op
):
def
_need_reshard
(
dist_tensor
,
dist_op
,
op_input
=
True
):
"""Judge the tensor whether needs to be resharded."""
is_reshard
=
False
tensor_dist_attr
=
dist_tensor
.
dist_attr
...
...
@@ -289,13 +289,31 @@ def _need_reshard(dist_tensor, dist_op):
op_dist_attr
=
dist_op
.
dist_attr
op_input_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
op_process_mesh
=
op_dist_attr
.
process_mesh
if
all
(
map
(
lambda
x
:
x
is
not
None
,
[
tensor_dims_mapping
,
tensor_process_mesh
,
op_input_dims_mapping
,
op_process_mesh
])):
if
tensor_dims_mapping
!=
op_input_dims_mapping
or
tensor_process_mesh
!=
op_process_mesh
:
is_reshard
=
True
if
op_input
:
op_input_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
op_process_mesh
=
op_dist_attr
.
process_mesh
if
all
(
map
(
lambda
x
:
x
is
not
None
,
[
tensor_dims_mapping
,
tensor_process_mesh
,
op_input_dims_mapping
,
op_process_mesh
])):
if
tensor_dims_mapping
!=
op_input_dims_mapping
or
tensor_process_mesh
!=
op_process_mesh
:
is_reshard
=
True
else
:
op_output_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
tensor_name
)
op_process_mesh
=
op_dist_attr
.
process_mesh
if
all
(
map
(
lambda
x
:
x
is
not
None
,
[
tensor_dims_mapping
,
tensor_process_mesh
,
op_output_dims_mapping
,
op_process_mesh
])):
if
tensor_process_mesh
!=
op_process_mesh
:
is_reshard
=
True
if
tensor_dims_mapping
!=
op_output_dims_mapping
:
raise
ValueError
(
"It is not supported that tensor dims mapping is different from op output dims mapping."
)
return
is_reshard
...
...
@@ -948,12 +966,13 @@ def remove_no_need_in_startup(auto_parallel_main_prog,
def
reshard
(
auto_parallel_main_prog
,
auto_parallel_startup_prog
,
rank_id
,
dist_context
):
"""
Reshard tensor in the program according to its dist
attr and corresponding op dist attr
.
Reshard tensor in the program according to its dist
ributed attribute and corresponding op distributed attribute
.
Args:
auto_parallel_main_prog (Program): An auto parallel main program.
auto_parallel_startup_prog (Program): An auto parallel startup program.
rank_id (int): The process id.
dist_context (DistributedContext): The distributed context of this rank.
"""
assert
isinstance
(
auto_parallel_main_prog
,
Program
),
"The type of auto_parallel_main_prog should be Program, "
\
"but got {}."
.
format
(
type
(
auto_parallel_main_prog
))
...
...
@@ -1001,6 +1020,34 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
else
:
idx
+=
1
# insert send and recv op if output process mesh is different from tensor process mesh
idx
=
0
skip_ops
=
[
"create_py_reader"
,
"create_double_buffer_reader"
,
"read"
]
while
idx
<
len
(
block
.
ops
):
pre_op_count
=
len
(
block
.
ops
)
op
=
block
.
ops
[
idx
]
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
dist_op
is
not
None
and
op
.
type
not
in
skip_ops
:
for
var_name
in
op
.
output_arg_names
:
var
=
block
.
vars
[
var_name
]
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
if
dist_tensor
is
not
None
and
_need_reshard
(
dist_tensor
,
dist_op
,
False
):
for
index
,
item
in
enumerate
(
dist_op
.
dist_attr
.
process_mesh
.
processes
):
recv_rank
=
dist_tensor
.
dist_attr
.
process_mesh
.
processes
[
index
]
if
rank_id
==
item
:
_insert_send_op
(
block
,
idx
+
1
,
var
,
recv_rank
)
if
rank_id
==
recv_rank
:
_insert_recv_op
(
block
,
idx
+
1
,
var
,
item
)
cur_op_count
=
len
(
block
.
ops
)
idx_offset
=
idx_offset
+
cur_op_count
-
pre_op_count
pre_op_count
=
cur_op_count
idx
=
idx
+
idx_offset
+
1
else
:
idx
+=
1
# remove no need vars and ops in the main program
remove_no_need_in_main
(
auto_parallel_main_prog
,
dist_context
,
rank_id
)
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
浏览文件 @
529f1425
...
...
@@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext
from
paddle.distributed
import
fleet
from
paddle.distributed.auto_parallel.parallelizer
import
AutoParallelizer
from
paddle.distributed.auto_parallel.partitioner
import
Partitioner
from
paddle.distributed.auto_parallel.reshard
import
reshard
from
paddle.distributed.auto_parallel.reshard
import
reshard
,
HAS_SENT
,
HAS_RECV
,
HAS_ALLGATHER
from
paddle.distributed.auto_parallel.process_group
import
_g_process_group_map
from
paddle.distributed.auto_parallel.utils
import
print_program_with_dist_attr
...
...
@@ -143,7 +143,11 @@ def mlp_forward(train_program, start_program):
return
loss
,
train_program
,
start_program
def
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
):
def
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
,
change_process_mesh
=
False
):
loss
,
train_program
,
startup_program
=
mlp_forward
(
train_program
,
startup_program
)
...
...
@@ -157,6 +161,12 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program
=
completer
.
complete_forward_annotation
(
train_program
)
if
change_process_mesh
:
global
PP_MESH_1
dist_context
.
get_tensor_dist_attr_for_program
(
train_program
.
global_block
().
vars
[
"gelu_0.tmp_0"
]).
process_mesh
=
PP_MESH_1
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
startup_program
,
...
...
@@ -308,6 +318,25 @@ class TestMLPReshard(unittest.TestCase):
# parameter initialization of every rank should be different in the pipeline scene
self
.
assertTrue
(
check_initialization
(
dist_startup_prog
,
rank_id
))
def
test_mlp_pp_diff_process_mesh
(
self
):
HAS_SENT
.
clear
()
HAS_RECV
.
clear
()
HAS_ALLGATHER
.
clear
()
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
dist_context
=
DistributedContext
()
rank_id
=
1
dist_main_prog
,
dist_startup_prog
=
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
,
True
)
for
key
in
list
(
_g_process_group_map
.
keys
()):
del
_g_process_group_map
[
key
]
reshard
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
)
print_program_with_dist_attr
(
dist_main_prog
,
dist_context
)
# check send and recv result
self
.
assertTrue
(
check_send_recv_result
(
dist_main_prog
,
rank_id
))
self
.
assertTrue
(
check_initialization
(
dist_startup_prog
,
rank_id
))
def
test_mlp_dp
(
self
):
global
_global_parallel_strategy
_global_parallel_strategy
=
"dp"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录