Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
695dd371
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
695dd371
编写于
3月 31, 2021
作者:
L
lilong12
提交者:
GitHub
3月 31, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adjust pipeline optimizer for 3d parallelism (#31939)
* update, test=develop
上级
6f85e241
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
168 addition
and
161 deletion
+168
-161
paddle/fluid/framework/pipeline_trainer.cc
paddle/fluid/framework/pipeline_trainer.cc
+3
-24
python/paddle/distributed/fleet/meta_optimizers/common.py
python/paddle/distributed/fleet/meta_optimizers/common.py
+5
-0
python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
...e/distributed/fleet/meta_optimizers/pipeline_optimizer.py
+1
-10
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+159
-127
未找到文件。
paddle/fluid/framework/pipeline_trainer.cc
浏览文件 @
695dd371
...
...
@@ -71,37 +71,16 @@ void PipelineTrainer::CopyParameters(int microbatch_id,
const
ProgramDesc
&
program
,
const
platform
::
Place
&
place
)
{
auto
&
global_block
=
program
.
Block
(
0
);
std
::
map
<
std
::
string
,
int
>
param_map
;
for
(
auto
&
var
:
global_block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
param_map
[
var
->
Name
()]
=
1
;
}
}
for
(
auto
&
var
:
global_block
.
AllVars
())
{
bool
is_param_grad
=
false
;
size_t
pos
=
0
;
// A magic suffix to indicate the merged gradient
std
::
string
magicSuffix
=
std
::
string
(
kGradVarSuffix
)
+
"@MERGED"
;
if
((
pos
=
var
->
Name
().
find
(
magicSuffix
))
!=
std
::
string
::
npos
)
{
auto
prefix_name
=
var
->
Name
().
substr
(
0
,
pos
);
if
(
param_map
.
find
(
prefix_name
)
!=
param_map
.
end
())
{
is_param_grad
=
true
;
}
}
if
(
var
->
Persistable
()
&&
microbatch_id
==
0
)
{
auto
*
ptr
=
root_scope_
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create persistable var: "
<<
var
->
Name
()
<<
", which pointer is "
<<
ptr
;
}
else
if
(
is_param_grad
&&
microbatch_id
==
0
)
{
auto
*
ptr
=
minibatch_scope_
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
3
)
<<
"Create grad for persistable var: "
<<
var
->
Name
()
VLOG
(
5
)
<<
"Create persistable var: "
<<
var
->
Name
()
<<
", which pointer is "
<<
ptr
;
}
else
if
(
!
var
->
Persistable
()
&&
!
is_param_grad
)
{
}
else
if
(
!
var
->
Persistable
())
{
auto
*
ptr
=
microbatch_scopes_
[
microbatch_id
]
->
Var
(
var
->
Name
());
VLOG
(
3
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for microbatch "
VLOG
(
5
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for microbatch "
<<
microbatch_id
<<
", which pointer is "
<<
ptr
;
InitializeVariable
(
ptr
,
var
->
GetType
());
}
...
...
python/paddle/distributed/fleet/meta_optimizers/common.py
浏览文件 @
695dd371
...
...
@@ -106,6 +106,11 @@ class CollectiveHelper(object):
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
block
.
append_op
(
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
sync_var
},
outputs
=
{
'Out'
:
sync_var
},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Forward
})
block
=
program
.
global_block
()
if
core
.
is_compiled_with_cuda
():
...
...
python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
浏览文件 @
695dd371
...
...
@@ -171,6 +171,7 @@ class PipelineOptimizer(MetaOptimizerBase):
program
.
_pipeline_opt
[
'ring_id'
]
=
self
.
start_pipeline_ring_id
program
.
_pipeline_opt
[
'micro_batch_size'
]
=
self
.
micro_batch_size
program
.
_pipeline_opt
[
'schedule_mode'
]
=
self
.
schedule_mode
program
.
_pipeline_opt
[
'use_sharding'
]
=
False
optimize_ops
,
params_grads
,
prog_list
,
pp_pair
,
ring_map
=
self
.
wrapped_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
self
.
startup_program
=
orig_startup_program
.
_pipeline_opt
[
...
...
@@ -218,7 +219,6 @@ class PipelineOptimizer(MetaOptimizerBase):
grad
=
None
processed_param_name
=
set
()
first_optimize_op_idx
=
None
add_sync_calc_stream
=
False
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_backward_op
(
op
)
and
not
first_optimize_op_idx
:
first_optimize_op_idx
=
idx
+
1
...
...
@@ -242,15 +242,6 @@ class PipelineOptimizer(MetaOptimizerBase):
origin_param
=
origin_block
.
vars
[
op_role_var
[
i
]]
if
origin_param
.
is_distributed
:
continue
if
not
add_sync_calc_stream
:
add_sync_calc_stream
=
True
block
.
_insert_op
(
first_optimize_op_idx
+
offset
,
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
grad
},
outputs
=
{
'Out'
:
grad
},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Optimize
})
offset
+=
1
block
.
_insert_op
(
first_optimize_op_idx
+
offset
,
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
695dd371
...
...
@@ -3805,7 +3805,6 @@ class PipelineOptimizer(object):
self
.
_param_device_map
=
None
self
.
_pipeline_pair
=
[]
self
.
_pp_ring_map
=
dict
()
self
.
_global_ring_id
=
None
# insert allreduce op to sync global information for global
# gradient clip and amp
...
...
@@ -3841,7 +3840,7 @@ class PipelineOptimizer(object):
inputs
=
{
'X'
:
temp_var
if
op
.
type
==
"reduce_any"
else
out_var
},
outputs
=
{
'Out'
:
temp_var
if
op
.
type
==
"reduce_any"
else
out_var
},
attrs
=
{
'ring_id'
:
self
.
_
global_ring_id
,
'ring_id'
:
self
.
global_ring_id
,
self
.
_op_role_key
:
self
.
_op_role
.
Optimize
,
'use_calc_stream'
:
True
})
...
...
@@ -3887,6 +3886,16 @@ class PipelineOptimizer(object):
reserved_x
.
append
(
input_name
)
op
.
desc
.
set_input
(
'X'
,
reserved_x
)
op
.
desc
.
set_output
(
'Out'
,
reserved_x
)
elif
op
.
type
==
'check_finite_and_unscale'
:
for
input_name
in
op
.
desc
.
input
(
"X"
):
if
block
.
_find_var_recursive
(
input_name
):
reserved_x
.
append
(
input_name
)
op
.
desc
.
set_input
(
'X'
,
reserved_x
)
op
.
desc
.
set_output
(
'Out'
,
reserved_x
)
if
len
(
reserved_x
)
==
0
:
block
.
_remove_op
(
op_idx
)
op_size
-=
1
continue
elif
op
.
type
==
'sum'
and
self
.
_is_gradient_clip_op
(
op
):
for
input_name
in
op
.
desc
.
input
(
"X"
):
if
block
.
_find_var_recursive
(
input_name
):
...
...
@@ -4020,63 +4029,32 @@ class PipelineOptimizer(object):
self
.
_create_vars
(
new_startup_program
.
global_block
(),
block
)
return
new_startup_program
def
_find_post_op
(
self
,
ops
,
cur_op
,
var_name
):
def
_find_post_op
(
self
,
index
,
var_name
):
"""
Find the real post op that has variable named var_name as input.
Args:
ops (list): A list of ops.
cur_op (Operator): Current operator which has variable named
var_name as output.
var_name (string): Variable name.
Find the post op that has variable named var_name as input.
"""
# To skip the cast op added by amp which has no op_device set
if
'.cast_fp32'
in
var_name
:
var_name
=
var_name
.
replace
(
'.cast_fp32'
,
''
)
elif
'.cast_fp16'
in
var_name
:
var_name
=
var_name
.
replace
(
'.cast_fp16'
,
''
)
post_op
=
[]
before
=
True
for
op
in
ops
:
if
op
==
cur_op
:
before
=
False
continue
if
before
:
continue
for
in_var_name
in
op
.
input_arg_names
:
if
in_var_name
==
var_name
:
post_op
.
append
(
op
)
break
if
post_op
:
return
post_op
[
0
]
return
None
post_ops
=
self
.
input_var_to_op
[
var_name
]
if
post_ops
==
None
:
return
None
result_op
=
None
for
post_op
,
post_idx
in
reversed
(
post_ops
):
if
post_idx
>
index
:
result_op
=
post_op
break
return
result_op
def
_find_
real_prev_op
(
self
,
ops
,
cur_op
,
var_name
):
def
_find_
prev_op
(
self
,
index
,
var_name
):
"""
Find the real previous op that outputs variable named var_name.
Args:
ops (list): A list of ops.
cur_op (Operator): Current operator which has variable named
var_name as input.
var_name (string): Variable name.
Find the previous op of op with index that outputs
variable named var_name.
"""
prev_op
=
[
]
for
op
in
ops
:
if
op
.
type
==
'send_v2'
or
op
.
type
==
'recv_v2'
\
or
op
.
type
==
'c_broadcast'
:
continue
if
op
==
cur_op
:
prev_op
s
=
self
.
output_var_to_op
[
var_name
]
if
prev_ops
==
None
:
return
None
result_op
=
None
for
prev_op
,
prev_idx
in
reversed
(
prev_ops
)
:
if
prev_idx
<
index
:
result_op
=
prev_op
break
for
out_var_name
in
op
.
output_arg_names
:
if
out_var_name
==
var_name
:
prev_op
.
append
(
op
)
if
prev_op
:
# A op may have more than one prev op,
# e.g., for 'learning_rate', there may be multiple ops have it as
# output.
return
prev_op
[
-
1
]
return
None
return
result_op
def
_rename_arg
(
self
,
op
,
old_name
,
new_name
):
op
.
_rename_input
(
old_name
,
new_name
)
...
...
@@ -4136,23 +4114,21 @@ class PipelineOptimizer(object):
# For LRSched ops, we should put them on all sub-programs to
# make sure each sub-program update the lr correctly
op
.
_set_attr
(
self
.
_op_device_key
,
"gpu:all"
)
elif
(
op
.
type
==
"cast"
or
op
.
type
==
"scale"
)
and
self
.
_is_backward_op
(
op
):
prev_op
=
self
.
_find_real_prev_op
(
block
.
ops
,
op
,
op
.
desc
.
input
(
"X"
)[
0
])
elif
op
.
type
==
"scale"
and
self
.
_is_backward_op
(
op
):
prev_op
=
self
.
_find_prev_op
(
idx
,
op
.
desc
.
input
(
"X"
)[
0
])
op
.
_set_attr
(
self
.
_op_device_key
,
prev_op
.
attr
(
self
.
_op_device_key
))
elif
op
.
type
==
"memcpy"
and
not
self
.
_is_optimize_op
(
op
):
# for checkpoint offloading
assert
len
(
op
.
input_arg_names
)
==
1
and
len
(
op
.
output_arg_names
)
==
1
input_name
=
op
.
input_arg_names
[
0
]
output_name
=
op
.
output_arg_names
[
0
]
if
'@Fetch'
in
output_name
:
post_op
=
self
.
_find_post_op
(
block
.
ops
,
op
,
output_name
)
post_op
=
self
.
_find_post_op
(
idx
,
output_name
)
op
.
_set_attr
(
self
.
_op_device_key
,
post_op
.
attr
(
self
.
_op_device_key
))
else
:
prev_op
=
self
.
_find_real_prev_op
(
block
.
ops
,
op
,
op
.
desc
.
input
(
"X"
)[
0
])
prev_op
=
self
.
_find_prev_op
(
idx
,
op
.
desc
.
input
(
"X"
)[
0
])
op
.
_set_attr
(
self
.
_op_device_key
,
prev_op
.
attr
(
self
.
_op_device_key
))
elif
self
.
_is_loss_op
(
op
):
...
...
@@ -4165,16 +4141,11 @@ class PipelineOptimizer(object):
assert
device
,
"Please put you program within device_guard scope."
for
i
in
range
(
offset
):
block
.
ops
[
idx
+
i
].
_set_attr
(
self
.
_op_device_key
,
device
)
elif
self
.
_is_optimize_op
(
op
)
and
op
.
type
==
"check_finite_and_unscale"
:
op_role_var
=
op
.
attr
(
self
.
_op_role_var_key
)
param_name
=
op_role_var
[
0
]
device
=
self
.
_param_device_map
[
param_name
]
op
.
_set_attr
(
self
.
_op_device_key
,
device
)
elif
self
.
_is_optimize_op
(
op
)
and
op
.
type
==
"cast"
:
# For fp16-->fp32 cast added by AMP
grad_name
=
op
.
output
(
'Out'
)
assert
len
(
grad_name
)
==
1
param_name
=
grad_name
[
0
].
strip
(
core
.
grad_var_suffix
()
)
param_name
=
self
.
_strip_grad_suffix
(
grad_name
[
0
]
)
device
=
self
.
_param_device_map
[
param_name
]
op
.
_set_attr
(
self
.
_op_device_key
,
device
)
elif
self
.
_is_gradient_clip_op
(
op
)
or
self
.
_is_regularization_op
(
op
):
...
...
@@ -4197,7 +4168,11 @@ class PipelineOptimizer(object):
op
.
_set_attr
(
self
.
_op_device_key
,
device
)
else
:
other_known_ops
=
[
'update_loss_scaling'
,
'reduce_any'
,
'concat'
,
'sum'
'update_loss_scaling'
,
'reduce_any'
,
'concat'
,
'sum'
,
'check_finite_and_unscale'
,
]
assert
op
.
type
in
other_known_ops
,
"For other ops without "
\
"op_device set, they must be one of {}, but it "
\
...
...
@@ -4274,41 +4249,70 @@ class PipelineOptimizer(object):
Insert a pair of send and recv ops for every two
consecutive ops on different devices.
"""
extra_index
=
0
extra_index
_info
=
{
'index'
:
0
}
# A map from var to device where op takes it as input,
# avoiding multiple send and recv ops.
var_dev_map
=
dict
()
input_var_to_device
=
dict
()
for
index
,
op
in
enumerate
(
list
(
block
.
ops
)):
cur_device
=
op
.
attr
(
self
.
_op_device_key
)
if
cur_device
==
"gpu:all"
:
continue
for
var_name
in
op
.
input_arg_names
:
# i.e., lod_tensor_blocking_queue created by DataLoader,
# which only exists in startup program.
var
=
block
.
var
(
var_name
)
# skip data
, because we will process it late
r
# skip data
va
r
if
var
.
is_data
:
continue
prev_device
=
None
if
var_name
in
self
.
_param_device_map
:
generate_ops
=
self
.
output_var_to_op
.
get
(
var_name
)
if
generate_ops
is
None
:
if
var_name
not
in
self
.
_param_device_map
:
continue
prev_device
=
self
.
_param_device_map
[
var_name
]
prev_op
=
self
.
_find_real_prev_op
(
block
.
ops
,
op
,
var_name
)
prev_op
=
self
.
_find_prev_op
(
index
,
var_name
)
if
not
prev_device
:
prev_device
=
prev_op
.
attr
(
self
.
_op_device_key
)
\
if
prev_op
else
None
if
not
prev_device
or
prev_device
==
'gpu:all'
:
continue
if
prev_device
!=
cur_device
:
if
var_name
not
in
var_dev_map
:
var_dev_map
[
var_name
]
=
[]
if
cur_device
in
var_dev_map
[
var_name
]:
continue
var_dev_map
[
var_name
].
append
(
cur_device
)
if
prev_device
is
None
or
prev_device
==
"gpu:all"
:
continue
if
prev_device
==
cur_device
:
continue
op_role
=
op
.
all_attrs
()[
self
.
_op_role_key
]
if
var_name
not
in
input_var_to_device
:
input_var_to_device
[
var_name
]
=
[]
if
(
cur_device
,
prev_device
)
in
input_var_to_device
[
var_name
]:
continue
device_type
=
cur_device
.
split
(
':'
)[
0
]
+
':'
def
_insert_send_recv
(
cur_id
,
prev_id
):
cur_dev
=
device_type
+
str
(
cur_id
)
prev_dev
=
device_type
+
str
(
prev_id
)
if
(
cur_dev
,
prev_dev
)
in
input_var_to_device
[
var_name
]:
return
if
cur_id
-
prev_id
>
1
:
_insert_send_recv
(
cur_id
-
1
,
prev_id
)
_insert_send_recv
(
cur_id
,
cur_id
-
1
)
input_var_to_device
[
var_name
].
append
(
(
cur_dev
,
prev_dev
))
return
elif
cur_id
-
prev_id
<
-
1
:
_insert_send_recv
(
cur_id
+
1
,
prev_id
)
_insert_send_recv
(
cur_id
,
cur_id
+
1
)
input_var_to_device
[
var_name
].
append
(
(
cur_dev
,
prev_dev
))
return
assert
abs
(
cur_id
-
prev_id
)
==
1
input_var_to_device
[
var_name
].
append
((
cur_dev
,
prev_dev
))
op_role
=
op
.
attr
(
self
.
_op_role_key
)
var
=
block
.
vars
[
var_name
]
prev_device_index
=
int
(
prev_device
.
split
(
':'
)[
1
])
cur_device_index
=
int
(
cur_device
.
split
(
':'
)[
1
])
pair
=
(
prev_device_index
,
cur_device_index
)
pair_key
=
prev_device_index
*
1000
+
cur_device_index
pair
=
(
prev_id
,
cur_id
)
# 1000 is just a magic number
pair_key
=
prev_id
*
1000
+
cur_id
if
pair
not
in
self
.
_pipeline_pair
:
self
.
_pipeline_pair
.
append
(
pair
)
self
.
_pp_ring_map
[
pair_key
]
=
self
.
ring_id
...
...
@@ -4316,89 +4320,95 @@ class PipelineOptimizer(object):
self
.
ring_id
+=
1
else
:
ring_id
=
self
.
_pp_ring_map
[
pair_key
]
if
self
.
schedule_mode
==
'F-then-B'
:
# F-then-B
block
.
_insert_op
(
index
=
index
+
extra_index
,
index
=
index
+
extra_index
_info
[
'index'
]
,
type
=
'send_v2'
,
inputs
=
{
'X'
:
var
},
attrs
=
{
self
.
_op_device_key
:
prev_dev
ice
,
self
.
_op_device_key
:
prev_dev
,
self
.
_op_role_key
:
op_role
,
'use_calc_stream'
:
True
,
'peer'
:
1
,
'ring_id'
:
ring_id
})
extra_index
+=
1
extra_index
_info
[
'index'
]
+=
1
block
.
_insert_op
(
index
=
index
+
extra_index
,
index
=
index
+
extra_index
_info
[
'index'
]
,
type
=
'recv_v2'
,
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
'out_shape'
:
var
.
shape
,
'dtype'
:
var
.
dtype
,
self
.
_op_device_key
:
cur_dev
ice
,
self
.
_op_device_key
:
cur_dev
,
self
.
_op_role_key
:
op_role
,
'use_calc_stream'
:
True
,
'peer'
:
0
,
'ring_id'
:
ring_id
})
extra_index
+=
1
extra_index
_info
[
'index'
]
+=
1
elif
self
.
schedule_mode
==
'1F1B'
:
# 1F1B
block
.
_insert_op
(
index
=
index
+
extra_index
,
index
=
index
+
extra_index
_info
[
'index'
]
,
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
self
.
_op_device_key
:
prev_dev
ice
,
self
.
_op_device_key
:
prev_dev
,
self
.
_op_role_key
:
op_role
,
})
extra_index
+=
1
extra_index
_info
[
'index'
]
+=
1
block
.
_insert_op
(
index
=
index
+
extra_index
,
index
=
index
+
extra_index
_info
[
'index'
]
,
type
=
'send_v2'
,
inputs
=
{
'X'
:
var
},
attrs
=
{
self
.
_op_device_key
:
prev_dev
ice
,
self
.
_op_device_key
:
prev_dev
,
self
.
_op_role_key
:
op_role
,
'use_calc_stream'
:
False
,
'ring_id'
:
ring_id
,
'peer'
:
1
,
})
extra_index
+=
1
extra_index
_info
[
'index'
]
+=
1
block
.
_insert_op
(
index
=
index
+
extra_index
,
index
=
index
+
extra_index
_info
[
'index'
]
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
self
.
_op_device_key
:
prev_dev
ice
,
self
.
_op_device_key
:
prev_dev
,
self
.
_op_role_key
:
self
.
_op_role
.
Backward
,
'ring_id'
:
ring_id
,
})
extra_index
+=
1
extra_index
_info
[
'index'
]
+=
1
var_shape
=
list
(
var
.
shape
)
var_shape
[
0
]
=
self
.
micro_batch_size
if
var_shape
[
0
]
<
0
else
var_shape
[
0
]
block
.
_insert_op
(
index
=
index
+
extra_index
,
index
=
index
+
extra_index
_info
[
'index'
]
,
type
=
'recv_v2'
,
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
'out_shape'
:
var_shape
,
'dtype'
:
var
.
dtype
,
self
.
_op_device_key
:
cur_dev
ice
,
self
.
_op_device_key
:
cur_dev
,
self
.
_op_role_key
:
op_role
,
'use_calc_stream'
:
True
,
'peer'
:
0
,
'ring_id'
:
ring_id
})
extra_index
+=
1
extra_index
_info
[
'index'
]
+=
1
else
:
raise
ValueError
(
"Now only 'F-then-B' and '1F1B' are supported."
"The given value is {}."
.
format
(
self
.
schedule_mode
))
_insert_send_recv
(
int
(
cur_device
.
split
(
':'
)[
1
]),
int
(
prev_device
.
split
(
':'
)[
1
]))
block
.
_sync_with_cpp
()
def
_insert_loss_scale
(
self
,
block
):
"""
Scale the loss corresponding to number of micro-batches.
...
...
@@ -4675,6 +4685,23 @@ class PipelineOptimizer(object):
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/regularization"
)
def
_get_input_output_info
(
self
,
block
):
'''
Get info of op input and output.
'''
# A map from output var to op which generate it.
self
.
output_var_to_op
=
dict
()
# A map from var to op which takes it as input.
self
.
input_var_to_op
=
dict
()
for
index
,
op
in
enumerate
(
list
(
block
.
ops
)):
for
var_name
in
op
.
input_arg_names
:
ops
=
self
.
input_var_to_op
.
setdefault
(
var_name
,
[])
ops
.
append
([
op
,
index
])
for
var_name
in
op
.
output_arg_names
:
ops
=
self
.
output_var_to_op
.
setdefault
(
var_name
,
[])
ops
.
append
([
op
,
index
])
def
minimize
(
self
,
loss
,
startup_program
=
None
,
...
...
@@ -4682,30 +4709,35 @@ class PipelineOptimizer(object):
no_grad_set
=
None
):
main_block
=
loss
.
block
self
.
origin_main_block
=
main_block
main_program
=
main_block
.
program
if
startup_program
is
None
:
startup_program
=
default_startup_program
()
optimize_ops
,
params_grads
=
self
.
_optimizer
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
self
.
_param_device_map
=
self
.
_origin_optimizer
.
_param_device_map
assert
main_block
.
program
.
_pipeline_opt
\
and
'local_rank'
in
main_block
.
program
.
_pipeline_opt
,
\
'Please use pipeline with fleet.'
local_rank
=
main_block
.
program
.
_pipeline_opt
[
'local_rank'
]
self
.
_global_ring_id
=
main_block
.
program
.
_pipeline_opt
[
'global_ring_id'
]
schedule_mode
=
0
if
'schedule_mode'
in
main_block
.
program
.
_pipeline_opt
:
schedule_mode
=
main_block
.
program
.
_pipeline_opt
[
'schedule_mode'
]
self
.
schedule_mode
=
schedule_mode
# micro batch size
assert
main_program
.
_pipeline_opt
,
'Please use pipeline with fleet.'
required_keys
=
[
'local_rank'
,
'schedule_mode'
,
'micro_batch_size'
,
'ring_id'
,
'global_ring_id'
,
'use_sharding'
,
]
for
key
in
required_keys
:
assert
key
in
main_program
.
_pipeline_opt
,
\
'Please use pipeline with fleet to use {}.'
.
format
(
key
)
self
.
local_rank
=
main_block
.
program
.
_pipeline_opt
[
'local_rank'
]
self
.
schedule_mode
=
main_block
.
program
.
_pipeline_opt
[
'schedule_mode'
]
self
.
micro_batch_size
=
main_block
.
program
.
_pipeline_opt
[
'micro_batch_size'
]
self
.
use_sharding
=
False
if
'use_sharding'
in
main_block
.
program
.
_pipeline_opt
:
self
.
use_sharding
=
main_block
.
program
.
_pipeline_opt
[
'use_sharding'
]
self
.
use_sharding
=
main_block
.
program
.
_pipeline_opt
[
'use_sharding'
]
self
.
ring_id
=
main_block
.
program
.
_pipeline_opt
[
'ring_id'
]
self
.
global_ring_id
=
main_block
.
program
.
_pipeline_opt
[
'global_ring_id'
]
optimize_ops
,
params_grads
=
self
.
_optimizer
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
self
.
_param_device_map
=
self
.
_origin_optimizer
.
_param_device_map
self
.
_get_input_output_info
(
main_block
)
# Step1: add default op_device attribute for ops.
self
.
_add_op_device_attr
(
main_block
)
device_list
=
self
.
_check_validation
(
main_block
)
...
...
@@ -4742,20 +4774,20 @@ class PipelineOptimizer(object):
# Step5: Add sub blocks for section programs
self
.
_add_sub_blocks
(
main_block
,
program_list
)
local_rank
=
main_program
.
_pipeline_opt
[
'local_rank'
]
%
len
(
device_list
)
self
.
local_rank
%=
len
(
device_list
)
place_list
=
[]
for
dev
in
device_list
:
dev_index
=
int
(
dev
.
split
(
":"
)[
1
])
place_list
.
append
(
core
.
CUDAPlace
(
dev_index
%
8
))
place_list
.
append
(
core
.
CUDAPlace
(
0
))
# Step6: Split startup program
new_startup_program
=
self
.
_split_startup_program
(
startup_program
,
local_rank
)
self
.
local_rank
)
startup_program
.
_pipeline_opt
=
{
"startup_program"
:
new_startup_program
,
}
real_block
=
program_list
[
local_rank
].
global_block
()
real_block
=
program_list
[
self
.
local_rank
].
global_block
()
self
.
_insert_loss_scale
(
real_block
)
if
not
self
.
use_sharding
:
# Step7: clear gradients before each mini-batch and
...
...
@@ -4769,12 +4801,12 @@ class PipelineOptimizer(object):
main_program
.
_pipeline_opt
=
{
"trainer"
:
"PipelineTrainer"
,
"device_worker"
:
"Section"
,
"pipeline_stage"
:
local_rank
,
"pipeline_stage"
:
self
.
local_rank
,
"num_pipeline_stages"
:
len
(
device_list
),
"schedule_mode"
:
self
.
schedule_mode
,
"inner_parallelism"
:
len
(
device_list
),
"section_program"
:
program_list
[
local_rank
],
"place"
:
place_list
[
local_rank
],
"section_program"
:
program_list
[
self
.
local_rank
],
"place"
:
place_list
[
self
.
local_rank
],
"place_id"
:
place_id
,
"sync_steps"
:
-
1
,
"num_microbatches"
:
self
.
_num_microbatches
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录