Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2747de2b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2747de2b
编写于
3月 10, 2022
作者:
C
caozhou
提交者:
GitHub
3月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel]Update reshard for while sub block (#40366)
* update reshard for while sub block * fix code format error
上级
575dea8f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
511 addition
and
181 deletion
+511
-181
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+505
-181
python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py
...paddle/fluid/tests/unittests/auto_parallel_autoconvert.py
+6
-0
未找到文件。
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
2747de2b
...
...
@@ -29,6 +29,7 @@ from .process_group import new_process_group, ProcessGroup, _g_process_group_map
# NOTE: If op in _g_special_ops, it will not be resharded.
_g_special_ops
=
[
'check_finite_and_unscale'
,
'update_loss_scaling'
]
while_block_info
=
{}
class
AllGatherOpDesc
:
...
...
@@ -280,8 +281,20 @@ def _is_overlapped(shape_x, shape_y):
return
overlapped
def
_need_reshard
(
dist_tensor
,
dist_op
,
op_input
=
True
):
def
_need_reshard
(
dist_tensor
,
dist_op
,
actual_process_mesh
,
program
,
dist_context
,
op_input
=
True
):
"""Judge the tensor whether needs to be resharded."""
def
_is_unshard
(
dims_mapping
):
for
dim
in
dims_mapping
:
if
dim
!=
-
1
:
return
False
return
True
is_reshard
=
False
tensor_dist_attr
=
dist_tensor
.
dist_attr
tensor_name
=
dist_tensor
.
serial_tensor
.
name
...
...
@@ -289,32 +302,74 @@ def _need_reshard(dist_tensor, dist_op, op_input=True):
tensor_process_mesh
=
tensor_dist_attr
.
process_mesh
op_dist_attr
=
dist_op
.
dist_attr
op_input_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
op_process_mesh
=
op_dist_attr
.
process_mesh
op_process_mesh
=
actual_
process_mesh
if
op_input
:
op_input_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
op_process_mesh
=
op_dist_attr
.
process_mesh
if
all
(
map
(
lambda
x
:
x
is
not
None
,
[
tensor_dims_mapping
,
tensor_process_mesh
,
op_input_dims_mapping
,
op_process_mesh
])):
if
tensor_dims_mapping
!=
op_input_dims_mapping
or
tensor_process_mesh
!=
op_process_mesh
:
is_reshard
=
True
# dims_mapping
if
tensor_dims_mapping
!=
op_input_dims_mapping
:
if
dist_op
.
serial_op
.
type
==
"while"
:
sub_block
=
program
.
blocks
[
dist_op
.
serial_op
.
attr
(
"sub_block"
).
id
]
for
op
in
sub_block
.
ops
:
for
var_name
in
op
.
input_arg_names
:
if
var_name
==
tensor_name
:
dist_op_attr
=
dist_context
.
get_dist_op_for_program
(
op
).
dist_attr
var_dims_mapping
=
dist_op_attr
.
get_input_dims_mapping
(
var_name
)
if
var_dims_mapping
!=
tensor_dims_mapping
:
is_reshard
=
True
break
else
:
is_reshard
=
True
# process_mesh
if
tensor_process_mesh
!=
op_process_mesh
:
# when processes length is not the same, the dims mapping must be replicative now
if
len
(
tensor_process_mesh
.
processes
)
!=
len
(
op_process_mesh
.
processes
):
assert
_is_unshard
(
tensor_dims_mapping
)
assert
_is_unshard
(
op_input_dims_mapping
)
else
:
if
dist_tensor
.
serial_tensor
.
dtype
==
paddle
.
bool
:
raise
ValueError
(
"Bool var is not supported reshard."
)
# for while op, it should find the process mesh of op actually used the tensor as input
if
dist_op
.
serial_op
.
type
==
"while"
:
sub_block
=
program
.
blocks
[
dist_op
.
serial_op
.
attr
(
"sub_block"
).
id
]
for
op
in
sub_block
.
ops
:
for
var_name
in
op
.
input_arg_names
:
if
var_name
==
tensor_name
:
dist_op_attr
=
dist_context
.
get_dist_op_for_program
(
op
).
dist_attr
process_mesh
=
dist_op_attr
.
process_mesh
if
process_mesh
==
op_process_mesh
:
is_reshard
=
True
break
else
:
is_reshard
=
True
else
:
op_output_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
tensor_name
)
op_process_mesh
=
op_dist_attr
.
process_mesh
if
all
(
map
(
lambda
x
:
x
is
not
None
,
[
tensor_dims_mapping
,
tensor_process_mesh
,
op_output_dims_mapping
,
op_process_mesh
])):
if
tensor_process_mesh
!=
op_process_mesh
:
if
dist_tensor
.
serial_tensor
.
dtype
==
paddle
.
bool
:
raise
ValueError
(
"Bool var is not supported reshard."
)
is_reshard
=
True
if
tensor_dims_mapping
!=
op_output_dims_mapping
:
raise
ValueError
(
"It is not supported that tensor dims mapping is different from op output dims mapping."
)
return
is_reshard
...
...
@@ -329,13 +384,14 @@ def _compute_complete_shape(slice_shape, process_shape, dims_mapping):
return
complete_shape
def
find_op_desc_seq
(
dist_tensor
,
dist_op
):
def
find_op_desc_seq
(
dist_tensor
,
dist_op
,
actual_process_mesh
,
batch_size
):
"""
Find the op description sequence to reshard the source tensor for matching the op requirement.
Args:
dist_tensor (DistributedTensor): A distributed tensor.
dist_op (DistributedOperator): A distributed operator.
actual_process_mesh (ProcessMesh): The actual op process mesh.
Returns:
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
...
...
@@ -350,11 +406,16 @@ def find_op_desc_seq(dist_tensor, dist_op):
source_process_shape
=
source_process_mesh
.
topology
op_dist_attr
=
dist_op
.
dist_attr
target_process_mesh
=
op_dist_attr
.
process_mesh
target_process_mesh
=
actual_
process_mesh
target_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
target_process_group
=
target_process_mesh
.
processes
target_process_shape
=
target_process_mesh
.
topology
if
source_tensor
.
shape
[
0
]
<
0
:
new_shape
=
list
(
source_tensor
.
shape
)
new_shape
[
0
]
=
batch_size
source_tensor
.
desc
.
set_shape
(
new_shape
)
complete_shape
=
_compute_complete_shape
(
source_tensor
.
shape
,
source_process_shape
,
source_dims_mapping
)
op_desc_seq
=
{}
...
...
@@ -503,7 +564,7 @@ def find_op_desc_seq(dist_tensor, dist_op):
return
op_desc_seq
def
_insert_send_op
(
block
,
idx
,
tensor
,
dst
):
def
_insert_send_op
(
block
,
idx
,
tensor
,
dst
,
op_role
):
"""Insert send op into block at the given index."""
op_type
=
'send_v2'
block
.
_insert_op
(
...
...
@@ -514,10 +575,11 @@ def _insert_send_op(block, idx, tensor, dst):
'ring_id'
:
0
,
'peer'
:
dst
,
'use_calc_stream'
:
True
,
'op_role'
:
op_role
})
def
_insert_recv_op
(
block
,
idx
,
tensor
,
src
):
def
_insert_recv_op
(
block
,
idx
,
tensor
,
src
,
op_role
):
"""Insert recv op into block at the given index."""
op_type
=
'recv_v2'
block
.
_insert_op
(
...
...
@@ -531,14 +593,16 @@ def _insert_recv_op(block, idx, tensor, src):
'out_shape'
:
tensor
.
shape
,
'dtype'
:
tensor
.
dtype
,
'use_calc_stream'
:
True
,
'op_role'
:
op_role
})
def
_insert_concat_op
(
block
,
idx
,
tensors
,
axis
):
def
_insert_concat_op
(
block
,
idx
,
tensors
,
axis
,
op_role
):
"""Insert concat op into block at the given block."""
inputs
=
{
'X'
:
tensors
}
attrs
=
{}
attrs
[
'axis'
]
=
axis
attrs
[
'op_role'
]
=
op_role
helper
=
LayerHelper
(
'concat'
,
**
locals
())
with
paddle
.
static
.
program_guard
(
block
.
program
):
out
=
helper
.
create_variable_for_type_inference
(
...
...
@@ -548,7 +612,8 @@ def _insert_concat_op(block, idx, tensors, axis):
return
out
def
_insert_slice_op
(
block
,
idx
,
tensor
,
starts
,
ends
,
axes
,
new_var_name
):
def
_insert_slice_op
(
block
,
idx
,
tensor
,
starts
,
ends
,
axes
,
new_var_name
,
op_role
):
"""Insert slice op into block at the given block."""
inputs
=
{
'Input'
:
tensor
}
infer_flags
=
list
(
1
for
i
in
range
(
len
(
axes
)))
...
...
@@ -556,24 +621,23 @@ def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name):
"axes"
:
axes
,
"starts"
:
starts
,
"ends"
:
ends
,
"infer_flags"
:
infer_flags
"infer_flags"
:
infer_flags
,
'op_role'
:
op_role
}
helper
=
LayerHelper
(
'slice'
,
**
locals
())
out
=
block
.
create_var
(
name
=
new_var_name
,
dtype
=
tensor
.
dtype
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
name
=
new_var_name
,
dtype
=
tensor
.
dtype
,
type
=
tensor
.
type
)
block
.
_insert_op
(
idx
,
type
=
"slice"
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
[
out
]},
attrs
=
attrs
)
return
out
def
_insert_split_op
(
block
,
idx
,
tensor
,
num_or_sections
):
def
_insert_split_op
(
block
,
idx
,
tensor
,
num_or_sections
,
op_role
):
"""Insert split op into block at the given index."""
helper
=
LayerHelper
(
'split'
,
**
locals
())
input_shape
=
tensor
.
shape
inputs
=
{
'X'
:
tensor
}
attrs
=
{
'num'
:
num_or_sections
,
"axis"
:
0
}
attrs
=
{
'num'
:
num_or_sections
,
'axis'
:
0
,
'op_role'
:
op_role
}
with
paddle
.
static
.
program_guard
(
block
.
program
):
outs
=
[
helper
.
create_variable_for_type_inference
(
...
...
@@ -584,7 +648,7 @@ def _insert_split_op(block, idx, tensor, num_or_sections):
return
outs
def
_insert_allgather_op
(
block
,
idx
,
tensor
,
ranks
):
def
_insert_allgather_op
(
block
,
idx
,
tensor
,
ranks
,
op_role
):
"""Insert allgather op into block at the given index."""
def
_insert_fill_constant_op
(
block
,
idx
):
...
...
@@ -597,6 +661,7 @@ def _insert_allgather_op(block, idx, tensor, ranks):
attrs
[
'str_value'
]
=
str
(
int
(
"1"
))
attrs
[
'value'
]
=
int
(
"1"
)
attrs
[
'dtype'
]
=
out
.
dtype
attrs
[
'op_role'
]
=
op_role
utils
.
get_shape_tensor_inputs
(
inputs
=
inputs
,
attrs
=
attrs
,
shape
=
[
0
],
op_type
=
'fill_constant'
)
block
.
_insert_op
(
...
...
@@ -625,14 +690,16 @@ def _insert_allgather_op(block, idx, tensor, ranks):
inputs
=
{
'X'
:
[
fill_constant_out
]},
outputs
=
{
'Out'
:
[
fill_constant_out
]},
attrs
=
{
'ring_id'
:
0
,
'use_calc_stream'
:
True
})
'use_calc_stream'
:
True
,
'op_role'
:
op_role
})
# insert c_sync_calc_stream op
block
.
_insert_op
(
idx
+
2
,
type
=
"c_sync_calc_stream"
,
inputs
=
{
'X'
:
[
fill_constant_out
]},
outputs
=
{
'Out'
:
[
fill_constant_out
]})
outputs
=
{
'Out'
:
[
fill_constant_out
]},
attrs
=
{
'op_role'
:
op_role
})
idx_offset
=
3
# insert c_allgather op
...
...
@@ -649,20 +716,21 @@ def _insert_allgather_op(block, idx, tensor, ranks):
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'nranks'
:
group
.
nranks
'nranks'
:
group
.
nranks
,
'op_role'
:
op_role
})
idx_offset
+=
1
# insert split op
split_out
=
_insert_split_op
(
block
,
idx
+
idx_offset
,
allgather_out
,
group
.
nranks
)
group
.
nranks
,
op_role
)
idx_offset
+=
1
tensor_list
.
extend
(
split_out
)
return
tensor_list
,
idx_offset
def
_concat_partitions_with_op
(
partition_tensor_list
,
tensor
,
partition_index
,
block
,
idx
):
block
,
idx
,
op_role
):
"""Concat the tensors and insert concat op."""
if
not
partition_tensor_list
:
partition_tensor_list
.
append
((
tensor
,
partition_index
))
...
...
@@ -674,13 +742,13 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
partition_tensor_list
[
i
][
1
],
partition_index
)
if
concat_axis
!=
-
1
:
has_concat
=
True
_
=
_insert_concat_op
(
block
,
idx
[
0
],
[
partition_tensor_list
[
i
][
0
],
tensor
],
concat_axis
)
\
_
=
_insert_concat_op
(
block
,
idx
[
0
],
[
partition_tensor_list
[
i
][
0
],
tensor
],
concat_axis
,
op_role
)
\
if
first_order
==
0
else
\
_insert_concat_op
(
block
,
idx
[
0
],
[
tensor
,
partition_tensor_list
[
i
][
0
]],
concat_axis
)
_insert_concat_op
(
block
,
idx
[
0
],
[
tensor
,
partition_tensor_list
[
i
][
0
]],
concat_axis
,
op_role
)
partition_tensor_list
.
pop
(
i
)
idx
[
0
]
+=
1
_concat_partitions_with_op
(
partition_tensor_list
,
_
,
new_partition
,
block
,
idx
)
new_partition
,
block
,
idx
,
op_role
)
break
i
+=
1
if
not
has_concat
:
...
...
@@ -692,8 +760,47 @@ HAS_RECV = {}
HAS_ALLGATHER
=
{}
def
parse_op_desc
(
program
,
rank_id
,
op_desc_seq
,
var_name
,
reshard_op
,
dist_context
):
def
_get_while_op_actual_process_mesh
(
op
,
program
,
rank_id
,
dist_context
):
"""Get the while op actual Process mesh corresponding to rank"""
assert
op
.
type
==
"while"
while_op_process_mesh
=
dist_context
.
get_dist_op_for_program
(
op
).
dist_attr
.
process_mesh
sub_block
=
program
.
blocks
[
op
.
attr
(
"sub_block"
).
id
]
ops
=
sub_block
.
ops
actual_process_mesh
=
None
for
op
in
ops
:
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
if
process_mesh
==
while_op_process_mesh
:
continue
if
rank_id
in
process_mesh
.
processes
:
raw_process_mesh
=
process_mesh
break
if
actual_process_mesh
is
None
and
rank_id
in
while_op_process_mesh
.
processes
:
actual_process_mesh
=
while_op_process_mesh
assert
actual_process_mesh
is
not
None
return
actual_process_mesh
def
_get_var
(
var_name
,
block
,
program
):
"""Get var in the parent block if not found in the current block"""
var
=
None
if
var_name
in
block
.
vars
:
var
=
block
.
vars
[
var_name
]
else
:
parent_block
=
program
.
blocks
[
block
.
parent_idx
]
if
var_name
in
parent_block
.
vars
:
var
=
parent_block
.
vars
[
var_name
]
assert
var
is
not
None
return
var
def
parse_op_desc
(
block
,
rank_id
,
op_desc_seq
,
var_name
,
reshard_op
,
dist_context
,
program
,
actual_process_mesh
):
"""Parse op desc sequence and insert op in the block"""
global
HAS_SENT
global
HAS_RECV
...
...
@@ -703,9 +810,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
if
rank_id
not
in
op_desc_seq
.
keys
():
return
op_desc_list
=
op_desc_seq
[
rank_id
]
block
=
program
.
global_block
()
assert
var_name
in
block
.
vars
.
keys
(
),
"The {} cannot be found in the {} program."
.
format
(
var_name
,
rank_id
)
idx
=
None
for
index
,
op
in
list
(
enumerate
(
block
.
ops
)):
...
...
@@ -716,7 +820,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
rank_id
)
matched_op
=
block
.
ops
[
idx
]
source_tensor
=
block
.
vars
[
var_name
]
source_tensor
=
_get_var
(
var_name
,
block
,
program
)
for
op_desc
in
op_desc_list
:
if
isinstance
(
op_desc
,
AllGatherOpDesc
):
# noqa: F401
if
var_name
not
in
HAS_ALLGATHER
.
keys
():
...
...
@@ -724,7 +828,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
if
not
HAS_ALLGATHER
[
var_name
]
or
op_desc
.
group
not
in
list
(
map
(
lambda
x
:
x
[
0
],
HAS_ALLGATHER
[
var_name
])):
tensor_list
,
idx_offset
=
_insert_allgather_op
(
block
,
idx
,
source_tensor
,
op_desc
.
group
)
block
,
idx
,
source_tensor
,
op_desc
.
group
,
reshard_op
.
attr
(
'op_role'
))
idx
+=
idx_offset
tensor_name_list
=
[
var
.
name
for
var
in
tensor_list
]
HAS_ALLGATHER
[
var_name
].
append
(
...
...
@@ -743,7 +848,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
if
var_name
not
in
HAS_SENT
.
keys
():
HAS_SENT
[
var_name
]
=
[]
if
op_desc
.
dst
not
in
HAS_SENT
[
var_name
]:
_insert_send_op
(
block
,
idx
,
source_tensor
,
op_desc
.
dst
)
_insert_send_op
(
block
,
idx
,
source_tensor
,
op_desc
.
dst
,
reshard_op
.
attr
(
'op_role'
))
idx
+=
1
HAS_SENT
[
var_name
].
append
(
op_desc
.
dst
)
...
...
@@ -758,8 +864,10 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
recv_tensor
=
block
.
create_var
(
name
=
unique_name
.
generate
(
var_name
+
"@recv"
),
shape
=
shape
,
dtype
=
source_tensor
.
dtype
)
_insert_recv_op
(
block
,
idx
,
recv_tensor
,
op_desc
.
src
)
dtype
=
source_tensor
.
dtype
,
type
=
source_tensor
.
type
)
_insert_recv_op
(
block
,
idx
,
recv_tensor
,
op_desc
.
src
,
reshard_op
.
attr
(
'op_role'
))
tensor_list
.
append
(
recv_tensor
)
idx
+=
1
HAS_RECV
[
var_name
][
op_desc
.
src
]
=
recv_tensor
...
...
@@ -772,7 +880,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
for
index
,
tensor
in
enumerate
(
tensor_list
):
_concat_partitions_with_op
(
partition_tensor_list
,
tensor
,
partition_index_list
[
index
],
block
,
idx_list
)
idx_list
,
reshard_op
.
attr
(
'op_role'
)
)
idx
=
idx_list
[
0
]
elif
isinstance
(
op_desc
,
SliceOpDesc
):
...
...
@@ -787,11 +895,11 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
starts
=
op_desc
.
starts
,
ends
=
op_desc
.
ends
,
axes
=
op_desc
.
axes
,
new_var_name
=
new_name
)
new_var_name
=
new_name
,
op_role
=
reshard_op
.
attr
(
'op_role'
))
tensor_attr
=
TensorDistributedAttribute
()
process_mesh
=
dist_context
.
get_op_dist_attr_for_program
(
matched_op
).
process_mesh
process_mesh
=
actual_process_mesh
dims_mapping
=
dist_context
.
get_op_dist_attr_for_program
(
matched_op
).
get_input_dims_mapping
(
var_name
)
tensor_attr
.
dims_mapping
=
dims_mapping
...
...
@@ -799,11 +907,29 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
dist_context
.
set_tensor_dist_attr_for_program
(
target_tensor
,
tensor_attr
)
if
op
.
type
==
"while"
:
global
while_block_info
# var_reshard_mapping means the while op input need be changed to
if
"var_reshard_mapping"
not
in
while_block_info
[
op
.
attr
(
"sub_block"
).
id
].
keys
():
while_block_info
[
op
.
attr
(
"sub_block"
).
id
][
"var_reshard_mapping"
]
=
{}
while_block_info
[
op
.
attr
(
"sub_block"
).
id
][
"var_reshard_mapping"
][
var_name
]
=
target_tensor
.
name
# rename op input name according to new name
for
op
in
block
.
ops
:
for
name
in
op
.
input_arg_names
:
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
op
)
if
name
==
var_name
and
op_dist_attr
is
not
None
:
if
op
.
desc
.
id
()
==
matched_op
.
desc
.
id
():
op
.
desc
.
_rename_input
(
name
,
target_tensor
.
name
)
op_dist_attr
.
set_input_dims_mapping
(
target_tensor
.
name
,
dims_mapping
)
op_dist_attr
.
set_input_dist_attr
(
name
,
None
)
continue
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
op_process_mesh
=
op_dist_attr
.
process_mesh
op_input_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
var_name
)
...
...
@@ -819,102 +945,166 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
not_remove_op_ref
=
[
"create_py_reader"
,
"create_double_buffer_reader"
,
"read"
]
remove_op_idx
=
[]
block
=
auto_parallel_main_prog
.
global_block
()
ops
=
block
.
ops
vars
=
block
.
vars
for
idx
,
op
in
enumerate
(
ops
):
# handle read op in the pipeline scene specially, it will be removed in the future.
if
op
.
type
==
"read"
:
dim_list
=
[]
for
var_name
in
op
.
output_arg_names
:
dim_list
.
extend
(
vars
[
var_name
].
shape
)
for
i
in
range
(
idx
,
-
1
,
-
1
):
if
ops
[
i
].
type
==
"create_py_reader"
:
ops
[
i
].
_set_attr
(
"shape_concat"
,
dim_list
)
break
continue
# replace the input and output of c_sync_comm_stream op when in pipeline scene.
if
op
.
type
==
"c_sync_comm_stream"
:
need_save
=
[]
for
var_name
in
op
.
input_arg_names
:
process_mesh
=
dist_context
.
get_tensor_dist_attr_for_program
(
vars
[
var_name
]).
process_mesh
if
rank_id
in
process_mesh
.
processes
:
need_save
.
append
(
var_name
)
if
not
need_save
:
remove_op_idx
.
append
(
idx
)
global
while_block_info
# NOTE: The nested sub block is not be supported now.
remove_block_order
=
[]
for
block_idx
in
while_block_info
:
remove_block_order
.
append
(
block_idx
)
for
block_idx
,
block
in
enumerate
(
auto_parallel_main_prog
.
blocks
):
if
block_idx
not
in
remove_block_order
:
remove_block_order
.
append
(
block_idx
)
# the sub block should be removed first
for
block_idx
in
remove_block_order
:
remove_op_idx
=
[]
block
=
auto_parallel_main_prog
.
blocks
[
block_idx
]
ops
=
block
.
ops
vars
=
block
.
vars
for
idx
,
op
in
enumerate
(
ops
):
if
op
.
type
==
"read"
:
dim_list
=
[]
for
var_name
in
op
.
output_arg_names
:
dim_list
.
extend
(
_get_var
(
var_name
,
block
,
auto_parallel_main_prog
)
.
shape
)
for
i
in
range
(
idx
,
-
1
,
-
1
):
if
ops
[
i
].
type
==
"create_py_reader"
:
ops
[
i
].
_set_attr
(
"shape_concat"
,
dim_list
)
break
continue
proto
=
OpProtoHolder
.
instance
().
get_op_proto
(
op
.
type
)
op
.
desc
.
set_input
(
proto
.
inputs
[
0
].
name
,
need_save
)
op
.
desc
.
set_output
(
proto
.
outputs
[
0
].
name
,
need_save
)
continue
# replace the input and output of c_sync_comm_stream op when in pipeline scene.
if
op
.
type
==
"c_sync_comm_stream"
:
need_save
=
[]
for
var_name
in
op
.
input_arg_names
:
process_mesh
=
dist_context
.
get_tensor_dist_attr_for_program
(
_get_var
(
var_name
,
block
,
auto_parallel_main_prog
)).
process_mesh
if
rank_id
in
process_mesh
.
processes
:
need_save
.
append
(
var_name
)
if
not
need_save
:
remove_op_idx
.
append
(
idx
)
continue
# judge the other op whether should be removed.
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
op
)
if
op_dist_attr
is
not
None
:
op_process_mesh
=
op_dist_attr
.
process_mesh
if
rank_id
not
in
op_process_mesh
.
processes
and
op
.
type
not
in
not_remove_op_ref
:
remove_op_idx
.
append
(
idx
)
proto
=
OpProtoHolder
.
instance
().
get_op_proto
(
op
.
type
)
op
.
desc
.
set_input
(
proto
.
inputs
[
0
].
name
,
need_save
)
op
.
desc
.
set_output
(
proto
.
outputs
[
0
].
name
,
need_save
)
continue
for
idx
in
remove_op_idx
[::
-
1
]:
block
.
_remove_op
(
idx
)
# judge the other op whether should be removed.
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
op
)
if
op_dist_attr
is
not
None
:
op_process_mesh
=
op_dist_attr
.
process_mesh
if
rank_id
not
in
op_process_mesh
.
processes
and
op
.
type
not
in
not_remove_op_ref
:
remove_op_idx
.
append
(
idx
)
for
idx
in
remove_op_idx
[::
-
1
]:
block
.
_remove_op
(
idx
)
def
_remove_no_need_vars
(
auto_parallel_main_prog
,
dist_params_grads
):
"""Remove no need vars in the main program"""
remove_vars
=
set
()
block
=
auto_parallel_main_prog
.
global_block
()
ops
=
block
.
ops
vars
=
block
.
vars
need_vars
=
set
()
for
op
in
ops
:
for
var_name
in
op
.
input_arg_names
:
if
var_name
in
vars
:
need_vars
.
add
(
var_name
)
for
var_name
in
op
.
output_arg_names
:
if
var_name
in
vars
:
need_vars
.
add
(
var_name
)
for
var
in
vars
:
if
var
not
in
need_vars
:
remove_vars
.
add
(
var
)
# change dist_params_grads
param_grad_map
=
{}
for
op
in
ops
:
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Optimize
):
if
"Param"
in
op
.
input_names
and
"Grad"
in
op
.
input_names
:
param_name
=
op
.
input
(
"Param"
)[
0
]
grad_name
=
op
.
input
(
"Grad"
)[
0
]
param_grad_map
[
param_name
]
=
grad_name
need_remove_idx
=
[]
for
idx
,
item
in
enumerate
(
dist_params_grads
):
if
item
[
0
].
name
not
in
param_grad_map
.
keys
():
need_remove_idx
.
append
(
idx
)
for
idx
in
need_remove_idx
[::
-
1
]:
dist_params_grads
.
pop
(
idx
)
idx
=
0
while
idx
<
len
(
dist_params_grads
):
param_name
=
dist_params_grads
[
idx
][
0
].
name
grad_name
=
dist_params_grads
[
idx
][
1
].
name
if
grad_name
!=
param_grad_map
[
param_name
]:
dist_params_grads
[
idx
]
=
(
vars
[
param_name
],
vars
[
param_grad_map
[
param_name
]])
idx
+=
1
for
block_idx
,
block
in
enumerate
(
auto_parallel_main_prog
.
blocks
):
remove_vars
=
set
()
ops
=
block
.
ops
vars
=
block
.
vars
need_vars
=
set
()
for
op
in
ops
:
for
var_name
in
op
.
input_arg_names
:
if
var_name
in
vars
:
need_vars
.
add
(
var_name
)
for
var_name
in
op
.
output_arg_names
:
if
var_name
in
vars
:
need_vars
.
add
(
var_name
)
for
var
in
vars
:
if
var
not
in
need_vars
:
remove_vars
.
add
(
var
)
# change dist_params_grads, the optimize op just in block 0.
if
block_idx
==
0
:
param_grad_map
=
{}
for
op
in
ops
:
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Optimize
):
if
"Param"
in
op
.
input_names
and
"Grad"
in
op
.
input_names
:
param_name
=
op
.
input
(
"Param"
)[
0
]
grad_name
=
op
.
input
(
"Grad"
)[
0
]
param_grad_map
[
param_name
]
=
grad_name
need_remove_idx
=
[]
for
idx
,
item
in
enumerate
(
dist_params_grads
):
if
item
[
0
].
name
not
in
param_grad_map
.
keys
():
need_remove_idx
.
append
(
idx
)
for
idx
in
need_remove_idx
[::
-
1
]:
dist_params_grads
.
pop
(
idx
)
idx
=
0
while
idx
<
len
(
dist_params_grads
):
param_name
=
dist_params_grads
[
idx
][
0
].
name
grad_name
=
dist_params_grads
[
idx
][
1
].
name
if
grad_name
!=
param_grad_map
[
param_name
]:
dist_params_grads
[
idx
]
=
(
vars
[
param_name
],
vars
[
param_grad_map
[
param_name
]])
idx
+=
1
for
var
in
remove_vars
:
block
.
_remove_var
(
var
)
for
var
in
remove_vars
:
block
.
_remove_var
(
var
)
def
_change_while_op_input_and_output
(
auto_parallel_main_prog
,
dist_context
):
"""Change while op input and output after the corresponding sub block ops removed"""
global
while_block_info
for
sub_block_idx
in
while_block_info
:
sub_block
=
auto_parallel_main_prog
.
blocks
[
sub_block_idx
]
parent_while_op_id
=
while_block_info
[
sub_block_idx
][
"op_id"
]
parent_block
=
auto_parallel_main_prog
.
blocks
[
sub_block
.
parent_idx
]
sub_block_op_inputs
=
set
()
sub_block_op_outputs
=
[]
for
op
in
sub_block
.
ops
:
# skip the input and output of operators inserted in the reshard phase
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
dist_op
:
for
var_name
in
op
.
output_arg_names
:
if
var_name
not
in
sub_block_op_outputs
:
sub_block_op_outputs
.
append
(
var_name
)
for
var_name
in
op
.
input_arg_names
:
sub_block_op_inputs
.
add
(
var_name
)
# find the while op
while_op
=
None
for
op
in
parent_block
.
ops
:
if
op
.
desc
.
id
()
==
parent_while_op_id
and
op
.
type
==
"while"
:
while_op
=
op
break
assert
while_op
is
not
None
# find the actual input and output of while op
proto
=
OpProtoHolder
.
instance
().
get_op_proto
(
while_op
.
type
)
new_X
=
[]
for
var_name
in
while_op
.
input
(
"X"
):
if
var_name
in
sub_block_op_inputs
:
new_X
.
append
(
var_name
)
assert
new_X
while_op
.
desc
.
set_input
(
proto
.
inputs
[
0
].
name
,
new_X
)
new_Out
=
[]
for
var_name
in
while_op
.
output
(
"Out"
):
for
output_name
in
sub_block_op_outputs
[::
-
1
]:
if
output_name
.
find
(
var_name
)
!=
-
1
:
new_Out
.
append
(
output_name
)
assert
new_Out
while_op
.
desc
.
set_output
(
proto
.
outputs
[
0
].
name
,
new_Out
)
def
remove_no_need_in_main
(
auto_parallel_main_prog
,
dist_context
,
rank_id
,
dist_params_grads
):
"""Remove no need vars and ops in the main program."""
_remove_no_need_ops
(
auto_parallel_main_prog
,
dist_context
,
rank_id
)
_change_while_op_input_and_output
(
auto_parallel_main_prog
,
dist_context
)
_remove_no_need_vars
(
auto_parallel_main_prog
,
dist_params_grads
)
...
...
@@ -992,8 +1182,70 @@ def remove_no_need_in_startup(auto_parallel_main_prog,
startup_block
.
_remove_op
(
idx
)
def
reshard
(
auto_parallel_main_prog
,
auto_parallel_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
):
def
_get_process_meshes
(
op
,
program
,
dist_context
):
"""Get all process meshes when op has sub block."""
assert
op
.
has_attr
(
"sub_block"
)
sub_block
=
program
.
blocks
[
op
.
attr
(
"sub_block"
).
id
]
ops
=
sub_block
.
ops
op_process_mesh
=
dist_context
.
get_dist_op_for_program
(
op
).
dist_attr
.
process_mesh
process_meshes
=
[]
for
op
in
ops
:
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
if
process_mesh
not
in
process_meshes
and
process_mesh
!=
op_process_mesh
:
process_meshes
.
append
(
process_mesh
)
if
not
process_meshes
:
process_meshes
.
append
(
op_process_mesh
)
return
process_meshes
def
_is_condition_replicative
(
op
,
program
,
dist_context
):
assert
op
.
type
==
"while"
sub_block
=
program
.
blocks
[
op
.
attr
(
"sub_block"
).
id
]
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
op_dist_attr
=
dist_op
.
dist_attr
# the dims mapping of condition tensor should be replicative
for
var_name
in
op
.
input
(
"Condition"
):
var
=
_get_var
(
var_name
,
sub_block
,
program
)
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
tensor_dist_attr
=
dist_tensor
.
dist_attr
var_dims_mapping
=
tensor_dist_attr
.
dims_mapping
for
dim
in
var_dims_mapping
:
if
dim
!=
-
1
:
return
False
return
True
def
_get_op_process_meshes
(
op
,
dist_context
):
process_meshes
=
[]
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
op_process_mesh
=
dist_op
.
dist_attr
.
process_mesh
for
process_mesh
in
dist_context
.
process_meshes
:
if
set
(
process_mesh
.
processes
)
&
(
set
(
op_process_mesh
.
processes
)
)
and
len
(
process_mesh
.
processes
)
<=
len
(
op_process_mesh
.
processes
):
process_meshes
.
append
(
process_mesh
)
# it means the process mesh is not a union when process meshes is null
if
not
process_meshes
:
process_meshes
.
append
(
op_process_mesh
)
return
process_meshes
def
reshard
(
auto_parallel_main_prog
,
auto_parallel_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
,
batch_size
=
None
):
"""
Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute.
...
...
@@ -1019,65 +1271,137 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
return
True
return
False
block
=
auto_parallel_main_prog
.
global_block
()
idx
=
0
while
idx
<
len
(
block
.
ops
):
pre_op_count
=
len
(
block
.
ops
)
op
=
block
.
ops
[
idx
]
global
while_block_info
for
block_idx
,
block
in
enumerate
(
auto_parallel_main_prog
.
blocks
):
if
block_idx
in
while_block_info
:
if
"var_reshard_mapping"
in
while_block_info
[
block_idx
]:
var_reshard_mapping
=
while_block_info
[
block_idx
][
"var_reshard_mapping"
]
for
op
in
block
.
ops
:
for
var_name
in
op
.
input_arg_names
:
if
var_name
in
var_reshard_mapping
:
op
.
desc
.
_rename_input
(
var_name
,
var_reshard_mapping
[
var_name
])
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
op_dist_attr
=
dist_op
.
dist_attr
if
op_dist_attr
.
process_mesh
==
while_block_info
[
block_idx
][
"actual_process_mesh"
]:
dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
var_name
)
op_dist_attr
.
set_input_dims_mapping
(
var_reshard_mapping
[
var_name
],
dims_mapping
)
op_dist_attr
.
set_input_dist_attr
(
var_name
,
None
)
# the outputs also need to be renamed when the output name is the same with input name
for
var_name
in
op
.
output_arg_names
:
if
var_name
in
var_reshard_mapping
:
op
.
desc
.
_rename_output
(
var_name
,
var_reshard_mapping
[
var_name
])
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
op_dist_attr
=
dist_op
.
dist_attr
if
op_dist_attr
.
process_mesh
==
while_block_info
[
block_idx
][
"actual_process_mesh"
]:
dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
var_name
)
op_dist_attr
.
set_output_dims_mapping
(
var_reshard_mapping
[
var_name
],
dims_mapping
)
op_dist_attr
.
set_output_dist_attr
(
var_name
,
None
)
idx
=
0
while
idx
<
len
(
block
.
ops
):
pre_op_count
=
len
(
block
.
ops
)
op
=
block
.
ops
[
idx
]
if
_is_special_op
(
op
):
idx
+=
1
continue
if
_is_special_op
(
op
):
idx
+=
1
continue
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
dist_op
is
not
None
:
process_meshes
=
[]
if
op
.
type
==
"while"
:
if
not
_is_condition_replicative
(
op
,
auto_parallel_main_prog
,
dist_context
):
raise
ValueError
(
"Please check the condition due to the dims mapping is not replicative."
)
process_meshes
=
_get_process_meshes
(
op
,
auto_parallel_main_prog
,
dist_context
)
assert
process_meshes
if
op
.
attr
(
"sub_block"
).
id
not
in
while_block_info
:
while_block_info
[
op
.
attr
(
"sub_block"
).
id
]
=
{}
while_block_info
[
op
.
attr
(
"sub_block"
).
id
][
"op_id"
]
=
op
.
desc
.
id
()
while_block_info
[
op
.
attr
(
"sub_block"
).
id
][
"actual_process_mesh"
]
=
_get_while_op_actual_process_mesh
(
op
,
auto_parallel_main_prog
,
rank_id
,
dist_context
)
else
:
process_meshes
=
_get_op_process_meshes
(
op
,
dist_context
)
input_vars
=
None
if
op
.
type
==
"while"
:
input_var_names
=
op
.
input
(
"X"
)
else
:
input_var_names
=
op
.
input_arg_names
idx_offset
=
0
for
var_name
in
op
.
input_arg_names
:
# skip lod_tensor_blocking_queue_0
if
var_name
==
"lod_tensor_blocking_queue_0"
:
continue
var
=
_get_var
(
var_name
,
block
,
auto_parallel_main_prog
)
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
for
process_mesh
in
process_meshes
:
if
dist_tensor
is
not
None
and
_need_reshard
(
dist_tensor
,
dist_op
,
process_mesh
,
auto_parallel_main_prog
,
dist_context
):
reshard_op_desc
=
find_op_desc_seq
(
dist_tensor
,
dist_op
,
process_mesh
,
batch_size
)
parse_op_desc
(
block
,
rank_id
,
reshard_op_desc
,
var_name
,
op
,
dist_context
,
auto_parallel_main_prog
,
process_mesh
)
cur_op_count
=
len
(
block
.
ops
)
idx_offset
=
idx_offset
+
cur_op_count
-
pre_op_count
pre_op_count
=
cur_op_count
idx
=
idx
+
idx_offset
+
1
else
:
idx
+=
1
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
dist_op
is
not
None
:
idx_offset
=
0
for
var_name
in
op
.
input_arg_names
:
# skip lod_tensor_blocking_queue_0
if
var_name
==
"lod_tensor_blocking_queue_0"
:
continue
var
=
block
.
vars
[
var_name
]
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
if
dist_tensor
is
not
None
and
_need_reshard
(
dist_tensor
,
dist_op
):
reshard_op_desc
=
find_op_desc_seq
(
dist_tensor
,
dist_op
)
parse_op_desc
(
auto_parallel_main_prog
,
rank_id
,
reshard_op_desc
,
var_name
,
op
,
dist_context
)
cur_op_count
=
len
(
block
.
ops
)
idx_offset
=
idx_offset
+
cur_op_count
-
pre_op_count
pre_op_count
=
cur_op_count
idx
=
idx
+
idx_offset
+
1
else
:
idx
+=
1
# insert send and recv op if output process mesh is different from tensor process mesh
idx
=
0
skip_ops
=
[
"create_py_reader"
,
"create_double_buffer_reader"
,
"read"
]
skip_ops
+=
_g_special_ops
while
idx
<
len
(
block
.
ops
):
pre_op_count
=
len
(
block
.
ops
)
op
=
block
.
ops
[
idx
]
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
dist_op
is
not
None
and
op
.
type
not
in
skip_ops
:
for
var_name
in
op
.
output_arg_names
:
var
=
block
.
vars
[
var_name
]
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
if
dist_tensor
is
not
None
and
_need_reshard
(
dist_tensor
,
dist_op
,
False
):
for
index
,
item
in
enumerate
(
dist_op
.
dist_attr
.
process_mesh
.
processes
):
recv_rank
=
dist_tensor
.
dist_attr
.
process_mesh
.
processes
[
index
]
if
rank_id
==
item
:
_insert_send_op
(
block
,
idx
+
1
,
var
,
recv_rank
)
if
rank_id
==
recv_rank
:
_insert_recv_op
(
block
,
idx
+
1
,
var
,
item
)
cur_op_count
=
len
(
block
.
ops
)
idx_offset
=
idx_offset
+
cur_op_count
-
pre_op_count
pre_op_count
=
cur_op_count
idx
=
idx
+
idx_offset
+
1
else
:
idx
+=
1
# insert send and recv op if output process mesh is different from tensor process mesh
idx
=
0
# skip reader and ops whose process mesh is union
skip_ops
=
[
"create_py_reader"
,
"create_double_buffer_reader"
,
"read"
,
"while"
,
"write_to_array"
,
"read_from_array"
]
skip_ops
+=
_g_special_ops
while
idx
<
len
(
block
.
ops
):
pre_op_count
=
len
(
block
.
ops
)
op
=
block
.
ops
[
idx
]
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
dist_op
is
not
None
and
op
.
type
not
in
skip_ops
:
for
var_name
in
op
.
output_arg_names
:
var
=
_get_var
(
var_name
,
block
,
auto_parallel_main_prog
)
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
if
dist_tensor
is
not
None
and
_need_reshard
(
dist_tensor
,
dist_op
,
process_mesh
,
auto_parallel_main_prog
,
dist_context
,
False
):
for
index
,
item
in
enumerate
(
dist_op
.
dist_attr
.
process_mesh
.
processes
):
recv_rank
=
dist_tensor
.
dist_attr
.
process_mesh
.
processes
[
index
]
if
rank_id
==
item
:
_insert_send_op
(
block
,
idx
+
1
,
var
,
recv_rank
,
op
.
attr
(
'op_role'
))
if
rank_id
==
recv_rank
:
_insert_recv_op
(
block
,
idx
+
1
,
var
,
item
,
op
.
attr
(
'op_role'
))
cur_op_count
=
len
(
block
.
ops
)
idx_offset
=
idx_offset
+
cur_op_count
-
pre_op_count
pre_op_count
=
cur_op_count
idx
=
idx
+
idx_offset
+
1
else
:
idx
+=
1
# remove no need vars and ops in the main program
remove_no_need_in_main
(
auto_parallel_main_prog
,
dist_context
,
rank_id
,
...
...
python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py
浏览文件 @
2747de2b
...
...
@@ -32,6 +32,7 @@ from paddle.fluid.initializer import NumpyArrayInitializer
from
paddle.distributed.auto_parallel.utils
import
save_distributed_checkpoint
,
load_distributed_checkpoint
,
load_checkpoint_into_program
from
paddle.distributed.auto_parallel.utils
import
get_dist_attr
,
merge_and_slice_parameter
,
load_parameter_into_program
from
paddle.distributed.auto_parallel.reshard
import
HAS_SENT
,
HAS_RECV
,
HAS_ALLGATHER
from
paddle.distributed.auto_parallel.dist_context
import
set_default_distributed_context
paddle
.
enable_static
()
_global_parallel_strategy
=
None
...
...
@@ -185,6 +186,7 @@ class TestMLPAutoConvert(unittest.TestCase):
str
(
paddle
.
distributed
.
get_rank
())))
def
test_mlp_mp2pp
(
self
):
set_default_distributed_context
(
None
)
global
_global_parallel_strategy
_global_parallel_strategy
=
"mp"
global
_global_process_mesh
...
...
@@ -211,6 +213,7 @@ class TestMLPAutoConvert(unittest.TestCase):
fetch_list
=
[
loss
])
last_res
=
res
[
0
]
set_default_distributed_context
(
None
)
_global_parallel_strategy
=
"pp"
_global_process_mesh
=
auto
.
ProcessMesh
([
0
,
1
])
global
PP_MESH_0
...
...
@@ -266,6 +269,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
str
(
paddle
.
distributed
.
get_rank
())))
def
test_mlp_pp2mp
(
self
):
set_default_distributed_context
(
None
)
global
_global_parallel_strategy
_global_parallel_strategy
=
"pp"
global
_global_process_mesh
...
...
@@ -302,6 +306,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
if
paddle
.
distributed
.
get_rank
()
in
[
1
]:
last_res
=
res
[
0
]
set_default_distributed_context
(
None
)
_global_parallel_strategy
=
"mp"
_global_process_mesh
=
auto
.
ProcessMesh
([
0
,
1
])
...
...
@@ -345,6 +350,7 @@ class TestMLPAutoConvertInvalid(unittest.TestCase):
np
.
random
.
seed
(
2021
)
def
test_input_invalid
(
self
):
set_default_distributed_context
(
None
)
global
_global_parallel_strategy
_global_parallel_strategy
=
"mp"
global
_global_process_mesh
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录