Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2747de2b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2747de2b
编写于
3月 10, 2022
作者:
C
caozhou
提交者:
GitHub
3月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel]Update reshard for while sub block (#40366)
* update reshard for while sub block * fix code format error
上级
575dea8f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
511 addition
and
181 deletion
+511
-181
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+505
-181
python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py
...paddle/fluid/tests/unittests/auto_parallel_autoconvert.py
+6
-0
未找到文件。
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
2747de2b
...
@@ -29,6 +29,7 @@ from .process_group import new_process_group, ProcessGroup, _g_process_group_map
...
@@ -29,6 +29,7 @@ from .process_group import new_process_group, ProcessGroup, _g_process_group_map
# NOTE: If op in _g_special_ops, it will not be resharded.
# NOTE: If op in _g_special_ops, it will not be resharded.
_g_special_ops
=
[
'check_finite_and_unscale'
,
'update_loss_scaling'
]
_g_special_ops
=
[
'check_finite_and_unscale'
,
'update_loss_scaling'
]
while_block_info
=
{}
class
AllGatherOpDesc
:
class
AllGatherOpDesc
:
...
@@ -280,8 +281,20 @@ def _is_overlapped(shape_x, shape_y):
...
@@ -280,8 +281,20 @@ def _is_overlapped(shape_x, shape_y):
return
overlapped
return
overlapped
def
_need_reshard
(
dist_tensor
,
dist_op
,
op_input
=
True
):
def
_need_reshard
(
dist_tensor
,
dist_op
,
actual_process_mesh
,
program
,
dist_context
,
op_input
=
True
):
"""Judge the tensor whether needs to be resharded."""
"""Judge the tensor whether needs to be resharded."""
def
_is_unshard
(
dims_mapping
):
for
dim
in
dims_mapping
:
if
dim
!=
-
1
:
return
False
return
True
is_reshard
=
False
is_reshard
=
False
tensor_dist_attr
=
dist_tensor
.
dist_attr
tensor_dist_attr
=
dist_tensor
.
dist_attr
tensor_name
=
dist_tensor
.
serial_tensor
.
name
tensor_name
=
dist_tensor
.
serial_tensor
.
name
...
@@ -289,32 +302,74 @@ def _need_reshard(dist_tensor, dist_op, op_input=True):
...
@@ -289,32 +302,74 @@ def _need_reshard(dist_tensor, dist_op, op_input=True):
tensor_process_mesh
=
tensor_dist_attr
.
process_mesh
tensor_process_mesh
=
tensor_dist_attr
.
process_mesh
op_dist_attr
=
dist_op
.
dist_attr
op_dist_attr
=
dist_op
.
dist_attr
op_input_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
op_input_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
op_process_mesh
=
op_dist_attr
.
process_mesh
op_process_mesh
=
actual_
process_mesh
if
op_input
:
if
op_input
:
op_input_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
op_input_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
op_process_mesh
=
op_dist_attr
.
process_mesh
if
all
(
if
all
(
map
(
lambda
x
:
x
is
not
None
,
[
map
(
lambda
x
:
x
is
not
None
,
[
tensor_dims_mapping
,
tensor_process_mesh
,
tensor_dims_mapping
,
tensor_process_mesh
,
op_input_dims_mapping
,
op_process_mesh
op_input_dims_mapping
,
op_process_mesh
])):
])):
if
tensor_dims_mapping
!=
op_input_dims_mapping
or
tensor_process_mesh
!=
op_process_mesh
:
# dims_mapping
is_reshard
=
True
if
tensor_dims_mapping
!=
op_input_dims_mapping
:
if
dist_op
.
serial_op
.
type
==
"while"
:
sub_block
=
program
.
blocks
[
dist_op
.
serial_op
.
attr
(
"sub_block"
).
id
]
for
op
in
sub_block
.
ops
:
for
var_name
in
op
.
input_arg_names
:
if
var_name
==
tensor_name
:
dist_op_attr
=
dist_context
.
get_dist_op_for_program
(
op
).
dist_attr
var_dims_mapping
=
dist_op_attr
.
get_input_dims_mapping
(
var_name
)
if
var_dims_mapping
!=
tensor_dims_mapping
:
is_reshard
=
True
break
else
:
is_reshard
=
True
# process_mesh
if
tensor_process_mesh
!=
op_process_mesh
:
# when processes length is not the same, the dims mapping must be replicative now
if
len
(
tensor_process_mesh
.
processes
)
!=
len
(
op_process_mesh
.
processes
):
assert
_is_unshard
(
tensor_dims_mapping
)
assert
_is_unshard
(
op_input_dims_mapping
)
else
:
if
dist_tensor
.
serial_tensor
.
dtype
==
paddle
.
bool
:
raise
ValueError
(
"Bool var is not supported reshard."
)
# for while op, it should find the process mesh of op actually used the tensor as input
if
dist_op
.
serial_op
.
type
==
"while"
:
sub_block
=
program
.
blocks
[
dist_op
.
serial_op
.
attr
(
"sub_block"
).
id
]
for
op
in
sub_block
.
ops
:
for
var_name
in
op
.
input_arg_names
:
if
var_name
==
tensor_name
:
dist_op_attr
=
dist_context
.
get_dist_op_for_program
(
op
).
dist_attr
process_mesh
=
dist_op_attr
.
process_mesh
if
process_mesh
==
op_process_mesh
:
is_reshard
=
True
break
else
:
is_reshard
=
True
else
:
else
:
op_output_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
op_output_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
tensor_name
)
tensor_name
)
op_process_mesh
=
op_dist_attr
.
process_mesh
if
all
(
if
all
(
map
(
lambda
x
:
x
is
not
None
,
[
map
(
lambda
x
:
x
is
not
None
,
[
tensor_dims_mapping
,
tensor_process_mesh
,
tensor_dims_mapping
,
tensor_process_mesh
,
op_output_dims_mapping
,
op_process_mesh
op_output_dims_mapping
,
op_process_mesh
])):
])):
if
tensor_process_mesh
!=
op_process_mesh
:
if
tensor_process_mesh
!=
op_process_mesh
:
if
dist_tensor
.
serial_tensor
.
dtype
==
paddle
.
bool
:
raise
ValueError
(
"Bool var is not supported reshard."
)
is_reshard
=
True
is_reshard
=
True
if
tensor_dims_mapping
!=
op_output_dims_mapping
:
if
tensor_dims_mapping
!=
op_output_dims_mapping
:
raise
ValueError
(
raise
ValueError
(
"It is not supported that tensor dims mapping is different from op output dims mapping."
"It is not supported that tensor dims mapping is different from op output dims mapping."
)
)
return
is_reshard
return
is_reshard
...
@@ -329,13 +384,14 @@ def _compute_complete_shape(slice_shape, process_shape, dims_mapping):
...
@@ -329,13 +384,14 @@ def _compute_complete_shape(slice_shape, process_shape, dims_mapping):
return
complete_shape
return
complete_shape
def
find_op_desc_seq
(
dist_tensor
,
dist_op
):
def
find_op_desc_seq
(
dist_tensor
,
dist_op
,
actual_process_mesh
,
batch_size
):
"""
"""
Find the op description sequence to reshard the source tensor for matching the op requirement.
Find the op description sequence to reshard the source tensor for matching the op requirement.
Args:
Args:
dist_tensor (DistributedTensor): A distributed tensor.
dist_tensor (DistributedTensor): A distributed tensor.
dist_op (DistributedOperator): A distributed operator.
dist_op (DistributedOperator): A distributed operator.
actual_process_mesh (ProcessMesh): The actual op process mesh.
Returns:
Returns:
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
...
@@ -350,11 +406,16 @@ def find_op_desc_seq(dist_tensor, dist_op):
...
@@ -350,11 +406,16 @@ def find_op_desc_seq(dist_tensor, dist_op):
source_process_shape
=
source_process_mesh
.
topology
source_process_shape
=
source_process_mesh
.
topology
op_dist_attr
=
dist_op
.
dist_attr
op_dist_attr
=
dist_op
.
dist_attr
target_process_mesh
=
op_dist_attr
.
process_mesh
target_process_mesh
=
actual_
process_mesh
target_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
target_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
tensor_name
)
target_process_group
=
target_process_mesh
.
processes
target_process_group
=
target_process_mesh
.
processes
target_process_shape
=
target_process_mesh
.
topology
target_process_shape
=
target_process_mesh
.
topology
if
source_tensor
.
shape
[
0
]
<
0
:
new_shape
=
list
(
source_tensor
.
shape
)
new_shape
[
0
]
=
batch_size
source_tensor
.
desc
.
set_shape
(
new_shape
)
complete_shape
=
_compute_complete_shape
(
complete_shape
=
_compute_complete_shape
(
source_tensor
.
shape
,
source_process_shape
,
source_dims_mapping
)
source_tensor
.
shape
,
source_process_shape
,
source_dims_mapping
)
op_desc_seq
=
{}
op_desc_seq
=
{}
...
@@ -503,7 +564,7 @@ def find_op_desc_seq(dist_tensor, dist_op):
...
@@ -503,7 +564,7 @@ def find_op_desc_seq(dist_tensor, dist_op):
return
op_desc_seq
return
op_desc_seq
def
_insert_send_op
(
block
,
idx
,
tensor
,
dst
):
def
_insert_send_op
(
block
,
idx
,
tensor
,
dst
,
op_role
):
"""Insert send op into block at the given index."""
"""Insert send op into block at the given index."""
op_type
=
'send_v2'
op_type
=
'send_v2'
block
.
_insert_op
(
block
.
_insert_op
(
...
@@ -514,10 +575,11 @@ def _insert_send_op(block, idx, tensor, dst):
...
@@ -514,10 +575,11 @@ def _insert_send_op(block, idx, tensor, dst):
'ring_id'
:
0
,
'ring_id'
:
0
,
'peer'
:
dst
,
'peer'
:
dst
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
'op_role'
:
op_role
})
})
def
_insert_recv_op
(
block
,
idx
,
tensor
,
src
):
def
_insert_recv_op
(
block
,
idx
,
tensor
,
src
,
op_role
):
"""Insert recv op into block at the given index."""
"""Insert recv op into block at the given index."""
op_type
=
'recv_v2'
op_type
=
'recv_v2'
block
.
_insert_op
(
block
.
_insert_op
(
...
@@ -531,14 +593,16 @@ def _insert_recv_op(block, idx, tensor, src):
...
@@ -531,14 +593,16 @@ def _insert_recv_op(block, idx, tensor, src):
'out_shape'
:
tensor
.
shape
,
'out_shape'
:
tensor
.
shape
,
'dtype'
:
tensor
.
dtype
,
'dtype'
:
tensor
.
dtype
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
'op_role'
:
op_role
})
})
def
_insert_concat_op
(
block
,
idx
,
tensors
,
axis
):
def
_insert_concat_op
(
block
,
idx
,
tensors
,
axis
,
op_role
):
"""Insert concat op into block at the given block."""
"""Insert concat op into block at the given block."""
inputs
=
{
'X'
:
tensors
}
inputs
=
{
'X'
:
tensors
}
attrs
=
{}
attrs
=
{}
attrs
[
'axis'
]
=
axis
attrs
[
'axis'
]
=
axis
attrs
[
'op_role'
]
=
op_role
helper
=
LayerHelper
(
'concat'
,
**
locals
())
helper
=
LayerHelper
(
'concat'
,
**
locals
())
with
paddle
.
static
.
program_guard
(
block
.
program
):
with
paddle
.
static
.
program_guard
(
block
.
program
):
out
=
helper
.
create_variable_for_type_inference
(
out
=
helper
.
create_variable_for_type_inference
(
...
@@ -548,7 +612,8 @@ def _insert_concat_op(block, idx, tensors, axis):
...
@@ -548,7 +612,8 @@ def _insert_concat_op(block, idx, tensors, axis):
return
out
return
out
def
_insert_slice_op
(
block
,
idx
,
tensor
,
starts
,
ends
,
axes
,
new_var_name
):
def
_insert_slice_op
(
block
,
idx
,
tensor
,
starts
,
ends
,
axes
,
new_var_name
,
op_role
):
"""Insert slice op into block at the given block."""
"""Insert slice op into block at the given block."""
inputs
=
{
'Input'
:
tensor
}
inputs
=
{
'Input'
:
tensor
}
infer_flags
=
list
(
1
for
i
in
range
(
len
(
axes
)))
infer_flags
=
list
(
1
for
i
in
range
(
len
(
axes
)))
...
@@ -556,24 +621,23 @@ def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name):
...
@@ -556,24 +621,23 @@ def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name):
"axes"
:
axes
,
"axes"
:
axes
,
"starts"
:
starts
,
"starts"
:
starts
,
"ends"
:
ends
,
"ends"
:
ends
,
"infer_flags"
:
infer_flags
"infer_flags"
:
infer_flags
,
'op_role'
:
op_role
}
}
helper
=
LayerHelper
(
'slice'
,
**
locals
())
helper
=
LayerHelper
(
'slice'
,
**
locals
())
out
=
block
.
create_var
(
out
=
block
.
create_var
(
name
=
new_var_name
,
name
=
new_var_name
,
dtype
=
tensor
.
dtype
,
type
=
tensor
.
type
)
dtype
=
tensor
.
dtype
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
block
.
_insert_op
(
block
.
_insert_op
(
idx
,
type
=
"slice"
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
[
out
]},
attrs
=
attrs
)
idx
,
type
=
"slice"
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
[
out
]},
attrs
=
attrs
)
return
out
return
out
def
_insert_split_op
(
block
,
idx
,
tensor
,
num_or_sections
):
def
_insert_split_op
(
block
,
idx
,
tensor
,
num_or_sections
,
op_role
):
"""Insert split op into block at the given index."""
"""Insert split op into block at the given index."""
helper
=
LayerHelper
(
'split'
,
**
locals
())
helper
=
LayerHelper
(
'split'
,
**
locals
())
input_shape
=
tensor
.
shape
input_shape
=
tensor
.
shape
inputs
=
{
'X'
:
tensor
}
inputs
=
{
'X'
:
tensor
}
attrs
=
{
'num'
:
num_or_sections
,
"axis"
:
0
}
attrs
=
{
'num'
:
num_or_sections
,
'axis'
:
0
,
'op_role'
:
op_role
}
with
paddle
.
static
.
program_guard
(
block
.
program
):
with
paddle
.
static
.
program_guard
(
block
.
program
):
outs
=
[
outs
=
[
helper
.
create_variable_for_type_inference
(
helper
.
create_variable_for_type_inference
(
...
@@ -584,7 +648,7 @@ def _insert_split_op(block, idx, tensor, num_or_sections):
...
@@ -584,7 +648,7 @@ def _insert_split_op(block, idx, tensor, num_or_sections):
return
outs
return
outs
def
_insert_allgather_op
(
block
,
idx
,
tensor
,
ranks
):
def
_insert_allgather_op
(
block
,
idx
,
tensor
,
ranks
,
op_role
):
"""Insert allgather op into block at the given index."""
"""Insert allgather op into block at the given index."""
def
_insert_fill_constant_op
(
block
,
idx
):
def
_insert_fill_constant_op
(
block
,
idx
):
...
@@ -597,6 +661,7 @@ def _insert_allgather_op(block, idx, tensor, ranks):
...
@@ -597,6 +661,7 @@ def _insert_allgather_op(block, idx, tensor, ranks):
attrs
[
'str_value'
]
=
str
(
int
(
"1"
))
attrs
[
'str_value'
]
=
str
(
int
(
"1"
))
attrs
[
'value'
]
=
int
(
"1"
)
attrs
[
'value'
]
=
int
(
"1"
)
attrs
[
'dtype'
]
=
out
.
dtype
attrs
[
'dtype'
]
=
out
.
dtype
attrs
[
'op_role'
]
=
op_role
utils
.
get_shape_tensor_inputs
(
utils
.
get_shape_tensor_inputs
(
inputs
=
inputs
,
attrs
=
attrs
,
shape
=
[
0
],
op_type
=
'fill_constant'
)
inputs
=
inputs
,
attrs
=
attrs
,
shape
=
[
0
],
op_type
=
'fill_constant'
)
block
.
_insert_op
(
block
.
_insert_op
(
...
@@ -625,14 +690,16 @@ def _insert_allgather_op(block, idx, tensor, ranks):
...
@@ -625,14 +690,16 @@ def _insert_allgather_op(block, idx, tensor, ranks):
inputs
=
{
'X'
:
[
fill_constant_out
]},
inputs
=
{
'X'
:
[
fill_constant_out
]},
outputs
=
{
'Out'
:
[
fill_constant_out
]},
outputs
=
{
'Out'
:
[
fill_constant_out
]},
attrs
=
{
'ring_id'
:
0
,
attrs
=
{
'ring_id'
:
0
,
'use_calc_stream'
:
True
})
'use_calc_stream'
:
True
,
'op_role'
:
op_role
})
# insert c_sync_calc_stream op
# insert c_sync_calc_stream op
block
.
_insert_op
(
block
.
_insert_op
(
idx
+
2
,
idx
+
2
,
type
=
"c_sync_calc_stream"
,
type
=
"c_sync_calc_stream"
,
inputs
=
{
'X'
:
[
fill_constant_out
]},
inputs
=
{
'X'
:
[
fill_constant_out
]},
outputs
=
{
'Out'
:
[
fill_constant_out
]})
outputs
=
{
'Out'
:
[
fill_constant_out
]},
attrs
=
{
'op_role'
:
op_role
})
idx_offset
=
3
idx_offset
=
3
# insert c_allgather op
# insert c_allgather op
...
@@ -649,20 +716,21 @@ def _insert_allgather_op(block, idx, tensor, ranks):
...
@@ -649,20 +716,21 @@ def _insert_allgather_op(block, idx, tensor, ranks):
attrs
=
{
attrs
=
{
'ring_id'
:
group
.
id
,
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
'nranks'
:
group
.
nranks
'nranks'
:
group
.
nranks
,
'op_role'
:
op_role
})
})
idx_offset
+=
1
idx_offset
+=
1
# insert split op
# insert split op
split_out
=
_insert_split_op
(
block
,
idx
+
idx_offset
,
allgather_out
,
split_out
=
_insert_split_op
(
block
,
idx
+
idx_offset
,
allgather_out
,
group
.
nranks
)
group
.
nranks
,
op_role
)
idx_offset
+=
1
idx_offset
+=
1
tensor_list
.
extend
(
split_out
)
tensor_list
.
extend
(
split_out
)
return
tensor_list
,
idx_offset
return
tensor_list
,
idx_offset
def
_concat_partitions_with_op
(
partition_tensor_list
,
tensor
,
partition_index
,
def
_concat_partitions_with_op
(
partition_tensor_list
,
tensor
,
partition_index
,
block
,
idx
):
block
,
idx
,
op_role
):
"""Concat the tensors and insert concat op."""
"""Concat the tensors and insert concat op."""
if
not
partition_tensor_list
:
if
not
partition_tensor_list
:
partition_tensor_list
.
append
((
tensor
,
partition_index
))
partition_tensor_list
.
append
((
tensor
,
partition_index
))
...
@@ -674,13 +742,13 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
...
@@ -674,13 +742,13 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
partition_tensor_list
[
i
][
1
],
partition_index
)
partition_tensor_list
[
i
][
1
],
partition_index
)
if
concat_axis
!=
-
1
:
if
concat_axis
!=
-
1
:
has_concat
=
True
has_concat
=
True
_
=
_insert_concat_op
(
block
,
idx
[
0
],
[
partition_tensor_list
[
i
][
0
],
tensor
],
concat_axis
)
\
_
=
_insert_concat_op
(
block
,
idx
[
0
],
[
partition_tensor_list
[
i
][
0
],
tensor
],
concat_axis
,
op_role
)
\
if
first_order
==
0
else
\
if
first_order
==
0
else
\
_insert_concat_op
(
block
,
idx
[
0
],
[
tensor
,
partition_tensor_list
[
i
][
0
]],
concat_axis
)
_insert_concat_op
(
block
,
idx
[
0
],
[
tensor
,
partition_tensor_list
[
i
][
0
]],
concat_axis
,
op_role
)
partition_tensor_list
.
pop
(
i
)
partition_tensor_list
.
pop
(
i
)
idx
[
0
]
+=
1
idx
[
0
]
+=
1
_concat_partitions_with_op
(
partition_tensor_list
,
_
,
_concat_partitions_with_op
(
partition_tensor_list
,
_
,
new_partition
,
block
,
idx
)
new_partition
,
block
,
idx
,
op_role
)
break
break
i
+=
1
i
+=
1
if
not
has_concat
:
if
not
has_concat
:
...
@@ -692,8 +760,47 @@ HAS_RECV = {}
...
@@ -692,8 +760,47 @@ HAS_RECV = {}
HAS_ALLGATHER
=
{}
HAS_ALLGATHER
=
{}
def
parse_op_desc
(
program
,
rank_id
,
op_desc_seq
,
var_name
,
reshard_op
,
def
_get_while_op_actual_process_mesh
(
op
,
program
,
rank_id
,
dist_context
):
dist_context
):
"""Get the while op actual Process mesh corresponding to rank"""
assert
op
.
type
==
"while"
while_op_process_mesh
=
dist_context
.
get_dist_op_for_program
(
op
).
dist_attr
.
process_mesh
sub_block
=
program
.
blocks
[
op
.
attr
(
"sub_block"
).
id
]
ops
=
sub_block
.
ops
actual_process_mesh
=
None
for
op
in
ops
:
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
if
process_mesh
==
while_op_process_mesh
:
continue
if
rank_id
in
process_mesh
.
processes
:
raw_process_mesh
=
process_mesh
break
if
actual_process_mesh
is
None
and
rank_id
in
while_op_process_mesh
.
processes
:
actual_process_mesh
=
while_op_process_mesh
assert
actual_process_mesh
is
not
None
return
actual_process_mesh
def
_get_var
(
var_name
,
block
,
program
):
"""Get var in the parent block if not found in the current block"""
var
=
None
if
var_name
in
block
.
vars
:
var
=
block
.
vars
[
var_name
]
else
:
parent_block
=
program
.
blocks
[
block
.
parent_idx
]
if
var_name
in
parent_block
.
vars
:
var
=
parent_block
.
vars
[
var_name
]
assert
var
is
not
None
return
var
def
parse_op_desc
(
block
,
rank_id
,
op_desc_seq
,
var_name
,
reshard_op
,
dist_context
,
program
,
actual_process_mesh
):
"""Parse op desc sequence and insert op in the block"""
"""Parse op desc sequence and insert op in the block"""
global
HAS_SENT
global
HAS_SENT
global
HAS_RECV
global
HAS_RECV
...
@@ -703,9 +810,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
...
@@ -703,9 +810,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
if
rank_id
not
in
op_desc_seq
.
keys
():
if
rank_id
not
in
op_desc_seq
.
keys
():
return
return
op_desc_list
=
op_desc_seq
[
rank_id
]
op_desc_list
=
op_desc_seq
[
rank_id
]
block
=
program
.
global_block
()
assert
var_name
in
block
.
vars
.
keys
(
),
"The {} cannot be found in the {} program."
.
format
(
var_name
,
rank_id
)
idx
=
None
idx
=
None
for
index
,
op
in
list
(
enumerate
(
block
.
ops
)):
for
index
,
op
in
list
(
enumerate
(
block
.
ops
)):
...
@@ -716,7 +820,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
...
@@ -716,7 +820,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
rank_id
)
rank_id
)
matched_op
=
block
.
ops
[
idx
]
matched_op
=
block
.
ops
[
idx
]
source_tensor
=
block
.
vars
[
var_name
]
source_tensor
=
_get_var
(
var_name
,
block
,
program
)
for
op_desc
in
op_desc_list
:
for
op_desc
in
op_desc_list
:
if
isinstance
(
op_desc
,
AllGatherOpDesc
):
# noqa: F401
if
isinstance
(
op_desc
,
AllGatherOpDesc
):
# noqa: F401
if
var_name
not
in
HAS_ALLGATHER
.
keys
():
if
var_name
not
in
HAS_ALLGATHER
.
keys
():
...
@@ -724,7 +828,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
...
@@ -724,7 +828,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
if
not
HAS_ALLGATHER
[
var_name
]
or
op_desc
.
group
not
in
list
(
if
not
HAS_ALLGATHER
[
var_name
]
or
op_desc
.
group
not
in
list
(
map
(
lambda
x
:
x
[
0
],
HAS_ALLGATHER
[
var_name
])):
map
(
lambda
x
:
x
[
0
],
HAS_ALLGATHER
[
var_name
])):
tensor_list
,
idx_offset
=
_insert_allgather_op
(
tensor_list
,
idx_offset
=
_insert_allgather_op
(
block
,
idx
,
source_tensor
,
op_desc
.
group
)
block
,
idx
,
source_tensor
,
op_desc
.
group
,
reshard_op
.
attr
(
'op_role'
))
idx
+=
idx_offset
idx
+=
idx_offset
tensor_name_list
=
[
var
.
name
for
var
in
tensor_list
]
tensor_name_list
=
[
var
.
name
for
var
in
tensor_list
]
HAS_ALLGATHER
[
var_name
].
append
(
HAS_ALLGATHER
[
var_name
].
append
(
...
@@ -743,7 +848,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
...
@@ -743,7 +848,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
if
var_name
not
in
HAS_SENT
.
keys
():
if
var_name
not
in
HAS_SENT
.
keys
():
HAS_SENT
[
var_name
]
=
[]
HAS_SENT
[
var_name
]
=
[]
if
op_desc
.
dst
not
in
HAS_SENT
[
var_name
]:
if
op_desc
.
dst
not
in
HAS_SENT
[
var_name
]:
_insert_send_op
(
block
,
idx
,
source_tensor
,
op_desc
.
dst
)
_insert_send_op
(
block
,
idx
,
source_tensor
,
op_desc
.
dst
,
reshard_op
.
attr
(
'op_role'
))
idx
+=
1
idx
+=
1
HAS_SENT
[
var_name
].
append
(
op_desc
.
dst
)
HAS_SENT
[
var_name
].
append
(
op_desc
.
dst
)
...
@@ -758,8 +864,10 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
...
@@ -758,8 +864,10 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
recv_tensor
=
block
.
create_var
(
recv_tensor
=
block
.
create_var
(
name
=
unique_name
.
generate
(
var_name
+
"@recv"
),
name
=
unique_name
.
generate
(
var_name
+
"@recv"
),
shape
=
shape
,
shape
=
shape
,
dtype
=
source_tensor
.
dtype
)
dtype
=
source_tensor
.
dtype
,
_insert_recv_op
(
block
,
idx
,
recv_tensor
,
op_desc
.
src
)
type
=
source_tensor
.
type
)
_insert_recv_op
(
block
,
idx
,
recv_tensor
,
op_desc
.
src
,
reshard_op
.
attr
(
'op_role'
))
tensor_list
.
append
(
recv_tensor
)
tensor_list
.
append
(
recv_tensor
)
idx
+=
1
idx
+=
1
HAS_RECV
[
var_name
][
op_desc
.
src
]
=
recv_tensor
HAS_RECV
[
var_name
][
op_desc
.
src
]
=
recv_tensor
...
@@ -772,7 +880,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
...
@@ -772,7 +880,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
for
index
,
tensor
in
enumerate
(
tensor_list
):
for
index
,
tensor
in
enumerate
(
tensor_list
):
_concat_partitions_with_op
(
partition_tensor_list
,
tensor
,
_concat_partitions_with_op
(
partition_tensor_list
,
tensor
,
partition_index_list
[
index
],
block
,
partition_index_list
[
index
],
block
,
idx_list
)
idx_list
,
reshard_op
.
attr
(
'op_role'
)
)
idx
=
idx_list
[
0
]
idx
=
idx_list
[
0
]
elif
isinstance
(
op_desc
,
SliceOpDesc
):
elif
isinstance
(
op_desc
,
SliceOpDesc
):
...
@@ -787,11 +895,11 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
...
@@ -787,11 +895,11 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
starts
=
op_desc
.
starts
,
starts
=
op_desc
.
starts
,
ends
=
op_desc
.
ends
,
ends
=
op_desc
.
ends
,
axes
=
op_desc
.
axes
,
axes
=
op_desc
.
axes
,
new_var_name
=
new_name
)
new_var_name
=
new_name
,
op_role
=
reshard_op
.
attr
(
'op_role'
))
tensor_attr
=
TensorDistributedAttribute
()
tensor_attr
=
TensorDistributedAttribute
()
process_mesh
=
dist_context
.
get_op_dist_attr_for_program
(
process_mesh
=
actual_process_mesh
matched_op
).
process_mesh
dims_mapping
=
dist_context
.
get_op_dist_attr_for_program
(
dims_mapping
=
dist_context
.
get_op_dist_attr_for_program
(
matched_op
).
get_input_dims_mapping
(
var_name
)
matched_op
).
get_input_dims_mapping
(
var_name
)
tensor_attr
.
dims_mapping
=
dims_mapping
tensor_attr
.
dims_mapping
=
dims_mapping
...
@@ -799,11 +907,29 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
...
@@ -799,11 +907,29 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
dist_context
.
set_tensor_dist_attr_for_program
(
target_tensor
,
dist_context
.
set_tensor_dist_attr_for_program
(
target_tensor
,
tensor_attr
)
tensor_attr
)
if
op
.
type
==
"while"
:
global
while_block_info
# var_reshard_mapping means the while op input need be changed to
if
"var_reshard_mapping"
not
in
while_block_info
[
op
.
attr
(
"sub_block"
).
id
].
keys
():
while_block_info
[
op
.
attr
(
"sub_block"
).
id
][
"var_reshard_mapping"
]
=
{}
while_block_info
[
op
.
attr
(
"sub_block"
).
id
][
"var_reshard_mapping"
][
var_name
]
=
target_tensor
.
name
# rename op input name according to new name
# rename op input name according to new name
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
for
name
in
op
.
input_arg_names
:
for
name
in
op
.
input_arg_names
:
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
op
)
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
op
)
if
name
==
var_name
and
op_dist_attr
is
not
None
:
if
name
==
var_name
and
op_dist_attr
is
not
None
:
if
op
.
desc
.
id
()
==
matched_op
.
desc
.
id
():
op
.
desc
.
_rename_input
(
name
,
target_tensor
.
name
)
op_dist_attr
.
set_input_dims_mapping
(
target_tensor
.
name
,
dims_mapping
)
op_dist_attr
.
set_input_dist_attr
(
name
,
None
)
continue
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
op_process_mesh
=
op_dist_attr
.
process_mesh
op_process_mesh
=
op_dist_attr
.
process_mesh
op_input_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
op_input_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
var_name
)
var_name
)
...
@@ -819,102 +945,166 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
...
@@ -819,102 +945,166 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
not_remove_op_ref
=
[
not_remove_op_ref
=
[
"create_py_reader"
,
"create_double_buffer_reader"
,
"read"
"create_py_reader"
,
"create_double_buffer_reader"
,
"read"
]
]
remove_op_idx
=
[]
global
while_block_info
block
=
auto_parallel_main_prog
.
global_block
()
ops
=
block
.
ops
# NOTE: The nested sub block is not be supported now.
vars
=
block
.
vars
remove_block_order
=
[]
for
idx
,
op
in
enumerate
(
ops
):
for
block_idx
in
while_block_info
:
# handle read op in the pipeline scene specially, it will be removed in the future.
remove_block_order
.
append
(
block_idx
)
if
op
.
type
==
"read"
:
dim_list
=
[]
for
block_idx
,
block
in
enumerate
(
auto_parallel_main_prog
.
blocks
):
for
var_name
in
op
.
output_arg_names
:
if
block_idx
not
in
remove_block_order
:
dim_list
.
extend
(
vars
[
var_name
].
shape
)
remove_block_order
.
append
(
block_idx
)
for
i
in
range
(
idx
,
-
1
,
-
1
):
if
ops
[
i
].
type
==
"create_py_reader"
:
# the sub block should be removed first
ops
[
i
].
_set_attr
(
"shape_concat"
,
dim_list
)
for
block_idx
in
remove_block_order
:
break
remove_op_idx
=
[]
continue
block
=
auto_parallel_main_prog
.
blocks
[
block_idx
]
ops
=
block
.
ops
# replace the input and output of c_sync_comm_stream op when in pipeline scene.
vars
=
block
.
vars
if
op
.
type
==
"c_sync_comm_stream"
:
for
idx
,
op
in
enumerate
(
ops
):
need_save
=
[]
if
op
.
type
==
"read"
:
for
var_name
in
op
.
input_arg_names
:
dim_list
=
[]
process_mesh
=
dist_context
.
get_tensor_dist_attr_for_program
(
for
var_name
in
op
.
output_arg_names
:
vars
[
var_name
]).
process_mesh
dim_list
.
extend
(
if
rank_id
in
process_mesh
.
processes
:
_get_var
(
var_name
,
block
,
auto_parallel_main_prog
)
need_save
.
append
(
var_name
)
.
shape
)
if
not
need_save
:
for
i
in
range
(
idx
,
-
1
,
-
1
):
remove_op_idx
.
append
(
idx
)
if
ops
[
i
].
type
==
"create_py_reader"
:
ops
[
i
].
_set_attr
(
"shape_concat"
,
dim_list
)
break
continue
continue
proto
=
OpProtoHolder
.
instance
().
get_op_proto
(
op
.
type
)
# replace the input and output of c_sync_comm_stream op when in pipeline scene.
op
.
desc
.
set_input
(
proto
.
inputs
[
0
].
name
,
need_save
)
if
op
.
type
==
"c_sync_comm_stream"
:
op
.
desc
.
set_output
(
proto
.
outputs
[
0
].
name
,
need_save
)
need_save
=
[]
continue
for
var_name
in
op
.
input_arg_names
:
process_mesh
=
dist_context
.
get_tensor_dist_attr_for_program
(
_get_var
(
var_name
,
block
,
auto_parallel_main_prog
)).
process_mesh
if
rank_id
in
process_mesh
.
processes
:
need_save
.
append
(
var_name
)
if
not
need_save
:
remove_op_idx
.
append
(
idx
)
continue
# judge the other op whether should be removed.
proto
=
OpProtoHolder
.
instance
().
get_op_proto
(
op
.
type
)
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
op
)
op
.
desc
.
set_input
(
proto
.
inputs
[
0
].
name
,
need_save
)
if
op_dist_attr
is
not
None
:
op
.
desc
.
set_output
(
proto
.
outputs
[
0
].
name
,
need_save
)
op_process_mesh
=
op_dist_attr
.
process_mesh
continue
if
rank_id
not
in
op_process_mesh
.
processes
and
op
.
type
not
in
not_remove_op_ref
:
remove_op_idx
.
append
(
idx
)
for
idx
in
remove_op_idx
[::
-
1
]:
# judge the other op whether should be removed.
block
.
_remove_op
(
idx
)
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
op
)
if
op_dist_attr
is
not
None
:
op_process_mesh
=
op_dist_attr
.
process_mesh
if
rank_id
not
in
op_process_mesh
.
processes
and
op
.
type
not
in
not_remove_op_ref
:
remove_op_idx
.
append
(
idx
)
for
idx
in
remove_op_idx
[::
-
1
]:
block
.
_remove_op
(
idx
)
def
_remove_no_need_vars
(
auto_parallel_main_prog
,
dist_params_grads
):
def
_remove_no_need_vars
(
auto_parallel_main_prog
,
dist_params_grads
):
"""Remove no need vars in the main program"""
"""Remove no need vars in the main program"""
remove_vars
=
set
()
for
block_idx
,
block
in
enumerate
(
auto_parallel_main_prog
.
blocks
):
block
=
auto_parallel_main_prog
.
global_block
()
remove_vars
=
set
()
ops
=
block
.
ops
ops
=
block
.
ops
vars
=
block
.
vars
vars
=
block
.
vars
need_vars
=
set
()
need_vars
=
set
()
for
op
in
ops
:
for
op
in
ops
:
for
var_name
in
op
.
input_arg_names
:
for
var_name
in
op
.
input_arg_names
:
if
var_name
in
vars
:
if
var_name
in
vars
:
need_vars
.
add
(
var_name
)
need_vars
.
add
(
var_name
)
for
var_name
in
op
.
output_arg_names
:
for
var_name
in
op
.
output_arg_names
:
if
var_name
in
vars
:
if
var_name
in
vars
:
need_vars
.
add
(
var_name
)
need_vars
.
add
(
var_name
)
for
var
in
vars
:
for
var
in
vars
:
if
var
not
in
need_vars
:
if
var
not
in
need_vars
:
remove_vars
.
add
(
var
)
remove_vars
.
add
(
var
)
# change dist_params_grads
# change dist_params_grads, the optimize op just in block 0.
param_grad_map
=
{}
if
block_idx
==
0
:
for
op
in
ops
:
param_grad_map
=
{}
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Optimize
):
for
op
in
ops
:
if
"Param"
in
op
.
input_names
and
"Grad"
in
op
.
input_names
:
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Optimize
):
param_name
=
op
.
input
(
"Param"
)[
0
]
if
"Param"
in
op
.
input_names
and
"Grad"
in
op
.
input_names
:
grad_name
=
op
.
input
(
"Grad"
)[
0
]
param_name
=
op
.
input
(
"Param"
)[
0
]
param_grad_map
[
param_name
]
=
grad_name
grad_name
=
op
.
input
(
"Grad"
)[
0
]
param_grad_map
[
param_name
]
=
grad_name
need_remove_idx
=
[]
for
idx
,
item
in
enumerate
(
dist_params_grads
):
need_remove_idx
=
[]
if
item
[
0
].
name
not
in
param_grad_map
.
keys
():
for
idx
,
item
in
enumerate
(
dist_params_grads
):
need_remove_idx
.
append
(
idx
)
if
item
[
0
].
name
not
in
param_grad_map
.
keys
():
need_remove_idx
.
append
(
idx
)
for
idx
in
need_remove_idx
[::
-
1
]:
dist_params_grads
.
pop
(
idx
)
for
idx
in
need_remove_idx
[::
-
1
]:
dist_params_grads
.
pop
(
idx
)
idx
=
0
while
idx
<
len
(
dist_params_grads
):
idx
=
0
param_name
=
dist_params_grads
[
idx
][
0
].
name
while
idx
<
len
(
dist_params_grads
):
grad_name
=
dist_params_grads
[
idx
][
1
].
name
param_name
=
dist_params_grads
[
idx
][
0
].
name
if
grad_name
!=
param_grad_map
[
param_name
]:
grad_name
=
dist_params_grads
[
idx
][
1
].
name
dist_params_grads
[
idx
]
=
(
vars
[
param_name
],
if
grad_name
!=
param_grad_map
[
param_name
]:
vars
[
param_grad_map
[
param_name
]])
dist_params_grads
[
idx
]
=
(
vars
[
param_name
],
idx
+=
1
vars
[
param_grad_map
[
param_name
]])
idx
+=
1
for
var
in
remove_vars
:
for
var
in
remove_vars
:
block
.
_remove_var
(
var
)
block
.
_remove_var
(
var
)
def
_change_while_op_input_and_output
(
auto_parallel_main_prog
,
dist_context
):
"""Change while op input and output after the corresponding sub block ops removed"""
global
while_block_info
for
sub_block_idx
in
while_block_info
:
sub_block
=
auto_parallel_main_prog
.
blocks
[
sub_block_idx
]
parent_while_op_id
=
while_block_info
[
sub_block_idx
][
"op_id"
]
parent_block
=
auto_parallel_main_prog
.
blocks
[
sub_block
.
parent_idx
]
sub_block_op_inputs
=
set
()
sub_block_op_outputs
=
[]
for
op
in
sub_block
.
ops
:
# skip the input and output of operators inserted in the reshard phase
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
dist_op
:
for
var_name
in
op
.
output_arg_names
:
if
var_name
not
in
sub_block_op_outputs
:
sub_block_op_outputs
.
append
(
var_name
)
for
var_name
in
op
.
input_arg_names
:
sub_block_op_inputs
.
add
(
var_name
)
# find the while op
while_op
=
None
for
op
in
parent_block
.
ops
:
if
op
.
desc
.
id
()
==
parent_while_op_id
and
op
.
type
==
"while"
:
while_op
=
op
break
assert
while_op
is
not
None
# find the actual input and output of while op
proto
=
OpProtoHolder
.
instance
().
get_op_proto
(
while_op
.
type
)
new_X
=
[]
for
var_name
in
while_op
.
input
(
"X"
):
if
var_name
in
sub_block_op_inputs
:
new_X
.
append
(
var_name
)
assert
new_X
while_op
.
desc
.
set_input
(
proto
.
inputs
[
0
].
name
,
new_X
)
new_Out
=
[]
for
var_name
in
while_op
.
output
(
"Out"
):
for
output_name
in
sub_block_op_outputs
[::
-
1
]:
if
output_name
.
find
(
var_name
)
!=
-
1
:
new_Out
.
append
(
output_name
)
assert
new_Out
while_op
.
desc
.
set_output
(
proto
.
outputs
[
0
].
name
,
new_Out
)
def
remove_no_need_in_main
(
auto_parallel_main_prog
,
dist_context
,
rank_id
,
def
remove_no_need_in_main
(
auto_parallel_main_prog
,
dist_context
,
rank_id
,
dist_params_grads
):
dist_params_grads
):
"""Remove no need vars and ops in the main program."""
"""Remove no need vars and ops in the main program."""
_remove_no_need_ops
(
auto_parallel_main_prog
,
dist_context
,
rank_id
)
_remove_no_need_ops
(
auto_parallel_main_prog
,
dist_context
,
rank_id
)
_change_while_op_input_and_output
(
auto_parallel_main_prog
,
dist_context
)
_remove_no_need_vars
(
auto_parallel_main_prog
,
dist_params_grads
)
_remove_no_need_vars
(
auto_parallel_main_prog
,
dist_params_grads
)
...
@@ -992,8 +1182,70 @@ def remove_no_need_in_startup(auto_parallel_main_prog,
...
@@ -992,8 +1182,70 @@ def remove_no_need_in_startup(auto_parallel_main_prog,
startup_block
.
_remove_op
(
idx
)
startup_block
.
_remove_op
(
idx
)
def
reshard
(
auto_parallel_main_prog
,
auto_parallel_startup_prog
,
rank_id
,
def
_get_process_meshes
(
op
,
program
,
dist_context
):
dist_context
,
dist_params_grads
):
"""Get all process meshes when op has sub block."""
assert
op
.
has_attr
(
"sub_block"
)
sub_block
=
program
.
blocks
[
op
.
attr
(
"sub_block"
).
id
]
ops
=
sub_block
.
ops
op_process_mesh
=
dist_context
.
get_dist_op_for_program
(
op
).
dist_attr
.
process_mesh
process_meshes
=
[]
for
op
in
ops
:
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
if
process_mesh
not
in
process_meshes
and
process_mesh
!=
op_process_mesh
:
process_meshes
.
append
(
process_mesh
)
if
not
process_meshes
:
process_meshes
.
append
(
op_process_mesh
)
return
process_meshes
def
_is_condition_replicative
(
op
,
program
,
dist_context
):
assert
op
.
type
==
"while"
sub_block
=
program
.
blocks
[
op
.
attr
(
"sub_block"
).
id
]
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
op_dist_attr
=
dist_op
.
dist_attr
# the dims mapping of condition tensor should be replicative
for
var_name
in
op
.
input
(
"Condition"
):
var
=
_get_var
(
var_name
,
sub_block
,
program
)
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
tensor_dist_attr
=
dist_tensor
.
dist_attr
var_dims_mapping
=
tensor_dist_attr
.
dims_mapping
for
dim
in
var_dims_mapping
:
if
dim
!=
-
1
:
return
False
return
True
def
_get_op_process_meshes
(
op
,
dist_context
):
process_meshes
=
[]
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
op_process_mesh
=
dist_op
.
dist_attr
.
process_mesh
for
process_mesh
in
dist_context
.
process_meshes
:
if
set
(
process_mesh
.
processes
)
&
(
set
(
op_process_mesh
.
processes
)
)
and
len
(
process_mesh
.
processes
)
<=
len
(
op_process_mesh
.
processes
):
process_meshes
.
append
(
process_mesh
)
# it means the process mesh is not a union when process meshes is null
if
not
process_meshes
:
process_meshes
.
append
(
op_process_mesh
)
return
process_meshes
def
reshard
(
auto_parallel_main_prog
,
auto_parallel_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
,
batch_size
=
None
):
"""
"""
Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute.
Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute.
...
@@ -1019,65 +1271,137 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
...
@@ -1019,65 +1271,137 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
return
True
return
True
return
False
return
False
block
=
auto_parallel_main_prog
.
global_block
()
global
while_block_info
idx
=
0
for
block_idx
,
block
in
enumerate
(
auto_parallel_main_prog
.
blocks
):
while
idx
<
len
(
block
.
ops
):
if
block_idx
in
while_block_info
:
pre_op_count
=
len
(
block
.
ops
)
if
"var_reshard_mapping"
in
while_block_info
[
block_idx
]:
op
=
block
.
ops
[
idx
]
var_reshard_mapping
=
while_block_info
[
block_idx
][
"var_reshard_mapping"
]
for
op
in
block
.
ops
:
for
var_name
in
op
.
input_arg_names
:
if
var_name
in
var_reshard_mapping
:
op
.
desc
.
_rename_input
(
var_name
,
var_reshard_mapping
[
var_name
])
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
op_dist_attr
=
dist_op
.
dist_attr
if
op_dist_attr
.
process_mesh
==
while_block_info
[
block_idx
][
"actual_process_mesh"
]:
dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
var_name
)
op_dist_attr
.
set_input_dims_mapping
(
var_reshard_mapping
[
var_name
],
dims_mapping
)
op_dist_attr
.
set_input_dist_attr
(
var_name
,
None
)
# the outputs also need to be renamed when the output name is the same with input name
for
var_name
in
op
.
output_arg_names
:
if
var_name
in
var_reshard_mapping
:
op
.
desc
.
_rename_output
(
var_name
,
var_reshard_mapping
[
var_name
])
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
op_dist_attr
=
dist_op
.
dist_attr
if
op_dist_attr
.
process_mesh
==
while_block_info
[
block_idx
][
"actual_process_mesh"
]:
dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
var_name
)
op_dist_attr
.
set_output_dims_mapping
(
var_reshard_mapping
[
var_name
],
dims_mapping
)
op_dist_attr
.
set_output_dist_attr
(
var_name
,
None
)
idx
=
0
while
idx
<
len
(
block
.
ops
):
pre_op_count
=
len
(
block
.
ops
)
op
=
block
.
ops
[
idx
]
if
_is_special_op
(
op
):
idx
+=
1
continue
if
_is_special_op
(
op
):
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
idx
+=
1
if
dist_op
is
not
None
:
continue
process_meshes
=
[]
if
op
.
type
==
"while"
:
if
not
_is_condition_replicative
(
op
,
auto_parallel_main_prog
,
dist_context
):
raise
ValueError
(
"Please check the condition due to the dims mapping is not replicative."
)
process_meshes
=
_get_process_meshes
(
op
,
auto_parallel_main_prog
,
dist_context
)
assert
process_meshes
if
op
.
attr
(
"sub_block"
).
id
not
in
while_block_info
:
while_block_info
[
op
.
attr
(
"sub_block"
).
id
]
=
{}
while_block_info
[
op
.
attr
(
"sub_block"
).
id
][
"op_id"
]
=
op
.
desc
.
id
()
while_block_info
[
op
.
attr
(
"sub_block"
).
id
][
"actual_process_mesh"
]
=
_get_while_op_actual_process_mesh
(
op
,
auto_parallel_main_prog
,
rank_id
,
dist_context
)
else
:
process_meshes
=
_get_op_process_meshes
(
op
,
dist_context
)
input_vars
=
None
if
op
.
type
==
"while"
:
input_var_names
=
op
.
input
(
"X"
)
else
:
input_var_names
=
op
.
input_arg_names
idx_offset
=
0
for
var_name
in
op
.
input_arg_names
:
# skip lod_tensor_blocking_queue_0
if
var_name
==
"lod_tensor_blocking_queue_0"
:
continue
var
=
_get_var
(
var_name
,
block
,
auto_parallel_main_prog
)
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
for
process_mesh
in
process_meshes
:
if
dist_tensor
is
not
None
and
_need_reshard
(
dist_tensor
,
dist_op
,
process_mesh
,
auto_parallel_main_prog
,
dist_context
):
reshard_op_desc
=
find_op_desc_seq
(
dist_tensor
,
dist_op
,
process_mesh
,
batch_size
)
parse_op_desc
(
block
,
rank_id
,
reshard_op_desc
,
var_name
,
op
,
dist_context
,
auto_parallel_main_prog
,
process_mesh
)
cur_op_count
=
len
(
block
.
ops
)
idx_offset
=
idx_offset
+
cur_op_count
-
pre_op_count
pre_op_count
=
cur_op_count
idx
=
idx
+
idx_offset
+
1
else
:
idx
+=
1
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
# insert send and recv op if output process mesh is different from tensor process mesh
if
dist_op
is
not
None
:
idx
=
0
idx_offset
=
0
# skip reader and ops whose process mesh is union
for
var_name
in
op
.
input_arg_names
:
skip_ops
=
[
# skip lod_tensor_blocking_queue_0
"create_py_reader"
,
"create_double_buffer_reader"
,
"read"
,
"while"
,
if
var_name
==
"lod_tensor_blocking_queue_0"
:
"write_to_array"
,
"read_from_array"
continue
]
var
=
block
.
vars
[
var_name
]
skip_ops
+=
_g_special_ops
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
while
idx
<
len
(
block
.
ops
):
if
dist_tensor
is
not
None
and
_need_reshard
(
dist_tensor
,
pre_op_count
=
len
(
block
.
ops
)
dist_op
):
op
=
block
.
ops
[
idx
]
reshard_op_desc
=
find_op_desc_seq
(
dist_tensor
,
dist_op
)
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
parse_op_desc
(
auto_parallel_main_prog
,
rank_id
,
if
dist_op
is
not
None
and
op
.
type
not
in
skip_ops
:
reshard_op_desc
,
var_name
,
op
,
dist_context
)
for
var_name
in
op
.
output_arg_names
:
cur_op_count
=
len
(
block
.
ops
)
var
=
_get_var
(
var_name
,
block
,
auto_parallel_main_prog
)
idx_offset
=
idx_offset
+
cur_op_count
-
pre_op_count
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
pre_op_count
=
cur_op_count
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
idx
=
idx
+
idx_offset
+
1
if
dist_tensor
is
not
None
and
_need_reshard
(
else
:
dist_tensor
,
dist_op
,
process_mesh
,
idx
+=
1
auto_parallel_main_prog
,
dist_context
,
False
):
for
index
,
item
in
enumerate
(
# insert send and recv op if output process mesh is different from tensor process mesh
dist_op
.
dist_attr
.
process_mesh
.
processes
):
idx
=
0
recv_rank
=
dist_tensor
.
dist_attr
.
process_mesh
.
processes
[
skip_ops
=
[
"create_py_reader"
,
"create_double_buffer_reader"
,
"read"
]
index
]
skip_ops
+=
_g_special_ops
if
rank_id
==
item
:
while
idx
<
len
(
block
.
ops
):
_insert_send_op
(
block
,
idx
+
1
,
var
,
recv_rank
,
pre_op_count
=
len
(
block
.
ops
)
op
.
attr
(
'op_role'
))
op
=
block
.
ops
[
idx
]
if
rank_id
==
recv_rank
:
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
_insert_recv_op
(
block
,
idx
+
1
,
var
,
item
,
if
dist_op
is
not
None
and
op
.
type
not
in
skip_ops
:
op
.
attr
(
'op_role'
))
for
var_name
in
op
.
output_arg_names
:
cur_op_count
=
len
(
block
.
ops
)
var
=
block
.
vars
[
var_name
]
idx_offset
=
idx_offset
+
cur_op_count
-
pre_op_count
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
pre_op_count
=
cur_op_count
if
dist_tensor
is
not
None
and
_need_reshard
(
dist_tensor
,
idx
=
idx
+
idx_offset
+
1
dist_op
,
False
):
else
:
for
index
,
item
in
enumerate
(
idx
+=
1
dist_op
.
dist_attr
.
process_mesh
.
processes
):
recv_rank
=
dist_tensor
.
dist_attr
.
process_mesh
.
processes
[
index
]
if
rank_id
==
item
:
_insert_send_op
(
block
,
idx
+
1
,
var
,
recv_rank
)
if
rank_id
==
recv_rank
:
_insert_recv_op
(
block
,
idx
+
1
,
var
,
item
)
cur_op_count
=
len
(
block
.
ops
)
idx_offset
=
idx_offset
+
cur_op_count
-
pre_op_count
pre_op_count
=
cur_op_count
idx
=
idx
+
idx_offset
+
1
else
:
idx
+=
1
# remove no need vars and ops in the main program
# remove no need vars and ops in the main program
remove_no_need_in_main
(
auto_parallel_main_prog
,
dist_context
,
rank_id
,
remove_no_need_in_main
(
auto_parallel_main_prog
,
dist_context
,
rank_id
,
...
...
python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py
浏览文件 @
2747de2b
...
@@ -32,6 +32,7 @@ from paddle.fluid.initializer import NumpyArrayInitializer
...
@@ -32,6 +32,7 @@ from paddle.fluid.initializer import NumpyArrayInitializer
from
paddle.distributed.auto_parallel.utils
import
save_distributed_checkpoint
,
load_distributed_checkpoint
,
load_checkpoint_into_program
from
paddle.distributed.auto_parallel.utils
import
save_distributed_checkpoint
,
load_distributed_checkpoint
,
load_checkpoint_into_program
from
paddle.distributed.auto_parallel.utils
import
get_dist_attr
,
merge_and_slice_parameter
,
load_parameter_into_program
from
paddle.distributed.auto_parallel.utils
import
get_dist_attr
,
merge_and_slice_parameter
,
load_parameter_into_program
from
paddle.distributed.auto_parallel.reshard
import
HAS_SENT
,
HAS_RECV
,
HAS_ALLGATHER
from
paddle.distributed.auto_parallel.reshard
import
HAS_SENT
,
HAS_RECV
,
HAS_ALLGATHER
from
paddle.distributed.auto_parallel.dist_context
import
set_default_distributed_context
paddle
.
enable_static
()
paddle
.
enable_static
()
_global_parallel_strategy
=
None
_global_parallel_strategy
=
None
...
@@ -185,6 +186,7 @@ class TestMLPAutoConvert(unittest.TestCase):
...
@@ -185,6 +186,7 @@ class TestMLPAutoConvert(unittest.TestCase):
str
(
paddle
.
distributed
.
get_rank
())))
str
(
paddle
.
distributed
.
get_rank
())))
def
test_mlp_mp2pp
(
self
):
def
test_mlp_mp2pp
(
self
):
set_default_distributed_context
(
None
)
global
_global_parallel_strategy
global
_global_parallel_strategy
_global_parallel_strategy
=
"mp"
_global_parallel_strategy
=
"mp"
global
_global_process_mesh
global
_global_process_mesh
...
@@ -211,6 +213,7 @@ class TestMLPAutoConvert(unittest.TestCase):
...
@@ -211,6 +213,7 @@ class TestMLPAutoConvert(unittest.TestCase):
fetch_list
=
[
loss
])
fetch_list
=
[
loss
])
last_res
=
res
[
0
]
last_res
=
res
[
0
]
set_default_distributed_context
(
None
)
_global_parallel_strategy
=
"pp"
_global_parallel_strategy
=
"pp"
_global_process_mesh
=
auto
.
ProcessMesh
([
0
,
1
])
_global_process_mesh
=
auto
.
ProcessMesh
([
0
,
1
])
global
PP_MESH_0
global
PP_MESH_0
...
@@ -266,6 +269,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
...
@@ -266,6 +269,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
str
(
paddle
.
distributed
.
get_rank
())))
str
(
paddle
.
distributed
.
get_rank
())))
def
test_mlp_pp2mp
(
self
):
def
test_mlp_pp2mp
(
self
):
set_default_distributed_context
(
None
)
global
_global_parallel_strategy
global
_global_parallel_strategy
_global_parallel_strategy
=
"pp"
_global_parallel_strategy
=
"pp"
global
_global_process_mesh
global
_global_process_mesh
...
@@ -302,6 +306,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
...
@@ -302,6 +306,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
if
paddle
.
distributed
.
get_rank
()
in
[
1
]:
if
paddle
.
distributed
.
get_rank
()
in
[
1
]:
last_res
=
res
[
0
]
last_res
=
res
[
0
]
set_default_distributed_context
(
None
)
_global_parallel_strategy
=
"mp"
_global_parallel_strategy
=
"mp"
_global_process_mesh
=
auto
.
ProcessMesh
([
0
,
1
])
_global_process_mesh
=
auto
.
ProcessMesh
([
0
,
1
])
...
@@ -345,6 +350,7 @@ class TestMLPAutoConvertInvalid(unittest.TestCase):
...
@@ -345,6 +350,7 @@ class TestMLPAutoConvertInvalid(unittest.TestCase):
np
.
random
.
seed
(
2021
)
np
.
random
.
seed
(
2021
)
def
test_input_invalid
(
self
):
def
test_input_invalid
(
self
):
set_default_distributed_context
(
None
)
global
_global_parallel_strategy
global
_global_parallel_strategy
_global_parallel_strategy
=
"mp"
_global_parallel_strategy
=
"mp"
global
_global_process_mesh
global
_global_process_mesh
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录