Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
229befc8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
229befc8
编写于
8月 23, 2022
作者:
J
JZ-LIANG
提交者:
GitHub
8月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Data Parallel Comm & Calc Overlap Optimization (#45173)
* bugfix * remove scaling * support rescale_grad opt * add unitest
上级
60e072d3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
84 addition
and
8 deletion
+84
-8
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+4
-2
python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py
...ibuted/passes/auto_parallel_data_parallel_optimization.py
+76
-5
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
+4
-1
未找到文件。
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
229befc8
...
@@ -189,8 +189,9 @@ class Engine:
...
@@ -189,8 +189,9 @@ class Engine:
serial_main_prog
=
self
.
_orig_main_prog
.
clone
()
serial_main_prog
=
self
.
_orig_main_prog
.
clone
()
serial_startup_prog
=
self
.
_orig_startup_prog
.
clone
()
serial_startup_prog
=
self
.
_orig_startup_prog
.
clone
()
# FIXME to support grad clip
# FIXME to support grad clip
with
static
.
program_guard
(
serial_main_prog
,
serial_startup_prog
),
\
# with static.program_guard(serial_main_prog, serial_startup_prog), \
utils
.
unique_name
.
guard
():
# utils.unique_name.guard():
with
static
.
program_guard
(
serial_main_prog
,
serial_startup_prog
):
inputs_spec
=
self
.
inputs_spec
inputs_spec
=
self
.
inputs_spec
labels_spec
=
self
.
labels_spec
if
self
.
labels_spec
else
[]
labels_spec
=
self
.
labels_spec
if
self
.
labels_spec
else
[]
inputs
=
[
s
.
_create_feed_layer
()
for
s
in
inputs_spec
]
inputs
=
[
s
.
_create_feed_layer
()
for
s
in
inputs_spec
]
...
@@ -440,6 +441,7 @@ class Engine:
...
@@ -440,6 +441,7 @@ class Engine:
for
epoch
in
range
(
epochs
):
for
epoch
in
range
(
epochs
):
train_logs
=
{
"epoch: {:d} "
:
epoch
}
train_logs
=
{
"epoch: {:d} "
:
epoch
}
for
step
,
_
in
enumerate
(
train_dataloader
):
for
step
,
_
in
enumerate
(
train_dataloader
):
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
fetch_list
=
fetch_list
,
fetch_list
=
fetch_list
,
use_program_cache
=
use_cache
,
use_program_cache
=
use_cache
,
...
...
python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py
浏览文件 @
229befc8
...
@@ -16,6 +16,7 @@ from collections import OrderedDict
...
@@ -16,6 +16,7 @@ from collections import OrderedDict
import
paddle
import
paddle
from
paddle.fluid.framework
import
default_main_program
from
paddle.fluid.framework
import
default_main_program
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
from
paddle.distributed.auto_parallel.operators.common
import
is_data_parallel_scale_op
,
is_data_parallel_reduce_op
from
paddle.distributed.auto_parallel.operators.common
import
is_data_parallel_scale_op
,
is_data_parallel_reduce_op
from
paddle.distributed.auto_parallel.utils
import
is_loss_grad_op
,
is_optimize_op
,
ring_id_to_process_group
from
paddle.distributed.auto_parallel.utils
import
is_loss_grad_op
,
is_optimize_op
,
ring_id_to_process_group
from
.pass_base
import
PassBase
,
PassType
,
register_pass
from
.pass_base
import
PassBase
,
PassType
,
register_pass
...
@@ -26,6 +27,9 @@ __rescale_grad_supported_opts__ = [
...
@@ -26,6 +27,9 @@ __rescale_grad_supported_opts__ = [
'merge_momentum'
'merge_momentum'
]
]
# a heuristic number
__max_stream_num_allow__
=
16
@
register_pass
(
"auto_parallel_data_parallel_optimization"
)
@
register_pass
(
"auto_parallel_data_parallel_optimization"
)
class
DataParallelOptimizationPass
(
PassBase
):
class
DataParallelOptimizationPass
(
PassBase
):
...
@@ -71,7 +75,7 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -71,7 +75,7 @@ class DataParallelOptimizationPass(PassBase):
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
self
.
_analyze_program
()
self
.
_analyze_program
()
self
.
_prune_grad_scaling
()
self
.
_prune_grad_scaling
()
self
.
_
overlap_comm
()
self
.
_
calc_comm_overlap
()
self
.
_fuse_allreduce
()
self
.
_fuse_allreduce
()
def
_prune_grad_scaling
(
self
):
def
_prune_grad_scaling
(
self
):
...
@@ -86,14 +90,18 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -86,14 +90,18 @@ class DataParallelOptimizationPass(PassBase):
self
.
_remove_grad_scaling
()
self
.
_remove_grad_scaling
()
def
_overlap_comm
(
self
):
def
_calc_comm_overlap
(
self
):
pass
if
not
self
.
_could_be_overlap
():
return
self
.
_calc_overlap_comms
()
self
.
_update_wait_comms
()
def
_fuse_allreduce
(
self
):
def
_fuse_allreduce
(
self
):
pass
pass
def
_analyze_program
(
self
):
def
_analyze_program
(
self
):
"""
"""
build two maps
{param_grad_name: data_parallel_group}
{param_grad_name: data_parallel_group}
{pdata_parallel_group: aram_grad_name}
{pdata_parallel_group: aram_grad_name}
"""
"""
...
@@ -103,8 +111,9 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -103,8 +111,9 @@ class DataParallelOptimizationPass(PassBase):
scaled_grads
=
[]
scaled_grads
=
[]
for
op
in
ops
:
for
op
in
ops
:
grad_name
=
op
.
output_arg_names
[
0
]
if
is_data_parallel_reduce_op
(
op
):
if
is_data_parallel_reduce_op
(
op
):
grad_name
=
op
.
output_arg_names
[
0
]
if
grad_name
in
self
.
_grad_name_to_group_map
:
if
grad_name
in
self
.
_grad_name_to_group_map
:
continue
continue
assert
op
.
has_attr
(
assert
op
.
has_attr
(
...
@@ -123,7 +132,6 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -123,7 +132,6 @@ class DataParallelOptimizationPass(PassBase):
self
.
_group_to_grad_name_map
[
group
].
append
(
grad_name
)
self
.
_group_to_grad_name_map
[
group
].
append
(
grad_name
)
elif
is_data_parallel_scale_op
(
op
):
elif
is_data_parallel_scale_op
(
op
):
grad_name
=
op
.
output_arg_names
[
0
]
scaled_grads
.
append
(
grad_name
)
scaled_grads
.
append
(
grad_name
)
# TODO support multiple optimizers in on network in future.
# TODO support multiple optimizers in on network in future.
...
@@ -206,3 +214,66 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -206,3 +214,66 @@ class DataParallelOptimizationPass(PassBase):
assert
scaled_grads
==
set
(
self
.
_grad_name_to_group_map
.
keys
(
assert
scaled_grads
==
set
(
self
.
_grad_name_to_group_map
.
keys
(
)),
"Unexception: gradients [{}] are unscaled."
.
format
(
)),
"Unexception: gradients [{}] are unscaled."
.
format
(
set
(
self
.
_grad_name_to_group_map
.
keys
())
-
scaled_grads
)
set
(
self
.
_grad_name_to_group_map
.
keys
())
-
scaled_grads
)
def
_could_be_overlap
(
self
):
# NOTE current different nccl comm will use different cuda stream
# so if there too many dp group there will be too many stream need to be
# created and sync.
# revise here when framework support custom stream in static mode.
num_dp_comm_stream
=
len
(
set
(
self
.
_group_to_grad_name_map
.
keys
()))
if
num_dp_comm_stream
>
__max_stream_num_allow__
:
return
False
return
True
def
_calc_overlap_comms
(
self
):
# TODO support InterpreterCore executor for overlap.
# InterpreterCore has a different logic for overlapping
# which is different from use_calc_stream
block
=
default_main_program
().
global_block
()
ops
=
block
.
ops
# comm wait calc to finish
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_data_parallel_reduce_op
(
op
):
assert
op
.
has_attr
(
'use_calc_stream'
)
assert
op
.
has_attr
(
'ring_id'
)
op
.
_set_attr
(
'use_calc_stream'
,
False
)
ring_id
=
op
.
attr
(
"ring_id"
)
block
.
_insert_op_without_sync
(
idx
,
type
=
'c_wait_compute'
,
inputs
=
{
'X'
:
[]},
outputs
=
{
'Out'
:
[]},
attrs
=
{
'op_role'
:
OpRole
.
Backward
,
'ring_id'
:
ring_id
})
block
.
_sync_with_cpp
()
def
_update_wait_comms
(
self
):
block
=
default_main_program
().
global_block
()
ops
=
block
.
ops
# update wait comm to finish
first_optimize_op_idx
=
-
1
for
idx
,
op
in
enumerate
(
ops
):
if
is_optimize_op
(
op
):
first_optimize_op_idx
=
idx
break
assert
first_optimize_op_idx
>
-
1
,
"Unexception: not found optimizer op in program"
for
group
in
self
.
_group_to_grad_name_map
.
keys
():
ring_id
=
group
.
id
block
.
_insert_op_without_sync
(
first_optimize_op_idx
,
type
=
'c_wait_comm'
,
inputs
=
{
'X'
:
[]},
outputs
=
{
'Out'
:
[]},
attrs
=
{
'op_role'
:
OpRole
.
Backward
,
'ring_id'
:
ring_id
})
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
浏览文件 @
229befc8
...
@@ -542,9 +542,12 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
...
@@ -542,9 +542,12 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
fp16_var_names
=
to_fp16_var_names
if
to_fp16_var_names
else
set
()
fp16_var_names
=
to_fp16_var_names
if
to_fp16_var_names
else
set
()
var_scope
=
scope
if
scope
else
global_scope
()
var_scope
=
scope
if
scope
else
global_scope
()
print
(
"======================cast_parameters_to_fp16=============================="
)
for
param
in
all_parameters
:
for
param
in
all_parameters
:
if
param
.
name
in
fp16_var_names
:
if
param
.
name
in
fp16_var_names
:
_logger
.
debug
(
"---- cast {} to fp16 dtype ----"
.
format
(
param
.
name
))
print
(
"---- cast {} to fp16 dtype ----"
.
format
(
param
.
name
))
param_t
=
var_scope
.
find_var
(
param
.
name
).
get_tensor
()
param_t
=
var_scope
.
find_var
(
param
.
name
).
get_tensor
()
data
=
np
.
array
(
param_t
)
data
=
np
.
array
(
param_t
)
param_t
.
set
(
np
.
float16
(
data
),
place
)
param_t
.
set
(
np
.
float16
(
data
),
place
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录