Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e79f24f8
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e79f24f8
编写于
3月 05, 2021
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sync for c_broadcast
上级
779fde8d
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
150 addition
and
8 deletion
+150
-8
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+2
-0
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+147
-8
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
e79f24f8
...
...
@@ -37,6 +37,7 @@ message ShardingConfig {
optional
bool
use_pipeline
=
6
[
default
=
false
];
optional
int32
acc_steps
=
7
[
default
=
1
];
optional
int32
schedule_mode
=
8
[
default
=
0
];
optional
int32
pp_bz
=
9
[
default
=
1
];
}
message
AMPConfig
{
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
e79f24f8
...
...
@@ -98,6 +98,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"acc_steps"
]
self
.
schedule_mode
=
self
.
user_defined_strategy
.
sharding_configs
[
"schedule_mode"
]
self
.
pp_bz
=
self
.
user_defined_strategy
.
sharding_configs
[
"pp_bz"
]
if
self
.
inner_opt
is
None
:
raise
ValueError
(
...
...
@@ -108,6 +109,7 @@ class ShardingOptimizer(MetaOptimizerBase):
main_program
=
loss
.
block
.
program
main_program
.
_pipeline_opt
=
dict
()
main_program
.
_pipeline_opt
[
'schedule_mode'
]
=
self
.
schedule_mode
main_program
.
_pipeline_opt
[
'pp_bz'
]
=
self
.
pp_bz
pp_rank
=
self
.
role_maker
.
_worker_index
()
//
(
self
.
user_defined_strategy
.
sharding_configs
[
'sharding_group_size'
]
*
self
.
_inner_parallelism_size
)
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
e79f24f8
...
...
@@ -4416,7 +4416,11 @@ class PipelineOptimizer(object):
var
=
block
.
var
(
var_name
)
# skip data, because we will process it later
if
var
.
is_data
:
continue
prev_device
=
None
if
var_name
in
self
.
_param_device_map
:
prev_device
=
self
.
_param_device_map
[
var_name
]
prev_op
=
self
.
_find_real_prev_op
(
block
.
ops
,
op
,
var_name
)
if
not
pre_device
:
prev_device
=
prev_op
.
attr
(
self
.
_op_device_key
)
\
if
prev_op
else
None
if
not
prev_device
or
prev_device
==
'gpu:all'
:
continue
...
...
@@ -4494,6 +4498,20 @@ class PipelineOptimizer(object):
'root'
:
0
,
})
extra_index
+=
1
block
.
_insert_op
(
index
=
index
+
extra_index
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
self
.
_op_device_key
:
cur_device
,
self
.
_op_role_key
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
,
#self._op_role_key: op_role,
'ring_id'
:
ring_id
,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index
+=
1
#block._insert_op(
# index=index + extra_index,
# type='c_sync_comm_stream',
...
...
@@ -4508,7 +4526,7 @@ class PipelineOptimizer(object):
# })
#extra_index += 1
fill_shape
=
list
(
var
.
shape
)
fill_shape
[
0
]
=
1
fill_shape
[
0
]
=
self
.
pp_bz
block
.
_insert_op
(
index
=
index
+
extra_index
,
#type='recv_v2',
...
...
@@ -4523,6 +4541,19 @@ class PipelineOptimizer(object):
'value'
:
float
(
0.0
),
})
extra_index
+=
1
block
.
_insert_op
(
index
=
index
+
extra_index
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
self
.
_op_device_key
:
cur_device
,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self
.
_op_role_key
:
op_role
,
'ring_id'
:
ring_id
,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index
+=
1
block
.
_insert_op
(
index
=
index
+
extra_index
,
#type='recv_v2',
...
...
@@ -4591,8 +4622,12 @@ class PipelineOptimizer(object):
# continue
#input_var_to_device[var_name].append(cur_device)
prev_device
=
None
generate_ops
=
output_var_to_op
.
get
(
var_name
)
if
generate_ops
is
None
:
continue
if
generate_ops
is
None
:
if
var_name
not
in
self
.
_param_device_map
:
continue
prev_device
=
self
.
_param_device_map
[
var_name
]
prev_op
=
None
for
gen_op
,
gen_idx
in
reversed
(
generate_ops
):
...
...
@@ -4600,6 +4635,7 @@ class PipelineOptimizer(object):
prev_op
=
gen_op
break
if
not
prev_device
:
prev_device
=
prev_op
.
attr
(
self
.
_op_device_key
)
\
if
prev_op
else
None
...
...
@@ -5134,6 +5170,7 @@ class PipelineOptimizer(object):
if
'schedule_mode'
in
main_block
.
program
.
_pipeline_opt
:
schedule_mode
=
main_block
.
program
.
_pipeline_opt
[
'schedule_mode'
]
self
.
schedule_mode
=
schedule_mode
self
.
pp_bz
=
main_block
.
program
.
_pipeline_opt
[
'pp_bz'
]
self
.
use_sharding
=
False
if
'use_sharding'
in
main_block
.
program
.
_pipeline_opt
:
...
...
@@ -5175,15 +5212,117 @@ class PipelineOptimizer(object):
# send and recv ops for data var.
main_program
=
main_block
.
program
program_list
=
self
.
_split_program
(
main_program
,
device_list
)
#cur_device_index = 0
#device_num = len(program_list)
for
p
in
program_list
:
self
.
_create_vars
(
p
[
"program"
].
block
(
0
),
main_block
)
# # Add send/recv pair to sync the execution.
# block = p['program'].block(0)
# prev_device_index = cur_device_index - 1
# next_device_index = cur_device_index + 1
# add_send_for_forward = False
# add_send_for_backward = False
# add_recv_for_backward = False
# extra_index = 0
# new_var = block.create_var(
# name=unique_name.generate('sync'),
# shape=[1],
# dtype='float32',
# persistable=False,
# stop_gradient=True)
# block._insert_op(
# index=0,
# type='fill_constant',
# inputs={},
# outputs={'Out': [new_var]},
# attrs={
# 'shape': [1],
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Forward,
# 'value': float(0.0),
# })
# extra_index += 1
# for op_idx, op in enumerate(list(block.ops)):
# if op_idx == extra_index:
# if cur_device_index > 0:
# pair_key = prev_device_index * 1000 + cur_device_index
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx,
# type='recv_v2',
# outputs={'Out': [new_var]},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Forward,
# 'peer': 0,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# extra_index += 1
# continue
# if op.type == "send_v2" and self._is_forward_op(op) \
# and not add_send_for_forward \
# and cur_device_index < device_num - 1:
# add_send_for_forward = True
# pair_key = cur_device_index * 1000 + next_device_index
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx + extra_index,
# type='send_v2',
# inputs={'Out': new_var},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Forward,
# 'peer': 1,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# extra_index += 1
# if self._is_backward_op(op) and not add_recv_for_backward \
# and cur_device_index < device_num - 1:
# pair_key = next_device_index * 1000 + cur_device_index
# add_recv_for_backward = True
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx + extra_index,
# type='recv_v2',
# outputs={'Out': [new_var]},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Backward,
# 'peer': 0,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# if op.type == "send_v2" and self._is_backward_op(op) \
# and not add_send_for_backward \
# and cur_device_index > 0:
# pair_key = cur_device_index * 1000 + prev_device_index
# add_send_for_backward = True
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx + extra_index,
# type='send_v2',
# outputs={'Out': [new_var]},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Backward,
# 'peer': 1,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# cur_device_index += 1
#self._insert_sendrecv_for_data_var(main_block, program_list,
# startup_program, device_list)
# Step4: Special Case: process persistable vars that exist in
# multiple sections
self
.
_process_persistable_vars_in_multi_sections
(
main_program
,
startup_program
,
program_list
)
#
self._process_persistable_vars_in_multi_sections(
#
main_program, startup_program, program_list)
# Step5: Add sub blocks for section programs
self
.
_add_sub_blocks
(
main_block
,
program_list
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录