Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c3974d0e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c3974d0e
编写于
3月 26, 2021
作者:
L
lilong12
提交者:
GitHub
3月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[3D-parallel] Reformat pipeline parallel (#31786)
* update, test=develop
上级
01aa2526
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
816 addition
and
569 deletion
+816
-569
paddle/fluid/framework/section_worker.cc
paddle/fluid/framework/section_worker.cc
+10
-10
python/paddle/distributed/fleet/meta_optimizers/common.py
python/paddle/distributed/fleet/meta_optimizers/common.py
+38
-3
python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
...e/distributed/fleet/meta_optimizers/pipeline_optimizer.py
+136
-172
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
+7
-3
python/paddle/fluid/device_worker.py
python/paddle/fluid/device_worker.py
+1
-1
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+15
-8
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+589
-365
python/paddle/fluid/tests/unittests/pipeline_mnist.py
python/paddle/fluid/tests/unittests/pipeline_mnist.py
+20
-7
未找到文件。
paddle/fluid/framework/section_worker.cc
浏览文件 @
c3974d0e
...
...
@@ -39,13 +39,13 @@ void SectionWorker::RunForward(
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool
run_first_mbatch
=
op_role
==
static_cast
<
int
>
(
OpRole
::
kForward
)
||
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
))
||
op_role
==
static_cast
<
int
>
(
OpRole
::
kLRSched
);
bool
run_others
=
op_role
==
static_cast
<
int
>
(
OpRole
::
kForward
)
||
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
));
bool
run_first_mbatch
=
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kForward
)
)
||
(
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)
))
||
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kLRSched
)
);
bool
run_others
=
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kForward
)
)
||
(
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)
));
if
((
micro_id
==
0
&&
run_first_mbatch
)
||
(
micro_id
!=
0
&&
run_others
))
{
VLOG
(
3
)
<<
"Forward: running op "
<<
op
->
Type
()
<<
" for micro-batch "
<<
micro_id
;
...
...
@@ -64,9 +64,9 @@ void SectionWorker::RunBackward(
&
unused_vars_
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)))
{
if
(
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
)
||
(
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)
)))
{
VLOG
(
3
)
<<
"Backward: running op "
<<
op
->
Type
()
<<
" for micro-batch "
<<
micro_id
;
op
->
Run
(
*
microbatch_scopes_
[
micro_id
],
place_
);
...
...
python/paddle/distributed/fleet/meta_optimizers/common.py
浏览文件 @
c3974d0e
...
...
@@ -47,7 +47,7 @@ def is_optimizer_op(op):
class
CollectiveHelper
(
object
):
def
__init__
(
self
,
role_maker
,
nrings
=
1
,
wait_port
=
'6174'
):
def
__init__
(
self
,
role_maker
,
nrings
=
1
,
wait_port
=
True
):
self
.
nrings
=
nrings
self
.
wait_port
=
wait_port
self
.
role_maker
=
role_maker
...
...
@@ -65,14 +65,48 @@ class CollectiveHelper(object):
self
.
role_maker
.
_worker_index
(),
ring_id
,
self
.
wait_port
)
self
.
_broadcast_params
()
def
_init_communicator
(
self
,
program
,
current_endpoint
,
endpoints
,
rank
,
ring_id
,
wait_port
):
def
_init_communicator
(
self
,
program
,
current_endpoint
,
endpoints
,
rank
,
ring_id
,
wait_port
,
global_ring_id
=
None
,
sync
=
True
):
nranks
=
len
(
endpoints
)
other_endpoints
=
endpoints
[:]
other_endpoints
.
remove
(
current_endpoint
)
if
rank
==
0
and
wait_port
:
wait_server_ready
(
other_endpoints
)
def
_add_sync_by_allreduce
(
block
):
sync_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'sync_var'
),
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
persistable
=
False
,
stop_gradient
=
True
)
block
.
append_op
(
type
=
'fill_constant'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
sync_var
]},
attrs
=
{
'shape'
:
[
1
],
'dtype'
:
sync_var
.
dtype
,
'value'
:
1
,
'force_cpu'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
sync_var
]},
outputs
=
{
'Out'
:
[
sync_var
]},
attrs
=
{
'ring_id'
:
global_ring_id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
block
=
program
.
global_block
()
if
core
.
is_compiled_with_cuda
():
comm_id_var
=
block
.
create_var
(
...
...
@@ -128,6 +162,7 @@ class CollectiveHelper(object):
raise
ValueError
(
"comm_id must be generated in paddlepaddle-xpu or paddlepaddle-xpu."
)
if
sync
:
_add_sync_by_allreduce
(
block
)
def
_wait
(
self
,
current_endpoint
,
endpoints
):
assert
(
self
.
wait_port
)
...
...
python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
浏览文件 @
c3974d0e
...
...
@@ -19,130 +19,21 @@ from paddle.fluid import core, unique_name
from
..base.private_helper_function
import
wait_server_ready
from
paddle.fluid.optimizer
import
PipelineOptimizer
as
PO
from
.meta_optimizer_base
import
MetaOptimizerBase
from
.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
,
CollectiveHelper
,
is_update_op
,
is_loss_grad_op
,
is_backward_op
,
is_optimizer_op
def
_get_node_num
(
endpoints
):
ss
=
set
()
for
ep
in
endpoints
:
ip
=
ep
.
split
(
":"
)[
0
].
strip
()
if
ip
not
in
ss
:
ss
.
add
(
ip
)
return
len
(
ss
)
class
PipelineHelper
(
object
):
def
__init__
(
self
,
role_maker
,
wait_port
=
'6174'
):
self
.
wait_port
=
wait_port
self
.
role_maker
=
role_maker
def
update_startup_program
(
self
,
startup_program
=
None
,
inner_parallelism
=
None
):
self
.
startup_program
=
startup_program
nranks
=
self
.
role_maker
.
_worker_num
()
rank
=
self
.
role_maker
.
_worker_index
()
endpoints
=
self
.
role_maker
.
_get_trainer_endpoints
()
current_endpoint
=
endpoints
[
rank
]
node_num
=
_get_node_num
(
endpoints
)
assert
nranks
%
node_num
==
0
# Create ring 0 for all gpus in the same pipeline
if
inner_parallelism
>
1
:
pipeline_rank
=
rank
%
inner_parallelism
pipeline_id
=
rank
//
inner_parallelism
start_index
=
pipeline_id
*
inner_parallelism
pipeline_endpoints
=
endpoints
[
start_index
:
start_index
+
inner_parallelism
]
self
.
_init_communicator
(
self
.
startup_program
,
current_endpoint
,
pipeline_endpoints
,
pipeline_rank
,
0
,
self
.
wait_port
)
pipeline_num
=
len
(
endpoints
)
//
inner_parallelism
if
pipeline_num
==
1
:
return
# Create rings for gpus with the same pipeline id for data parallel
eps
=
[]
pipeline_rank
=
rank
%
inner_parallelism
ring_id
=
pipeline_rank
+
1
for
i
in
range
(
pipeline_num
):
eps
.
append
(
endpoints
[
i
*
inner_parallelism
+
pipeline_rank
])
# rank in a ring of gpus with the same pipeline id for data parallel
dp_rank
=
rank
//
inner_parallelism
self
.
_init_communicator
(
self
.
startup_program
,
current_endpoint
,
eps
,
dp_rank
,
ring_id
,
self
.
wait_port
)
self
.
_broadcast_params
(
ring_id
)
def
_init_communicator
(
self
,
program
,
current_endpoint
,
endpoints
,
rank
,
ring_id
,
wait_port
):
nranks
=
len
(
endpoints
)
other_endpoints
=
endpoints
[:]
other_endpoints
.
remove
(
current_endpoint
)
if
rank
==
0
and
wait_port
:
wait_server_ready
(
other_endpoints
)
block
=
program
.
global_block
()
nccl_id_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'nccl_id'
),
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
block
.
append_op
(
type
=
'c_gen_nccl_id'
,
inputs
=
{},
outputs
=
{
'Out'
:
nccl_id_var
},
attrs
=
{
'rank'
:
rank
,
'endpoint'
:
current_endpoint
,
'other_endpoints'
:
other_endpoints
,
OP_ROLE_KEY
:
OpRole
.
Forward
,
})
block
.
append_op
(
type
=
'c_comm_init'
,
inputs
=
{
'X'
:
nccl_id_var
},
outputs
=
{},
attrs
=
{
'nranks'
:
nranks
,
'rank'
:
rank
,
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
,
})
def
_broadcast_params
(
self
,
ring_id
):
block
=
self
.
startup_program
.
global_block
()
for
var_name
in
block
.
vars
:
if
"nccl_id"
in
var_name
:
continue
param
=
block
.
var
(
var_name
)
if
not
param
.
persistable
:
continue
block
.
append_op
(
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
param
},
outputs
=
{
'Out'
:
param
},
attrs
=
{
'ring_id'
:
ring_id
,
'root'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
block
.
append_op
(
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
param
},
outputs
=
{
'Out'
:
param
},
attrs
=
{
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
from
.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
,
CollectiveHelper
,
is_loss_grad_op
,
is_backward_op
,
is_optimizer_op
class
PipelineOptimizer
(
MetaOptimizerBase
):
def
__init__
(
self
,
optimizer
):
super
(
PipelineOptimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[
"RecomputeOptimizer"
,
"AMPOptimizer"
,
]
self
.
meta_optimizers_black_list
=
[
"GraphExecutionOptimizer"
,
]
self
.
global_ring_id
=
1
self
.
dp_ring_id
=
2
self
.
start_pipeline_ring_id
=
20
# Just a magic number
def
_set_basic_info
(
self
,
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
):
...
...
@@ -165,7 +56,11 @@ class PipelineOptimizer(MetaOptimizerBase):
def
_disable_strategy
(
self
,
dist_strategy
):
dist_strategy
.
pipeline
=
False
dist_strategy
.
pipeline_configs
=
{}
dist_strategy
.
pipeline_configs
=
{
"micro_batch_size"
:
1
,
"accumulate_steps"
:
1
,
"schedule_mode"
:
"1F1B"
,
}
def
_enable_strategy
(
self
,
dist_strategy
,
context
):
dist_strategy
.
pipeline
=
True
...
...
@@ -175,61 +70,134 @@ class PipelineOptimizer(MetaOptimizerBase):
"schedule_mode"
:
"1F1B"
,
}
def
_broadcast_params
(
self
,
ring_id
):
block
=
self
.
startup_program
.
global_block
()
param
=
None
for
param
in
block
.
iter_parameters
():
if
param
.
is_distributed
:
continue
block
.
append_op
(
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
param
},
outputs
=
{
'Out'
:
param
},
attrs
=
{
'ring_id'
:
ring_id
,
'root'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
if
not
param
:
return
# no parameter on this device
block
.
append_op
(
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
param
},
outputs
=
{
'Out'
:
param
},
attrs
=
{
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
def
_get_process_group_info
(
self
):
# global ring info
self
.
global_endpoints
=
self
.
endpoints
self
.
global_rank
=
self
.
rank
self
.
global_nranks
=
self
.
nranks
# data parallel ring info
if
self
.
pipeline_num
>
1
:
self
.
dp_rank
=
self
.
rank
//
self
.
inner_parallelism
self
.
dp_nranks
=
self
.
nranks
//
self
.
inner_parallelism
start_index
=
self
.
rank
%
self
.
inner_parallelism
self
.
dp_endpoints
=
[
self
.
endpoints
[
start_index
+
i
*
self
.
inner_parallelism
]
for
i
in
range
(
self
.
pipeline_num
)
]
def
_init_process_group
(
self
,
pipeline_pair
,
pipeline_ring_map
):
self
.
_get_process_group_info
()
collective_helper
=
CollectiveHelper
(
self
.
role_maker
,
wait_port
=
False
)
# Create global ring for all gpus (ring_id = 0)
collective_helper
.
_init_communicator
(
self
.
startup_program
,
self
.
current_endpoint
,
self
.
global_endpoints
,
self
.
global_rank
,
self
.
global_ring_id
,
True
,
self
.
global_ring_id
,
True
)
# Create pipeline rings
if
self
.
inner_parallelism
>
1
:
pipeline_id
=
self
.
rank
//
self
.
inner_parallelism
start_index
=
pipeline_id
*
self
.
inner_parallelism
for
pair
in
pipeline_pair
:
pair_key
=
pair
[
0
]
*
1000
+
pair
[
1
]
ring_id
=
pipeline_ring_map
[
pair_key
]
assert
ring_id
>=
self
.
start_pipeline_ring_id
first_node
=
pair
[
0
]
+
start_index
second_node
=
pair
[
1
]
+
start_index
if
self
.
rank
!=
first_node
and
self
.
rank
!=
second_node
:
continue
pipeline_endpoints
=
[
self
.
endpoints
[
first_node
],
self
.
endpoints
[
second_node
]
]
pipeline_rank
=
0
if
self
.
rank
==
first_node
else
1
pipeline_nranks
=
2
collective_helper
.
_init_communicator
(
self
.
startup_program
,
self
.
current_endpoint
,
pipeline_endpoints
,
pipeline_rank
,
ring_id
,
False
,
self
.
global_ring_id
,
True
)
# Create dp rings
if
self
.
pipeline_num
>
1
:
collective_helper
.
_init_communicator
(
self
.
startup_program
,
self
.
current_endpoint
,
self
.
dp_endpoints
,
self
.
dp_rank
,
self
.
dp_ring_id
,
True
,
self
.
global_ring_id
,
True
)
self
.
_broadcast_params
(
self
.
dp_ring_id
)
def
minimize_impl
(
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
endpoints
=
self
.
role_maker
.
_get_trainer_endpoints
()
current_endpoint
=
endpoints
[
self
.
role_maker
.
_worker_index
()]
self
.
wrapped_opt
=
PO
(
self
.
inner_opt
,
num_microbatches
=
self
.
num_microbatches
)
node_num
=
_get_node_num
(
endpoints
)
gpus_per_node
=
len
(
endpoints
)
//
node_num
self
.
startup_program
=
startup_program
if
startup_program
is
None
:
self
.
startup_program
=
fluid
.
default_startup_program
()
self
.
endpoints
=
self
.
role_maker
.
_get_trainer_endpoints
()
self
.
current_endpoint
=
self
.
endpoints
[
self
.
role_maker
.
_worker_index
()]
self
.
rank
=
self
.
role_maker
.
_worker_index
()
self
.
nranks
=
self
.
role_maker
.
_worker_num
()
assert
self
.
nranks
%
node_num
==
0
loss
.
block
.
program
.
_pipeline_opt
=
dict
()
loss
.
block
.
program
.
_pipeline_opt
[
'local_rank'
]
=
self
.
rank
loss
.
block
.
program
.
_pipeline_opt
[
'micro_batch_size'
]
=
self
.
micro_batch_size
loss
.
block
.
program
.
_pipeline_opt
[
'schedule_mode'
]
=
self
.
schedule_mode
optimize_ops
,
params_grads
,
prog_list
=
self
.
wrapped_opt
.
minimize
(
self
.
wrapped_opt
=
PO
(
self
.
inner_opt
,
num_microbatches
=
self
.
num_microbatches
)
orig_startup_program
=
startup_program
if
startup_program
else
fluid
.
default_startup_program
(
)
block
=
loss
.
block
program
=
block
.
program
program
.
_pipeline_opt
=
dict
()
program
.
_pipeline_opt
[
'local_rank'
]
=
self
.
rank
program
.
_pipeline_opt
[
'global_ring_id'
]
=
self
.
global_ring_id
program
.
_pipeline_opt
[
'ring_id'
]
=
self
.
start_pipeline_ring_id
program
.
_pipeline_opt
[
'micro_batch_size'
]
=
self
.
micro_batch_size
program
.
_pipeline_opt
[
'schedule_mode'
]
=
self
.
schedule_mode
optimize_ops
,
params_grads
,
prog_list
,
pp_pair
,
ring_map
=
self
.
wrapped_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
assert
prog_list
self
.
main_program_list
=
prog_list
self
.
main_program
=
loss
.
block
.
program
self
.
inner_parallelism
=
loss
.
block
.
program
.
_pipeline_opt
[
'inner_parallelism'
]
self
.
startup_program
=
orig_startup_program
.
_pipeline_opt
[
'startup_program'
]
self
.
inner_parallelism
=
program
.
_pipeline_opt
[
'inner_parallelism'
]
assert
self
.
nranks
%
self
.
inner_parallelism
==
0
assert
prog_list
self
.
pipeline_num
=
len
(
self
.
endpoints
)
//
self
.
inner_parallelism
pipeline_helper
=
PipelineHelper
(
self
.
role_maker
)
pipeline_helper
.
update_startup_program
(
self
.
startup_program
.
_pipeline_opt
[
"startup_program"
],
self
.
inner_parallelism
)
self
.
_init_process_group
(
pp_pair
,
ring_map
)
pipeline_num
=
self
.
nranks
//
self
.
inner_parallelism
self
.
_transpile_main_program
(
loss
,
pipeline_num
,
self
.
inner_parallelism
)
self
.
main_program_list
=
prog_list
self
.
main_program
=
program
if
self
.
pipeline_num
>
1
:
self
.
_transpile_main_program
(
loss
)
return
optimize_ops
,
params_grads
def
_transpile_main_program
(
self
,
loss
,
pipeline_num
,
inner_parallelism
):
if
pipeline_num
<=
1
:
return
self
.
_insert_loss_grad_ops
(
loss
,
pipeline_num
)
for
ring_id
in
range
(
1
,
inner_parallelism
+
1
):
self
.
_insert_allreduce_ops
(
ring_id
)
def
_transpile_main_program
(
self
,
loss
):
self
.
_insert_loss_grad_ops
(
loss
,
self
.
pipeline_num
)
self
.
_insert_allreduce_ops
(
self
.
dp_ring_id
)
def
_insert_loss_grad_ops
(
self
,
loss
,
pipeline_num
):
"""
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
"""
block
=
self
.
main_program_list
[
-
1
]
[
'program'
]
.
global_block
()
block
=
self
.
main_program_list
[
-
1
].
global_block
()
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_loss_grad_op
(
op
):
loss_grad_var
=
block
.
vars
[
op
.
output_arg_names
[
0
]]
...
...
@@ -244,57 +212,53 @@ class PipelineOptimizer(MetaOptimizerBase):
})
def
_insert_allreduce_ops
(
self
,
ring_id
):
block
=
self
.
main_program_list
[
ring_id
-
1
][
'program'
].
global_block
()
block
=
self
.
main_program
.
_pipeline_opt
[
'section_program'
].
global_block
(
)
origin_block
=
self
.
main_program
.
global_block
()
grad
=
None
processed_param_name
=
set
()
first_optimize_op_idx
=
None
add_sync_calc_stream
=
False
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_backward_op
(
op
)
and
not
first_optimize_op_idx
:
first_optimize_op_idx
=
idx
+
1
# no optimize phase
if
first_optimize_op_idx
==
len
(
block
.
ops
):
return
if
is_backward_op
(
op
)
and
\
OP_ROLE_VAR_KEY
in
op
.
attr_names
:
op_role_var
=
op
.
all_attrs
()[
OP_ROLE_VAR_KEY
]
if
len
(
op_role_var
)
==
0
:
continue
assert
len
(
op_role_var
)
%
2
==
0
offset
=
idx
offset
=
0
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
param_name
=
op_role_var
[
i
]
param
=
block
.
vars
[
op_role_var
[
i
]]
if
param_name
in
processed_param_name
:
continue
processed_param_name
.
add
(
param_name
)
grad
=
block
.
vars
[
op_role_var
[
i
+
1
]]
grad_name
=
op_role_var
[
i
+
1
]
if
not
'MERGED'
in
grad_name
:
grad_name
+=
'@MERGED'
grad
=
block
.
vars
[
grad_name
]
origin_param
=
origin_block
.
vars
[
op_role_var
[
i
]]
if
origin_param
.
is_distributed
:
continue
if
offset
==
idx
:
offset
+=
1
if
not
add_sync_calc_stream
:
add_sync_calc_stream
=
True
block
.
_insert_op
(
offset
,
first_optimize_op_idx
+
offset
,
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
grad
},
outputs
=
{
'Out'
:
grad
},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Backward
})
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Optimize
})
offset
+=
1
block
.
_insert_op
(
offset
,
first_optimize_op_idx
+
offset
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
grad
},
outputs
=
{
'Out'
:
grad
},
attrs
=
{
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Backward
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
if
grad
is
None
:
return
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
is_optimizer_op
(
op
):
block
.
_insert_op
(
idx
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
grad
},
outputs
=
{
'Out'
:
grad
},
attrs
=
{
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
break
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
浏览文件 @
c3974d0e
...
...
@@ -123,7 +123,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
outputs
=
{
"Out"
:
out_var
},
attrs
=
{
"in_dtype"
:
in_var
.
dtype
,
"out_dtype"
:
out_var
.
dtype
"out_dtype"
:
out_var
.
dtype
,
"op_device"
:
op
.
attr
(
"op_device"
)
})
num_cast_ops
+=
1
_rename_arg
(
op
,
in_var
.
name
,
out_var
.
name
)
...
...
@@ -171,8 +172,11 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
type
=
"cast"
,
inputs
=
{
"X"
:
target_var
},
outputs
=
{
"Out"
:
cast_var
},
attrs
=
{
"in_dtype"
:
target_var
.
dtype
,
"out_dtype"
:
cast_var
.
dtype
})
attrs
=
{
"in_dtype"
:
target_var
.
dtype
,
"out_dtype"
:
cast_var
.
dtype
,
"op_device"
:
op
.
attr
(
"op_device"
)
})
num_cast_ops
+=
1
op_var_rename_map
[
block
.
idx
][
target_var
.
name
]
=
cast_var
.
name
...
...
python/paddle/fluid/device_worker.py
浏览文件 @
c3974d0e
...
...
@@ -427,7 +427,7 @@ class Section(DeviceWorker):
section_param
.
schedule_mode
=
schedule_mode
cfg
=
section_param
.
section_config
program
=
pipeline_opt
[
"section_program"
]
cfg
.
program_desc
.
ParseFromString
(
program
[
"program"
]
.
_get_desc
()
cfg
.
program_desc
.
ParseFromString
(
program
.
_get_desc
()
.
serialize_to_string
())
# TODO: why does not work
# cfg.program_desc.CopyFrom(program.program._get_desc())
...
...
python/paddle/fluid/executor.py
浏览文件 @
c3974d0e
...
...
@@ -1458,7 +1458,7 @@ class Executor(object):
dataset
.
_prepare_to_run
()
real_fetch_list
=
[]
if
program
.
_pipeline_opt
:
real_program
=
program
.
_pipeline_opt
[
"section_program"
]
[
'program'
]
real_program
=
program
.
_pipeline_opt
[
"section_program"
]
for
fetch_var
in
fetch_list
:
if
isinstance
(
fetch_var
,
Variable
):
fetch_var_name
=
fetch_var
.
name
...
...
@@ -1467,13 +1467,20 @@ class Executor(object):
if
fetch_var_name
in
real_program
.
global_block
().
vars
:
real_fetch_list
.
append
(
fetch_var
)
program
.
_pipeline_opt
[
"section_program"
][
'program'
]
=
self
.
_add_feed_fetch_ops
(
program
=
program
.
_pipeline_opt
[
"section_program"
][
'program'
],
feed
=
[],
fetch_list
=
real_fetch_list
,
feed_var_name
=
'feed'
,
fetch_var_name
=
'fetch'
)
program
.
_pipeline_opt
[
"section_program"
]
=
self
.
_add_feed_fetch_ops
(
program
=
program
.
_pipeline_opt
[
"section_program"
],
feed
=
[],
fetch_list
=
real_fetch_list
,
feed_var_name
=
'feed'
,
fetch_var_name
=
'fetch'
)
main_block
=
program
.
_pipeline_opt
[
"section_program"
].
block
(
0
)
for
op
in
main_block
.
ops
:
# set the op_role of fetch op to Optimize to avoid
# erase the fetched vars by gc for pipeline
if
op
.
type
==
'fetch'
:
op
.
_set_attr
(
'op_role'
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
)
fetch_list
=
None
scope
,
trainer
=
self
.
_prepare_trainer
(
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
c3974d0e
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/pipeline_mnist.py
浏览文件 @
c3974d0e
...
...
@@ -66,12 +66,21 @@ def cnn_model(data):
param_shape
=
[
reduce
(
lambda
a
,
b
:
a
*
b
,
input_shape
[
1
:],
1
)]
+
[
SIZE
]
scale
=
(
2.0
/
(
param_shape
[
0
]
**
2
*
SIZE
))
**
0.5
predict
=
fluid
.
layers
.
fc
(
input
=
conv_pool_2
,
size
=
SIZE
,
act
=
"softmax"
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
)))
with
fluid
.
device_guard
(
"gpu:1"
):
predict
=
fluid
.
layers
.
fc
(
input
=
conv_pool_2
,
size
=
SIZE
,
act
=
"softmax"
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
)))
# To cover @RENAMED@GRADIENT
predict2
=
fluid
.
layers
.
fc
(
input
=
conv_pool_1
,
size
=
SIZE
,
act
=
"softmax"
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
)))
predict
+=
predict2
return
predict
...
...
@@ -108,7 +117,10 @@ class TestDistMnist2x2(TestDistRunnerBase):
bd
=
[
steps_per_pass
*
p
for
p
in
passes
]
lr
=
[
base_lr
*
(
0.1
**
i
)
for
i
in
range
(
len
(
bd
)
+
1
)]
lr_val
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
bd
,
values
=
lr
)
opt
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
lr_val
,
momentum
=
0.9
)
opt
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
lr_val
,
momentum
=
0.9
,
grad_clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
1.0
))
acc_steps
=
2
# accumulated steps for pipeline
if
dist_strategy
:
...
...
@@ -120,6 +132,7 @@ class TestDistMnist2x2(TestDistRunnerBase):
fleet
.
init
(
is_collective
=
True
)
strategy
=
fleet
.
DistributedStrategy
()
strategy
.
pipeline
=
True
strategy
.
amp
=
True
strategy
.
pipeline_configs
=
{
'micro_batch_size'
:
batch_size
,
'schedule_mode'
:
'1F1B'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录