Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9db507f1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9db507f1
编写于
11月 07, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
11月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] update naive data parallel completion (#47578)
* expand op donot use naive data parallel * fix unittest
上级
b0c38568
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
94 addition
and
129 deletion
+94
-129
python/paddle/distributed/auto_parallel/completion.py
python/paddle/distributed/auto_parallel/completion.py
+33
-108
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+27
-20
python/paddle/distributed/auto_parallel/planner_v2.py
python/paddle/distributed/auto_parallel/planner_v2.py
+3
-1
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+31
-0
未找到文件。
python/paddle/distributed/auto_parallel/completion.py
浏览文件 @
9db507f1
...
...
@@ -13,10 +13,11 @@
# limitations under the License.
import
copy
import
time
import
logging
from
paddle.fluid
import
core
from
.utils
import
is_naive_data_parallel
,
get_logger
from
.utils
import
is_gradient_clip_op
,
__not_shape_var_type__
from
.operators
import
find_compatible_distributed_operator_impls
from
.dist_context
import
_node_id
...
...
@@ -142,6 +143,7 @@ class Completer:
assert
dist_context
is
not
None
self
.
_dist_context
=
dist_context
self
.
_has_prepared
=
False
self
.
_logger
=
get_logger
(
logging
.
INFO
,
"Completer"
)
def
_update_tensor_node_dims_mapping
(
self
,
tensor_node
,
fwd
=
True
):
changed
=
False
...
...
@@ -974,138 +976,60 @@ class Completer:
else
:
self
.
_dist_context
.
_serial_main_program
=
serial_main_program
start_time
=
time
.
time
()
# print("start time", start_time, flush=True)
if
not
self
.
_dist_context
.
data_parallel
:
if
not
is_naive_data_parallel
(
self
.
_dist_context
):
self
.
_dist_context
.
initialize
(
with_graph
=
True
)
# self._dist_context.validate_dist_attr_for_program()
self
.
_prepare
()
self
.
_update_process_mesh
()
self
.
_update_dims_mapping
()
# Copy the corresponding distributed attribute from graph to serial_main_program
self
.
_dist_context
.
copy_dist_attr_from_graph_to_program
()
else
:
self
.
_logger
.
info
(
"Default data parallel will be set."
)
self
.
_dist_context
.
initialize
(
with_graph
=
False
)
# A fast and special completion for data parallel
self
.
_update_dist_attr_for_dp
()
# print_program_with_dist_attr(self._dist_context.serial_main_program,
# self._dist_context)
# NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient
self
.
_complete_high_order_grad_annotation
(
serial_main_program
)
# Do the validation check and amend some completion
self
.
_dist_context
.
amend_dist_attr_for_program
()
self
.
_dist_context
.
validate_dist_attr_for_program
()
end_time
=
time
.
time
()
# print("end time", end_time, flush=True)
# print("elapsed time", end_time - start_time, flush=True)
return
serial_main_program
def
_update_dist_attr_for_dp
(
self
):
# TODO: we must ensure the world process group contains all ranks
ranks
=
get_world_process_group
().
ranks
process_mesh
=
ProcessMesh
(
ranks
)
for
(
dist_tensor
)
in
self
.
_dist_context
.
_dist_tensors_for_program
.
values
():
serial_tensor
=
dist_tensor
.
serial_tensor
tensor_dist_attr
=
dist_tensor
.
dist_attr
tensor_dist_attr
.
process_mesh
=
process_mesh
for
dist_op
in
self
.
_dist_context
.
_dist_ops_for_program
.
values
():
dist_tensors
=
self
.
_dist_context
.
_dist_tensors_for_program
for
dist_tensor
in
dist_tensors
.
values
():
dist_tensor
.
dist_attr
.
process_mesh
=
process_mesh
dist_ops
=
self
.
_dist_context
.
_dist_ops_for_program
for
dist_op
in
dist_ops
.
values
():
serial_op
=
dist_op
.
serial_op
op_desc
=
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
op_dist_attr
.
process_mesh
=
process_mesh
original_op_dist_attr
=
copy
.
deepcopy
(
op_dist_attr
)
input_xshape_arg_names
=
[]
if
"XShape"
in
op_desc
.
input_names
():
input_xshape_arg_names
=
op_desc
.
input
(
"XShape"
)
for
arg_name
in
serial_op
.
input_arg_names
:
serial_tensor
=
dist_op
.
get_serial_input
(
arg_name
)
if
not
serial_tensor
.
is_parameter
:
if
arg_name
not
in
input_xshape_arg_names
:
old_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
arg_name
dist_tensor
=
(
self
.
_dist_context
.
get_dist_tensor_for_program
(
serial_tensor
)
if
len
(
old_dims_mapping
)
>
0
:
new_dims_mapping
=
[
0
]
+
[
-
1
for
_
in
range
(
len
(
old_dims_mapping
)
-
1
)
]
op_dist_attr
.
set_input_dims_mapping
(
arg_name
,
new_dims_mapping
)
else
:
old_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
arg_name
)
if
len
(
old_dims_mapping
)
>
1
:
new_dims_mapping
=
[
-
1
,
0
]
+
[
-
1
for
_
in
range
(
len
(
old_dims_mapping
)
-
2
)
]
op_dist_attr
.
set_input_dims_mapping
(
arg_name
,
new_dims_mapping
)
# Set tensor's dims_mapping by the op's
tensor_dist_attr
=
(
self
.
_dist_context
.
get_tensor_dist_attr_for_program
(
serial_tensor
)
)
tensor_dist_attr
.
dims_mapping
=
(
op_dist_attr
.
get_input_dims_mapping
(
arg_name
)
)
output_xshape_arg_names
=
[]
if
"XShape"
in
op_desc
.
output_names
():
output_xshape_arg_names
=
op_desc
.
output
(
"XShape"
)
for
arg_name
in
serial_op
.
output_arg_names
:
serial_tensor
=
dist_op
.
get_serial_output
(
arg_name
)
if
not
serial_tensor
.
is_parameter
:
if
arg_name
not
in
output_xshape_arg_names
:
old_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
arg_name
)
if
len
(
old_dims_mapping
)
>
0
:
new_dims_mapping
=
[
0
]
+
[
-
1
for
_
in
range
(
len
(
old_dims_mapping
)
-
1
)
]
op_dist_attr
.
set_output_dims_mapping
(
arg_name
,
new_dims_mapping
)
else
:
old_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
arg_name
)
if
len
(
old_dims_mapping
)
>
1
:
new_dims_mapping
=
[
-
1
,
0
]
+
[
-
1
for
_
in
range
(
len
(
old_dims_mapping
)
-
2
)
]
op_dist_attr
.
set_output_dims_mapping
(
arg_name
,
new_dims_mapping
)
# Set tensor's dims_mapping by the op's
tensor_dist_attr
=
(
self
.
_dist_context
.
get_tensor_dist_attr_for_program
(
serial_tensor
op_dist_attr
=
dist_op
.
dist_attr
op_dist_attr
.
process_mesh
=
(
dist_tensor
.
dist_attr
.
process_mesh
)
op_dist_attr
.
set_input_dims_mapping
(
arg_name
,
dist_tensor
.
dist_attr
.
dims_mapping
)
)
tensor_dist_attr
.
dims_mapping
=
(
op_dist_attr
.
get_output_dims_mapping
(
arg_name
)
)
op_dist_impls
=
find_compatible_distributed_operator_impls
(
dist_op
,
partial
=
Fals
e
dist_op
,
fwd
=
Tru
e
)
if
op_dist_impls
is
not
None
:
not_compatible
=
True
...
...
@@ -1127,6 +1051,16 @@ class Completer:
else
:
dist_op
.
dist_attr
=
original_op_dist_attr
for
arg_name
in
serial_op
.
output_arg_names
:
op_dist_attr
=
dist_op
.
dist_attr
serial_tensor
=
dist_op
.
get_serial_output
(
arg_name
)
dist_tensor
=
self
.
_dist_context
.
get_dist_tensor_for_program
(
serial_tensor
)
dist_tensor
.
dist_attr
.
dims_mapping
=
(
op_dist_attr
.
get_output_dims_mapping
(
arg_name
)
)
def
_complete_tensor_dist_attr_by_op
(
self
,
serial_main_program
=
None
):
if
serial_main_program
is
None
:
serial_main_program
=
self
.
_dist_context
.
serial_main_program
...
...
@@ -1942,19 +1876,10 @@ class Completer:
else
:
self
.
_dist_context
.
_serial_main_program
=
serial_main_program
import
time
start_time
=
time
.
time
()
self
.
_dist_context
.
_is_initialized
=
True
start_time
=
time
.
time
()
self
.
_dist_context
.
_init_dist_attr_for_program
()
start_time
=
time
.
time
()
self
.
_init_global_mesh_for_program
()
# Do the validation check and amend some completion
start_time
=
time
.
time
()
self
.
_dist_context
.
amend_dist_attr_for_program
()
self
.
_dist_context
.
validate_dist_attr_for_program
()
...
...
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
9db507f1
...
...
@@ -22,6 +22,7 @@ from collections import defaultdict
import
paddle
import
paddle.utils
as
utils
import
paddle.distributed.auto_parallel.utils
as
auto_utils
from
paddle
import
fluid
,
static
from
paddle.metric
import
Metric
...
...
@@ -47,12 +48,10 @@ from .dist_loader import (
DistributedDataLoaderFromGenerator
,
DistributedDataLoader
,
)
from
.strategy
import
Strategy
from
.process_group
import
new_process_group
,
get_all_process_groups
from
.dist_context
import
DistributedContext
,
get_default_distributed_context
from
.strategy
import
Strategy
from
.interface
import
CollectionNames
,
get_collection
from
.utils
import
to_list
,
get_dist_attr
,
get_lr
,
validate_opt
from
.utils
import
initialize_pg_in_full_mode
,
get_input_split_info
from
.cost.estimate_cost
import
get_cost_from_engine
from
..utils.log_utils
import
get_logger
...
...
@@ -159,18 +158,18 @@ class Engine:
"'optimizer' must be object of class `paddle.optimizer.Optimizer`"
" or `paddle.fluid.optimizer.Optimizer`."
)
self
.
_optimizer
=
validate_opt
(
optimizer
)
self
.
_optimizer
=
auto_utils
.
validate_opt
(
optimizer
)
self
.
_orig_optimizer
=
copy
.
deepcopy
(
self
.
_optimizer
)
metrics
=
metrics
or
[]
for
metric
in
to_list
(
metrics
):
for
metric
in
auto_utils
.
to_list
(
metrics
):
if
metric
and
not
isinstance
(
metric
,
Metric
):
raise
TypeError
(
"{} is not sub class of Metric"
.
format
(
metric
.
__class__
.
__name__
)
)
self
.
_metrics
=
to_list
(
metrics
)
self
.
_metrics
=
auto_utils
.
to_list
(
metrics
)
if
cluster
and
not
isinstance
(
cluster
,
Cluster
):
raise
TypeError
(
...
...
@@ -253,8 +252,8 @@ class Engine:
type
(
data
).
__name__
)
)
inputs
=
to_list
(
inputs
)
labels
=
to_list
(
labels
)
inputs
=
auto_utils
.
to_list
(
inputs
)
labels
=
auto_utils
.
to_list
(
labels
)
num_shards
=
self
.
_strategy
.
dataset
.
num_shards
...
...
@@ -481,7 +480,7 @@ class Engine:
if
metric_out
:
metric
.
update
(
*
metric_out
)
results
=
metric
.
accumulate
()
for
i
,
res
in
enumerate
(
to_list
(
results
)):
for
i
,
res
in
enumerate
(
auto_utils
.
to_list
(
results
)):
logs
[
metric
.
name
()[
i
]]
=
res
group_idx
+=
1
# logging outputs
...
...
@@ -562,7 +561,7 @@ class Engine:
s
.
_create_feed_layer
()
for
s
in
self
.
_labels_spec
]
outputs
=
to_list
(
self
.
_model
(
*
self
.
_inputs
))
outputs
=
auto_utils
.
to_list
(
self
.
_model
(
*
self
.
_inputs
))
if
mode
!=
"predict"
and
self
.
_loss
:
assert
isinstance
(
...
...
@@ -570,14 +569,14 @@ class Engine:
)
or
callable
(
self
.
_loss
),
"the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function."
self
.
_losses
=
to_list
(
self
.
_losses
=
auto_utils
.
to_list
(
self
.
_loss
(
*
(
outputs
+
self
.
_labels
))
)
if
mode
!=
"predict"
and
(
outputs
or
self
.
_labels
):
for
metric
in
self
.
_metrics
:
metrics
.
append
(
to_list
(
auto_utils
.
to_list
(
metric
.
compute
(
*
(
outputs
+
self
.
_labels
))
)
)
...
...
@@ -585,7 +584,7 @@ class Engine:
assert
isinstance
(
self
.
_loss
,
Variable
),
"the type of `loss` of the Engine arguments should be Variable."
self
.
_losses
=
to_list
(
self
.
_loss
)
self
.
_losses
=
auto_utils
.
to_list
(
self
.
_loss
)
default_ctx
=
get_default_distributed_context
()
if
not
default_ctx
.
has_annotation
:
...
...
@@ -593,6 +592,12 @@ class Engine:
# needs all ranks by default.
new_process_group
(
list
(
range
(
self
.
_nranks
)))
default_ctx
.
data_parallel
=
True
self
.
_inputs
=
[
auto_utils
.
set_data_parallel
(
var
)
for
var
in
self
.
_inputs
]
self
.
_labels
=
[
auto_utils
.
set_data_parallel
(
var
)
for
var
in
self
.
_labels
]
feed_vars
=
{
"inputs"
:
self
.
_inputs
,
"labels"
:
self
.
_labels
}
...
...
@@ -684,7 +689,7 @@ class Engine:
self
.
_dp_world_sizes
=
[]
self
.
_dp_ranks
=
[]
for
feed_var
in
feed_list
:
dp_world_size
,
dp_rank
=
get_input_split_info
(
dp_world_size
,
dp_rank
=
auto_utils
.
get_input_split_info
(
self
.
_cur_rank
,
feed_var
,
self
.
_dist_contexts
[
mode
]
)
self
.
_dp_world_sizes
.
append
(
dp_world_size
)
...
...
@@ -749,7 +754,9 @@ class Engine:
cur_rank
=
self
.
_cur_rank
# NOTE: After the implementation of the unified dynamic and static communication group initialization mode in the future, the initialization logic of full mode will be removed because port occupation error may occur.
if
self
.
_strategy
.
auto_mode
==
"full"
:
initialize_pg_in_full_mode
(
all_process_groups
,
cur_rank
)
auto_utils
.
initialize_pg_in_full_mode
(
all_process_groups
,
cur_rank
)
else
:
for
process_group
in
all_process_groups
:
if
cur_rank
not
in
process_group
.
ranks
:
...
...
@@ -927,7 +934,7 @@ class Engine:
)
except
core
.
EOFException
:
break
lr
=
get_lr
(
self
.
_optimizer
)
lr
=
auto_utils
.
get_lr
(
self
.
_optimizer
)
logs
=
self
.
_prepare_logger
(
outs
,
epoch
,
...
...
@@ -1474,7 +1481,7 @@ class Engine:
self
.
_optimization_tuning
(
self
.
_mode
,
tune_data
,
batch_size
)
def
_validate_spec
(
self
,
specs
):
specs
=
to_list
(
specs
)
specs
=
auto_utils
.
to_list
(
specs
)
self
.
_k_steps
=
self
.
_strategy
.
gradient_merge
.
k_steps
if
specs
is
not
None
:
for
i
,
spec
in
enumerate
(
specs
):
...
...
@@ -1500,7 +1507,7 @@ class Engine:
return
specs
or
[]
def
_validate_vars
(
self
,
vars
):
vars
=
to_list
(
vars
)
vars
=
auto_utils
.
to_list
(
vars
)
if
vars
is
not
None
:
for
i
,
var
in
enumerate
(
vars
):
if
not
isinstance
(
var
,
Variable
):
...
...
@@ -1547,7 +1554,7 @@ class Engine:
def
_metrics_name
(
self
):
metrics_name
=
[
'loss'
]
if
self
.
_loss
else
[]
for
m
in
self
.
_metrics
:
metrics_name
.
extend
(
to_list
(
m
.
name
()))
metrics_name
.
extend
(
auto_utils
.
to_list
(
m
.
name
()))
return
metrics_name
def
_switch_mode
(
self
,
mode
):
...
...
@@ -1568,7 +1575,7 @@ class Engine:
def
_set_state_dict
(
self
,
mode
,
strict
,
state_dict
,
dist_attr
):
program
=
self
.
_dist_main_progs
[
mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
mode
]
cur_dist_attr
=
get_dist_attr
(
program
,
dist_context
)
cur_dist_attr
=
auto_utils
.
get_dist_attr
(
program
,
dist_context
)
converter
=
Converter
(
state_dict
,
dist_attr
,
cur_dist_attr
)
state_dict
=
converter
.
convert
(
strict
=
strict
)
program
.
set_state_dict
(
state_dict
)
...
...
python/paddle/distributed/auto_parallel/planner_v2.py
浏览文件 @
9db507f1
...
...
@@ -15,6 +15,7 @@
from
.completion
import
Completer
from
.dist_context
import
get_default_distributed_context
from
.tuner.parallel_tuner
import
ParallelTuner
from
.utils
import
is_naive_data_parallel
class
Planner
:
...
...
@@ -26,7 +27,8 @@ class Planner:
# dependency of backward-forward ops in forward completion.
default_ctx
=
get_default_distributed_context
()
self
.
_dist_context
.
_dist_op_context
=
default_ctx
.
dist_op_context
if
not
default_ctx
.
data_parallel
:
self
.
_dist_context
.
data_parallel
=
default_ctx
.
data_parallel
if
not
is_naive_data_parallel
(
self
.
_dist_context
):
# Use SSA graph for complex parallism
self
.
_dist_context
.
initialize
(
with_graph
=
True
)
else
:
...
...
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
9db507f1
...
...
@@ -37,6 +37,8 @@ __not_shape_var_type__ = [
core
.
VarDesc
.
VarType
.
STEP_SCOPES
,
]
__not_naive_data_parallel_op__
=
[
"expand_v2"
]
def
get_logger
(
log_level
,
name
=
"auto_parallel"
):
logger
=
logging
.
getLogger
(
name
)
...
...
@@ -1909,6 +1911,35 @@ def validate_opt(optimizer):
return
optimizer
def
set_data_parallel
(
x
):
from
.process_group
import
get_world_process_group
from
.interface
import
shard_tensor
,
ProcessMesh
world_ranks
=
get_world_process_group
().
ranks
process_mesh
=
ProcessMesh
(
world_ranks
,
[
'dp'
])
shard_spec
=
[
'dp'
if
len
(
world_ranks
)
>
1
else
None
]
+
[
None
for
_
in
range
(
len
(
x
.
shape
)
-
1
)
]
return
shard_tensor
(
x
,
process_mesh
,
shard_spec
)
def
is_naive_data_parallel
(
dist_context
):
# Navie data parallel only completes dist_attr once from the front to back.
if
not
dist_context
.
data_parallel
:
return
False
ops_type
=
[
op
.
type
for
op
in
dist_context
.
_original_serial_main_program
.
global_block
().
ops
]
if
(
not
set
(
ops_type
)
&
set
(
__not_naive_data_parallel_op__
)
)
and
dist_context
.
data_parallel
:
return
True
return
False
def
_copy_tensor_dist_attr_to_cpp
(
cpp_dist_attr
,
py_dist_attr
):
py_process_mesh
=
py_dist_attr
.
process_mesh
if
py_process_mesh
is
not
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录