Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b1a23b82
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b1a23b82
编写于
8月 07, 2020
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
initialize gradient at the begining of each run instead of startup program, test=develop
上级
0360e583
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
48 addition
and
77 deletion
+48
-77
paddle/fluid/framework/pipeline_trainer.cc
paddle/fluid/framework/pipeline_trainer.cc
+3
-5
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+45
-72
未找到文件。
paddle/fluid/framework/pipeline_trainer.cc
浏览文件 @
b1a23b82
...
@@ -130,31 +130,29 @@ void PipelineTrainer::CopyParameters(int section_id, int microbatch_id,
...
@@ -130,31 +130,29 @@ void PipelineTrainer::CopyParameters(int section_id, int microbatch_id,
}
}
}
}
for
(
auto
&
var
:
global_block
.
AllVars
())
{
for
(
auto
&
var
:
global_block
.
AllVars
())
{
bool
is_grad
=
false
;
bool
is_param_grad
=
false
;
bool
is_param_grad
=
false
;
size_t
pos
=
0
;
size_t
pos
=
0
;
if
((
pos
=
var
->
Name
().
find
(
kGradVarSuffix
))
!=
std
::
string
::
npos
)
{
if
((
pos
=
var
->
Name
().
find
(
kGradVarSuffix
))
!=
std
::
string
::
npos
)
{
is_grad
=
true
;
auto
prefix_name
=
var
->
Name
().
substr
(
0
,
pos
);
auto
prefix_name
=
var
->
Name
().
substr
(
0
,
pos
);
if
(
param_map
.
find
(
prefix_name
)
!=
param_map
.
end
())
{
if
(
param_map
.
find
(
prefix_name
)
!=
param_map
.
end
())
{
is_param_grad
=
true
;
is_param_grad
=
true
;
}
}
}
}
VLOG
(
3
)
<<
"Var name: "
<<
var
->
Name
();
VLOG
(
3
)
<<
"Var name: "
<<
var
->
Name
();
if
((
var
->
Persistable
()
||
is_grad
)
&&
microbatch_id
==
0
)
{
if
((
var
->
Persistable
()
||
is_
param_
grad
)
&&
microbatch_id
==
0
)
{
auto
*
ptr
=
root_scope_
->
FindVar
(
var
->
Name
());
auto
*
ptr
=
root_scope_
->
FindVar
(
var
->
Name
());
auto
*
new_ptr
=
minibatch_scopes_
[
section_id
]
->
Var
(
var
->
Name
());
auto
*
new_ptr
=
minibatch_scopes_
[
section_id
]
->
Var
(
var
->
Name
());
VLOG
(
3
)
<<
"Create persistable var "
<<
var
->
Name
()
<<
" for minibatch "
VLOG
(
3
)
<<
"Create persistable var "
<<
var
->
Name
()
<<
" for minibatch "
<<
section_id
<<
", which pointer is "
<<
new_ptr
;
<<
section_id
<<
", which pointer is "
<<
new_ptr
;
InitializeVariable
(
new_ptr
,
var
->
GetType
());
InitializeVariable
(
new_ptr
,
var
->
GetType
());
if
(
!
var
->
Persistable
()
&&
!
is_param_grad
)
{
if
(
is_param_grad
)
{
continue
;
continue
;
}
}
const
LoDTensor
&
root_tensor
=
ptr
->
Get
<
LoDTensor
>
();
const
LoDTensor
&
root_tensor
=
ptr
->
Get
<
LoDTensor
>
();
LoDTensor
*
minibatch_tensor
=
new_ptr
->
GetMutable
<
LoDTensor
>
();
LoDTensor
*
minibatch_tensor
=
new_ptr
->
GetMutable
<
LoDTensor
>
();
TensorCopy
(
*
static_cast
<
const
Tensor
*>
(
&
root_tensor
),
place
,
TensorCopy
(
*
static_cast
<
const
Tensor
*>
(
&
root_tensor
),
place
,
static_cast
<
Tensor
*>
(
minibatch_tensor
));
static_cast
<
Tensor
*>
(
minibatch_tensor
));
}
else
if
(
!
var
->
Persistable
()
&&
!
is_grad
)
{
}
else
if
(
!
var
->
Persistable
()
&&
!
is_
param_
grad
)
{
auto
*
ptr
=
auto
*
ptr
=
microbatch_scopes_
[
section_id
][
microbatch_id
]
->
Var
(
var
->
Name
());
microbatch_scopes_
[
section_id
][
microbatch_id
]
->
Var
(
var
->
Name
());
VLOG
(
3
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for section "
VLOG
(
3
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for section "
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
b1a23b82
...
@@ -3703,7 +3703,7 @@ class PipelineOptimizer(object):
...
@@ -3703,7 +3703,7 @@ class PipelineOptimizer(object):
self
.
_op_role_key
=
op_maker
.
kOpRoleAttrName
()
self
.
_op_role_key
=
op_maker
.
kOpRoleAttrName
()
self
.
_op_role_var_key
=
op_maker
.
kOpRoleVarAttrName
()
self
.
_op_role_var_key
=
op_maker
.
kOpRoleVarAttrName
()
self
.
_op_device_key
=
op_maker
.
kOpDeviceAttrName
()
self
.
_op_device_key
=
op_maker
.
kOpDeviceAttrName
()
self
.
_param_device_map
=
dict
()
self
.
_param_device_map
=
None
def
_create_vars
(
self
,
block
,
main_program
):
def
_create_vars
(
self
,
block
,
main_program
):
# Create vars for block, copied from main_program's global block
# Create vars for block, copied from main_program's global block
...
@@ -3742,9 +3742,10 @@ class PipelineOptimizer(object):
...
@@ -3742,9 +3742,10 @@ class PipelineOptimizer(object):
return
'Param'
in
op
.
input_names
and
'Grad'
in
op
.
input_names
and
(
return
'Param'
in
op
.
input_names
and
'Grad'
in
op
.
input_names
and
(
"LearningRate"
in
op
.
input_names
)
"LearningRate"
in
op
.
input_names
)
def
_split_program
(
self
,
main_program
):
def
_split_program
(
self
,
main_program
,
devices
):
"""
"""
Split a program into sections according to devices that ops run on.
Split a program into sections according to devices that ops run on.
The ops of the role LRSched are copied to all sections.
Args:
Args:
main_program (Program): the main program
main_program (Program): the main program
...
@@ -3752,18 +3753,27 @@ class PipelineOptimizer(object):
...
@@ -3752,18 +3753,27 @@ class PipelineOptimizer(object):
programs
=
[]
programs
=
[]
# Map from device to its corresponding section program info
# Map from device to its corresponding section program info
device_program_map
=
dict
()
device_program_map
=
dict
()
block
=
main_program
.
block
(
0
)
for
device
in
devices
:
p
=
{
'program'
:
Program
()}
device_program_map
[
device
]
=
p
block
=
main_program
.
block
(
0
)
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
device
=
op
.
attr
(
self
.
_op_device_key
)
device
=
op
.
attr
(
self
.
_op_device_key
)
op_role
=
op
.
attr
(
self
.
_op_role_key
)
if
device
not
in
device_program_map
:
if
int
(
op_role
)
&
int
(
self
.
_op_role
.
LRSched
):
program
=
{
"program"
:
Program
()}
# Copy ops of the role LRSched to all sections.
device_program_map
[
device
]
=
program
for
device
in
device_program_map
.
keys
():
program
=
device_program_map
[
device
]
program
=
device_program_map
[
device
]
op_desc
=
op
.
desc
op_desc
=
op
.
desc
ap_op
=
program
[
"program"
].
block
(
0
).
desc
.
append_op
()
ap_op
=
program
[
"program"
].
block
(
0
).
desc
.
append_op
()
ap_op
.
copy_from
(
op_desc
)
ap_op
.
copy_from
(
op_desc
)
ap_op
.
_set_attr
(
self
.
_op_device_key
,
device
)
else
:
program
=
device_program_map
[
device
]
op_desc
=
op
.
desc
ap_op
=
program
[
"program"
].
block
(
0
).
desc
.
append_op
()
ap_op
.
copy_from
(
op_desc
)
for
key
in
sorted
(
device_program_map
.
keys
()):
for
key
in
sorted
(
device_program_map
.
keys
()):
program
=
device_program_map
[
key
]
program
=
device_program_map
[
key
]
...
@@ -3948,18 +3958,6 @@ class PipelineOptimizer(object):
...
@@ -3948,18 +3958,6 @@ class PipelineOptimizer(object):
"""
"""
return
name
+
core
.
grad_var_suffix
()
return
name
+
core
.
grad_var_suffix
()
def
_update_param_device_map
(
self
,
params_grads
,
block
):
for
param_grad
in
params_grads
:
if
not
param_grad
[
0
].
trainable
:
continue
param_name
=
param_grad
[
0
].
name
ops
=
block
.
ops
for
op
in
ops
:
input_arg_names
=
op
.
input_arg_names
if
param_name
in
input_arg_names
:
self
.
_param_device_map
[
param_name
]
=
op
.
attr
(
self
.
_op_device_key
)
break
def
_add_opdevice_attr_for_regularization_clip
(
self
,
block
):
def
_add_opdevice_attr_for_regularization_clip
(
self
,
block
):
"""
"""
Add op_device attribute for regulization and clip ops.
Add op_device attribute for regulization and clip ops.
...
@@ -4043,6 +4041,8 @@ class PipelineOptimizer(object):
...
@@ -4043,6 +4041,8 @@ class PipelineOptimizer(object):
"{} has not been set."
.
format
(
op
.
type
))
"{} has not been set."
.
format
(
op
.
type
))
if
not
dev_spec
in
device_specs
:
if
not
dev_spec
in
device_specs
:
device_specs
.
append
(
dev_spec
)
device_specs
.
append
(
dev_spec
)
sorted_device_specs
=
sorted
(
device_specs
)
assert
sorted_device_specs
==
device_specs
return
device_specs
return
device_specs
def
_insert_enq_deq_ops_for_boundaries
(
self
,
block
,
origin_block
,
def
_insert_enq_deq_ops_for_boundaries
(
self
,
block
,
origin_block
,
...
@@ -4059,6 +4059,11 @@ class PipelineOptimizer(object):
...
@@ -4059,6 +4059,11 @@ class PipelineOptimizer(object):
var_devspec
=
dict
()
var_devspec
=
dict
()
for
index
,
op
in
list
(
enumerate
(
origin_block
.
ops
)):
for
index
,
op
in
list
(
enumerate
(
origin_block
.
ops
)):
# skips lr-related op and vars, as we will process them later.
if
int
(
op
.
attr
(
self
.
_op_role_key
))
&
int
(
self
.
_op_role
.
LRSched
):
continue
if
self
.
_is_update_op
(
op
):
continue
cur_device_spec
=
op
.
attr
(
self
.
_op_device_key
)
cur_device_spec
=
op
.
attr
(
self
.
_op_device_key
)
for
var_name
in
op
.
input_arg_names
:
for
var_name
in
op
.
input_arg_names
:
# i.e., lod_tensor_blocking_queue created by DataLoader,
# i.e., lod_tensor_blocking_queue created by DataLoader,
...
@@ -4114,51 +4119,25 @@ class PipelineOptimizer(object):
...
@@ -4114,51 +4119,25 @@ class PipelineOptimizer(object):
})
})
extra_index
+=
1
extra_index
+=
1
def
_
initialize_gradients
(
self
,
startup_block
,
main_block
):
def
_
clear_gradients
(
self
,
main_block
):
"""
"""
Initialize gradients before run
.
Clear gradients at the begining of each run of a minibatch
.
"""
"""
for
param_name
in
self
.
_param_device_map
:
for
param_name
in
self
.
_param_device_map
:
grad_name
=
self
.
_append_grad_suffix
(
param_name
)
grad_name
=
self
.
_append_grad_suffix
(
param_name
)
param_var
=
startup_block
.
vars
[
param_name
]
param_var
=
main_block
.
vars
[
param_name
]
grad_var
=
self
.
_create_var
(
startup_block
,
param_var
,
grad_name
)
grad_var
=
main_block
.
vars
[
grad_name
]
main_grad_var
=
self
.
_create_var
(
main_block
,
param_var
,
grad_name
)
device
=
self
.
_param_device_map
[
param_name
]
grad_var
.
persistable
=
True
main_block
.
_insert_op
(
main_grad_var
.
persistable
=
True
index
=
0
,
startup_block
.
append_op
(
type
=
'fill_constant'
,
type
=
'fill_constant'
,
inputs
=
{},
inputs
=
{},
outputs
=
{
'Out'
:[
grad_var
]},
outputs
=
{
'Out'
:[
grad_var
]},
attrs
=
{
'shape'
:
grad_var
.
shape
,
'dtype'
:
grad_var
.
dtype
,
'value'
:
float
(
0
)
})
def
_clear_gradients
(
self
,
block
):
"""
Clear gradients after update.
"""
for
index
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
device
=
op
.
attr
(
self
.
_op_device_key
)
if
not
self
.
_is_update_op
(
op
):
continue
assert
self
.
_op_role_var_key
in
op
.
attr_names
op_role_var
=
op
.
all_attrs
()[
self
.
_op_role_var_key
]
assert
len
(
op_role_var
)
==
2
grad_name
=
op_role_var
[
1
]
grad_var
=
block
.
var
(
grad_name
)
block
.
append_op
(
type
=
'fill_constant'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
grad_var
]},
attrs
=
{
attrs
=
{
'shape'
:
grad_var
.
shape
,
'shape'
:
grad_var
.
shape
,
'dtype'
:
grad_var
.
dtype
,
'dtype'
:
grad_var
.
dtype
,
'value'
:
float
(
0
),
'value'
:
float
(
0
),
'force_cpu'
:
False
,
self
.
_op_device_key
:
device
,
self
.
_op_device_key
:
device
,
self
.
_op_role_key
:
self
.
_op_role
.
Optimize
})
})
def
_accumulate_gradients
(
self
,
block
):
def
_accumulate_gradients
(
self
,
block
):
...
@@ -4338,13 +4317,14 @@ class PipelineOptimizer(object):
...
@@ -4338,13 +4317,14 @@ class PipelineOptimizer(object):
startup_program
=
default_startup_program
()
startup_program
=
default_startup_program
()
optimize_ops
,
params_grads
=
self
.
_optimizer
.
minimize
(
optimize_ops
,
params_grads
=
self
.
_optimizer
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
self
.
_
update_param_device_map
(
params_grads
,
main_block
)
self
.
_
param_device_map
=
self
.
_optimizer
.
_param_device_map
# Step1: add default op_device attribute for regulization and clip ops
# Step1: add default op_device attribute for regulization and clip ops
self
.
_add_opdevice_attr_for_regularization_clip
(
main_block
)
self
.
_add_opdevice_attr_for_regularization_clip
(
main_block
)
# Step2: add default op_device attribute for ops whose op_device
# Step2: add default op_device attribute for ops whose op_device
# attribute have not been set yet.
# attribute have not been set yet. Then check all ops have the
# op_device attribute.
self
.
_add_default_opdevice_attr
(
main_block
)
self
.
_add_default_opdevice_attr
(
main_block
)
device_specs
=
self
.
_check_validation
(
main_block
)
device_specs
=
self
.
_check_validation
(
main_block
)
...
@@ -4356,9 +4336,9 @@ class PipelineOptimizer(object):
...
@@ -4356,9 +4336,9 @@ class PipelineOptimizer(object):
# Step4: accumulate gradients during backward
# Step4: accumulate gradients during backward
# and clear them after update
# and clear them after update
self
.
_initialize_gradients
(
startup_program
.
global_block
(),
main_block
)
self
.
_accumulate_gradients
(
main_block
)
self
.
_clear_gradients
(
main_block
)
self
.
_clear_gradients
(
main_block
)
self
.
_accumulate_gradients
(
main_block
)
#self._clear_gradients(main_block)
main_program
=
main_block
.
program
main_program
=
main_block
.
program
...
@@ -4377,18 +4357,11 @@ class PipelineOptimizer(object):
...
@@ -4377,18 +4357,11 @@ class PipelineOptimizer(object):
# Step5: split program into sections and add pairs of
# Step5: split program into sections and add pairs of
# enqueue and dequeue ops for data var.
# enqueue and dequeue ops for data var.
if
len
(
place_list
)
==
0
:
if
len
(
place_list
)
<=
1
:
program_list
=
[]
raise
ValueError
(
"Run on one device, do not use pipeline."
)
ptmp
=
{
program_list
=
self
.
_split_program
(
main_program
,
device_specs
)
"program"
:
main_program
,
for
p
in
program_list
:
"input_set"
:
set
(),
self
.
_create_vars
(
p
[
"program"
].
block
(
0
),
main_program
)
"output_set"
:
set
()
}
program_list
.
append
(
ptmp
)
else
:
program_list
=
self
.
_split_program
(
main_program
)
for
p
in
program_list
:
self
.
_create_vars
(
p
[
"program"
].
block
(
0
),
main_program
)
self
.
_insert_enq_deq_for_data_var
(
main_block
,
program_list
,
self
.
_insert_enq_deq_for_data_var
(
main_block
,
program_list
,
startup_program
,
device_specs
)
startup_program
,
device_specs
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录