Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
131bd8e8
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
131bd8e8
编写于
8月 06, 2020
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
accumulate gradients instead of dequeueN, test=develop
上级
1e4e23a8
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
184 addition
and
213 deletion
+184
-213
paddle/fluid/framework/pipeline_trainer.cc
paddle/fluid/framework/pipeline_trainer.cc
+30
-19
paddle/fluid/framework/section_worker.cc
paddle/fluid/framework/section_worker.cc
+84
-59
paddle/fluid/framework/trainer.h
paddle/fluid/framework/trainer.h
+0
-1
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+70
-134
未找到文件。
paddle/fluid/framework/pipeline_trainer.cc
浏览文件 @
131bd8e8
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#if defined(PADDLE_WITH_NCCL)
#include <map>
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
...
...
@@ -44,7 +45,6 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
"must be 1 now, but the value you give is %d."
,
num_readers
));
auto
*
reader
=
readers
[
0
];
feed_var_names_
=
reader
->
GetUseSlotAlias
();
workers_
.
resize
(
section_num_
);
for
(
int
i
=
0
;
i
<
section_num_
;
++
i
)
{
...
...
@@ -123,27 +123,38 @@ void PipelineTrainer::CopyParameters(int section_id, int microbatch_id,
const
ProgramDesc
&
program
,
const
platform
::
Place
&
place
)
{
auto
&
global_block
=
program
.
Block
(
0
);
std
::
map
<
std
::
string
,
int
>
param_map
;
for
(
auto
&
var
:
global_block
.
AllVars
())
{
int
is_feed_var
=
std
::
count
(
feed_var_names_
.
begin
(),
feed_var_names_
.
end
(),
var
->
Name
());
if
(
var
->
Persistable
())
{
param_map
[
var
->
Name
()]
=
1
;
}
}
for
(
auto
&
var
:
global_block
.
AllVars
())
{
bool
is_grad
=
false
;
bool
is_param_grad
=
false
;
size_t
pos
=
0
;
if
((
pos
=
var
->
Name
().
find
(
kGradVarSuffix
))
!=
std
::
string
::
npos
)
{
is_grad
=
true
;
auto
prefix_name
=
var
->
Name
().
substr
(
0
,
pos
);
if
(
param_map
.
find
(
prefix_name
)
!=
param_map
.
end
())
{
is_param_grad
=
true
;
}
}
VLOG
(
3
)
<<
"Var name: "
<<
var
->
Name
();
if
((
var
->
Persistable
()
||
is_feed_var
)
&&
microbatch_id
==
0
)
{
if
(
is_feed_var
)
{
auto
*
new_ptr
=
minibatch_scopes_
[
section_id
]
->
Var
(
var
->
Name
());
VLOG
(
3
)
<<
"data name: "
<<
var
->
Name
()
<<
", ptr: "
<<
new_ptr
;
InitializeVariable
(
new_ptr
,
var
->
GetType
());
}
else
{
auto
*
ptr
=
root_scope_
->
FindVar
(
var
->
Name
());
auto
*
new_ptr
=
minibatch_scopes_
[
section_id
]
->
Var
(
var
->
Name
());
VLOG
(
3
)
<<
"Create persistable var "
<<
var
->
Name
()
<<
" for minibatch "
<<
section_id
<<
", which pointer is "
<<
new_ptr
;
InitializeVariable
(
new_ptr
,
var
->
GetType
());
const
LoDTensor
&
root_tensor
=
ptr
->
Get
<
LoDTensor
>
();
LoDTensor
*
minibatch_tensor
=
new_ptr
->
GetMutable
<
LoDTensor
>
();
TensorCopy
(
*
static_cast
<
const
Tensor
*>
(
&
root_tensor
),
place
,
static_cast
<
Tensor
*>
(
minibatch_tensor
));
if
((
var
->
Persistable
()
||
is_grad
)
&&
microbatch_id
==
0
)
{
auto
*
ptr
=
root_scope_
->
FindVar
(
var
->
Name
());
auto
*
new_ptr
=
minibatch_scopes_
[
section_id
]
->
Var
(
var
->
Name
());
VLOG
(
3
)
<<
"Create persistable var "
<<
var
->
Name
()
<<
" for minibatch "
<<
section_id
<<
", which pointer is "
<<
new_ptr
;
InitializeVariable
(
new_ptr
,
var
->
GetType
());
if
(
!
var
->
Persistable
()
&&
!
is_param_grad
)
{
continue
;
}
}
else
if
(
!
var
->
Persistable
()
&&
!
is_feed_var
)
{
const
LoDTensor
&
root_tensor
=
ptr
->
Get
<
LoDTensor
>
();
LoDTensor
*
minibatch_tensor
=
new_ptr
->
GetMutable
<
LoDTensor
>
();
TensorCopy
(
*
static_cast
<
const
Tensor
*>
(
&
root_tensor
),
place
,
static_cast
<
Tensor
*>
(
minibatch_tensor
));
}
else
if
(
!
var
->
Persistable
()
&&
!
is_grad
)
{
auto
*
ptr
=
microbatch_scopes_
[
section_id
][
microbatch_id
]
->
Var
(
var
->
Name
());
VLOG
(
3
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for section "
...
...
paddle/fluid/framework/section_worker.cc
浏览文件 @
131bd8e8
...
...
@@ -109,6 +109,8 @@ void SectionWorker::TrainFiles() {
if
(
thread_id_
==
0
)
{
while
(
true
)
{
// Start a minibatch.
// real number of microbatches run
int
real_microbatch_num
=
0
;
batch_timer
.
Start
();
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
try
{
...
...
@@ -141,18 +143,20 @@ void SectionWorker::TrainFiles() {
VLOG
(
3
)
<<
"called notify all"
;
thread_condition
.
notify_all
();
VLOG
(
0
)
<<
"EOF encountered"
;
return
;
break
;
}
if
(
i
==
0
)
{
{
real_microbatch_num
+=
1
;
batch_id_
+=
1
;
VLOG
(
3
)
<<
"called notify all"
;
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
batch_id_
+=
1
;
thread_condition
.
notify_all
();
}
}
dev_ctx_
->
Wait
();
VLOG
(
0
)
<<
"real_microbatch_num for thread 0 "
<<
real_microbatch_num
;
// backward pass
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
real_microbatch_num
;
++
i
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
...
...
@@ -169,6 +173,11 @@ void SectionWorker::TrainFiles() {
}
}
dev_ctx_
->
Wait
();
if
(
real_microbatch_num
==
0
)
{
batch_timer
.
Pause
();
VLOG
(
0
)
<<
"batch time: "
<<
batch_timer
.
ElapsedUS
();
return
;
}
// update pass
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
...
...
@@ -188,30 +197,32 @@ void SectionWorker::TrainFiles() {
}
}
else
{
while
(
true
)
{
{
PADDLE_ENFORCE_LE
(
local_batch_id_
,
batch_id_
,
platform
::
errors
::
InvalidArgument
(
"local_batch_id_ (%d) must be less than or equal to "
"batch_id_ (%d)"
,
local_batch_id_
,
batch_id_
));
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
if
(
local_batch_id_
==
batch_id_
&&
!
threads_completed
)
{
thread_condition
.
wait
(
lk
);
}
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" local_batch_id_ "
<<
local_batch_id_
<<
" batch_id_ "
<<
batch_id_
;
if
(
threads_completed
)
{
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" completed."
;
lk
.
unlock
();
threads_completed
=
false
;
return
;
}
lk
.
unlock
();
local_batch_id_
+=
1
;
}
// forward pass:
int
real_microbatch_num
=
0
;
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
{
PADDLE_ENFORCE_LE
(
local_batch_id_
,
batch_id_
,
platform
::
errors
::
InvalidArgument
(
"local_batch_id_ (%d) must be less than or equal to "
"batch_id_ (%d)"
,
local_batch_id_
,
batch_id_
));
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
if
(
local_batch_id_
==
batch_id_
&&
!
threads_completed
)
{
thread_condition
.
wait
(
lk
);
}
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" local_batch_id_ "
<<
local_batch_id_
<<
" batch_id_ "
<<
batch_id_
;
if
(
threads_completed
)
{
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" completed."
;
lk
.
unlock
();
threads_completed
=
false
;
break
;
}
lk
.
unlock
();
local_batch_id_
+=
1
;
real_microbatch_num
+=
1
;
}
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
// We run op with op_role = kLRSched only for the first microbatch
...
...
@@ -237,7 +248,7 @@ void SectionWorker::TrainFiles() {
}
dev_ctx_
->
Wait
();
// backward pass
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
real_microbatch_num
;
++
i
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
...
...
@@ -255,6 +266,9 @@ void SectionWorker::TrainFiles() {
}
dev_ctx_
->
Wait
();
// update pass
if
(
real_microbatch_num
==
0
)
{
return
;
}
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
...
...
@@ -324,6 +338,7 @@ void SectionWorker::TrainFilesWithProfiler() {
while
(
true
)
{
// Start a minibatch.
batch_timer
.
Start
();
int
real_microbatch_num
=
0
;
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
try
{
int
op_idx
=
0
;
...
...
@@ -397,18 +412,19 @@ void SectionWorker::TrainFilesWithProfiler() {
<<
", mean_time: "
<<
op_total_time
[
i
]
/
op_count
[
i
];
}
VLOG
(
0
)
<<
"================================"
;
return
;
break
;
}
if
(
i
==
0
)
{
{
VLOG
(
3
)
<<
"called notify all"
;
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
real_microbatch_num
+=
1
;
batch_id_
+=
1
;
thread_condition
.
notify_all
();
}
}
dev_ctx_
->
Wait
();
// backward pass
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
real_microbatch_num
;
++
i
)
{
int
op_idx
=
0
;
gettimeofday
(
&
micro_start
,
NULL
);
for
(
auto
&
op
:
ops_
)
{
...
...
@@ -460,6 +476,10 @@ void SectionWorker::TrainFilesWithProfiler() {
}
}
dev_ctx_
->
Wait
();
if
(
real_microbatch_num
==
0
)
{
batch_timer
.
Pause
();
VLOG
(
0
)
<<
"batch time: "
<<
batch_timer
.
ElapsedUS
();
}
// update pass
int
op_idx
=
0
;
gettimeofday
(
&
micro_start
,
NULL
);
...
...
@@ -526,36 +546,38 @@ void SectionWorker::TrainFilesWithProfiler() {
cudaEventCreate
(
&
cu_start
);
cudaEventCreate
(
&
cu_stop
);
while
(
true
)
{
{
PADDLE_ENFORCE_LE
(
local_batch_id_
,
batch_id_
,
platform
::
errors
::
InvalidArgument
(
"local_batch_id_ (%d) must be less than or equal to "
"batch_id_ (%d)"
,
local_batch_id_
,
batch_id_
));
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
if
(
local_batch_id_
==
batch_id_
&&
!
threads_completed
)
{
thread_condition
.
wait
(
lk
);
}
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" local_batch_id_ "
<<
local_batch_id_
<<
" batch_id_ "
<<
batch_id_
;
if
(
threads_completed
)
{
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" completed."
;
lk
.
unlock
();
VLOG
(
0
)
<<
"============timeline============"
;
for
(
size_t
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
VLOG
(
0
)
<<
"op: "
<<
op_name
[
i
]
<<
", max_time: "
<<
op_max_time
[
i
]
<<
", min_time: "
<<
op_min_time
[
i
]
<<
", mean_time: "
<<
op_total_time
[
i
]
/
op_count
[
i
];
}
VLOG
(
0
)
<<
"================================"
;
return
;
}
lk
.
unlock
();
local_batch_id_
+=
1
;
}
// forward pass:
int
real_microbatch_num
=
0
;
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
{
PADDLE_ENFORCE_LE
(
local_batch_id_
,
batch_id_
,
platform
::
errors
::
InvalidArgument
(
"local_batch_id_ (%d) must be less than or equal to "
"batch_id_ (%d)"
,
local_batch_id_
,
batch_id_
));
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
if
(
local_batch_id_
==
batch_id_
&&
!
threads_completed
)
{
thread_condition
.
wait
(
lk
);
}
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" local_batch_id_ "
<<
local_batch_id_
<<
" batch_id_ "
<<
batch_id_
;
if
(
threads_completed
)
{
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" completed."
;
lk
.
unlock
();
VLOG
(
0
)
<<
"============timeline============"
;
for
(
size_t
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
VLOG
(
0
)
<<
"op: "
<<
op_name
[
i
]
<<
", max_time: "
<<
op_max_time
[
i
]
<<
", min_time: "
<<
op_min_time
[
i
]
<<
", mean_time: "
<<
op_total_time
[
i
]
/
op_count
[
i
];
}
VLOG
(
0
)
<<
"================================"
;
break
;
}
lk
.
unlock
();
real_microbatch_num
+=
1
;
local_batch_id_
+=
1
;
}
int
op_idx
=
0
;
gettimeofday
(
&
micro_start
,
NULL
);
for
(
auto
&
op
:
ops_
)
{
...
...
@@ -616,7 +638,7 @@ void SectionWorker::TrainFilesWithProfiler() {
}
dev_ctx_
->
Wait
();
// backward pass
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
real_microbatch_num
;
++
i
)
{
int
op_idx
=
0
;
gettimeofday
(
&
micro_start
,
NULL
);
for
(
auto
&
op
:
ops_
)
{
...
...
@@ -668,6 +690,9 @@ void SectionWorker::TrainFilesWithProfiler() {
}
}
dev_ctx_
->
Wait
();
if
(
real_microbatch_num
==
0
)
{
return
;
}
// update pass
int
op_idx
=
0
;
gettimeofday
(
&
micro_start
,
NULL
);
...
...
paddle/fluid/framework/trainer.h
浏览文件 @
131bd8e8
...
...
@@ -143,7 +143,6 @@ class PipelineTrainer : public TrainerBase {
int
section_num_
;
int
num_microbatches_
;
int
start_cpu_core_id_
;
std
::
vector
<
std
::
string
>
feed_var_names_
;
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
skip_vars_
;
TrainerDesc
trainer_desc_
;
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
131bd8e8
...
...
@@ -3676,15 +3676,9 @@ class PipelineOptimizer(object):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch_size = 1
filelist = [] # you should set your own filelist, e.g. filelist = ["dataA.txt"]
dataset = fluid.DatasetFactory().create_dataset("FileInstantDataset")
dataset.set_use_var([x,y])
dataset.set_batch_size(batch_size)
dataset.set_filelist(filelist)
data_loader.start()
exe.train_from_dataset(
fluid.default_main_program(),
dataset)
fluid.default_main_program())
data_loader.reset()
"""
...
...
@@ -3701,7 +3695,7 @@ class PipelineOptimizer(object):
"num_microbatches must be a positive value."
)
self
.
_num_microbatches
=
num_microbatches
assert
start_cpu_core_id
>=
0
,
(
"start_cpu_core_id must be
greater than or equal to 0
."
)
"start_cpu_core_id must be
a non negative integer
."
)
self
.
_start_cpu_core_id
=
start_cpu_core_id
self
.
_place_list
=
None
op_maker
=
core
.
op_proto_and_checker_maker
...
...
@@ -3800,8 +3794,6 @@ class PipelineOptimizer(object):
if
in_var_name
==
var_name
:
post_op
.
append
(
op
)
if
post_op
:
if
not
len
(
post_op
)
==
1
:
raise
ValueError
(
"Each op can only have one post op."
)
return
post_op
[
0
]
return
None
...
...
@@ -3856,60 +3848,26 @@ class PipelineOptimizer(object):
def
_get_data_var_info
(
self
,
block
):
"""
Get all vars whose is_data attribute are true and then rename them.
For PipelineTrainer, all data vars are binded to
minibatch scope, so we have to feed them to the microbatch
to avoid conflicts. The vars feeded to microbatch have to
be renamed.
"""
# A map from var name to the renamed name.
raw_name_new_name_map
=
dict
()
# Because we will create vars in block, it is more safe
# to get all var_names before iteration.
var_names
=
list
(
block
.
vars
.
keys
())
for
var_name
in
var_names
:
var
=
block
.
var
(
var_name
)
if
not
var
.
is_data
:
continue
assert
var_name
not
in
raw_name_new_name_map
,
(
"{} has already been processed."
.
format
(
var_name
))
new_name
=
unique_name
.
generate
(
var_name
)
raw_name_new_name_map
[
var_name
]
=
new_name
new_var
=
self
.
_create_var
(
block
,
var
,
new_name
)
new_var
.
is_data
=
False
# map of data to devices that that data on
# map of data vars to devices that that data on
data_devices_map
=
dict
()
for
op
in
block
.
ops
:
dev_spec
=
op
.
attr
(
self
.
_op_device_key
)
for
var_name
in
op
.
input_arg_names
:
if
var_name
not
in
raw_name_new_name_map
:
if
"blocking_queue"
in
var_name
:
continue
var
=
block
.
var
(
var_name
)
if
not
var
.
is_data
:
continue
if
not
var_name
in
data_devices_map
:
data_devices_map
[
var_name
]
=
[]
if
not
dev_spec
in
data_devices_map
[
var_name
]:
data_devices_map
[
var_name
].
append
(
dev_spec
)
new_name
=
raw_name_new_name_map
[
var_name
]
#self._rename_arg(op, var_name, new_name)
return
data_devices_map
,
raw_name_new_name_map
def
_rename_var_in_block
(
self
,
block
,
raw_name_new_name_map
):
"""
Rename vars whose names in raw_name_new_name_map to the corresponding
new names.
"""
for
op
in
block
.
ops
:
if
op
.
type
==
"enqueue"
or
op
.
type
==
"dequeue"
:
continue
for
var_name
in
op
.
input_arg_names
:
if
var_name
in
raw_name_new_name_map
:
new_name
=
raw_name_new_name_map
[
var_name
]
self
.
_rename_arg
(
op
,
var_name
,
new_name
)
return
data_devices_map
def
_insert_enq_deq_for_data_var
(
self
,
main_block
,
programs
,
startup
,
devices
):
"""
Insert enqueue and dequeue ops for data var
Insert enqueue and dequeue ops for data var
that on other devices.
Args:
main_block (Block): Global block for main program
...
...
@@ -3918,22 +3876,19 @@ class PipelineOptimizer(object):
devices (list): List of devices in the format (dev:dev_index)
"""
main_program
=
main_block
.
program
data_devices_map
,
raw_name_new_name_map
=
self
.
_get_data_var_info
(
main_block
)
data_devices_map
=
self
.
_get_data_var_info
(
main_block
)
first_prog
=
programs
[
0
][
'program'
]
first_block
=
first_prog
.
block
(
0
)
enqueue_index
=
0
if
first_block
.
ops
[
0
].
type
==
"create_py_reader"
or
(
first_block
.
ops
[
1
].
type
==
"create_py_reader"
):
for
op
in
first_block
.
ops
:
if
op
.
type
==
"read"
:
enqueue_index
+=
1
break
enqueue_index
+=
1
for
op
in
first_block
.
ops
:
enqueue_index
+=
1
if
op
.
type
==
"read"
:
break
first_dev_spec
=
devices
[
0
]
for
var_name
in
data_devices_map
.
keys
():
for
device
in
data_devices_map
[
var_name
]:
if
device
==
first_dev_spec
:
continue
# step1: generate queue for each pair of data var and device
# that that data on
queue_name
=
var_name
+
"_blocking_queue"
...
...
@@ -3967,13 +3922,10 @@ class PipelineOptimizer(object):
prog
=
programs
[
prog_index
][
'program'
]
block
=
prog
.
block
(
0
)
index
=
0
if
device
==
first_dev_spec
:
index
=
enqueue_index
+
1
new_name
=
raw_name_new_name_map
[
var_name
]
source_var
=
main_program
.
block
(
0
).
var
(
var_name
)
new_var
=
self
.
_create_var
(
block
,
source_var
,
new
_name
)
new_var
=
self
.
_create_var
(
block
,
source_var
,
var
_name
)
block
.
_insert_op
(
index
=
index
,
index
=
0
,
type
=
'dequeue'
,
outputs
=
{
'Out'
:
[
new_var
]},
attrs
=
{
...
...
@@ -3981,7 +3933,6 @@ class PipelineOptimizer(object):
self
.
_op_role_key
:
self
.
_op_role
.
Forward
,
'queue_name'
:
queue_name
,
})
self
.
_rename_var_in_block
(
block
,
raw_name_new_name_map
)
def
_strip_grad_suffix
(
self
,
name
):
"""
...
...
@@ -4162,82 +4113,57 @@ class PipelineOptimizer(object):
})
extra_index
+=
1
def
_add_dequeue_ops_for_optimize
(
self
,
block
,
startup_program
):
startup_block
=
startup_program
.
global_block
()
grad_queue_map
=
dict
()
grad_device_map
=
dict
()
optimize_index
=
None
grad_names_to_dequeue
=
[]
def
_initialize_gradients
(
self
,
startup_block
,
main_block
):
"""
Initialize gradients before run.
"""
for
param_name
in
self
.
_param_device_map
:
grad_name
=
self
.
_append_grad_suffix
(
param_name
)
param_var
=
startup_block
.
vars
[
param_name
]
grad_var
=
self
.
_create_var
(
startup_block
,
param_var
,
grad_name
)
main_grad_var
=
self
.
_create_var
(
main_block
,
param_var
,
grad_name
)
grad_var
.
persistable
=
True
main_grad_var
.
persistable
=
True
startup_block
.
append_op
(
type
=
'fill_constant'
,
inputs
=
{},
outputs
=
{
'Out'
:[
grad_var
]},
attrs
=
{
'shape'
:
grad_var
.
shape
,
'dtype'
:
grad_var
.
dtype
,
'value'
:
float
(
0
)
})
def
_clear_gradients
(
self
,
block
):
"""
Clear gradients after update.
"""
for
index
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
device
=
op
.
attr
(
self
.
_op_device_key
)
# Optimizer pass
if
not
self
.
_is_optimize_op
(
op
):
optimize_index
=
index
+
1
break
if
not
self
.
_is_update_op
(
op
):
continue
assert
self
.
_op_role_var_key
in
op
.
attr_names
op_role_var
=
op
.
all_attrs
()[
self
.
_op_role_var_key
]
assert
len
(
op_role_var
)
==
2
grad_name
=
op_role_var
[
1
]
assert
grad_name
not
in
grad_device_map
assert
grad_name
not
in
grad_names_to_dequeue
grad_device_map
[
grad_name
]
=
device
grad_names_to_dequeue
.
append
(
grad_name
)
for
grad_name
in
grad_names_to_dequeue
:
device
=
grad_device_map
[
grad_name
]
grad_names
=
[]
grads
=
[]
queue_name
=
grad_name
+
"_blocking_queue"
queue_name
=
unique_name
.
generate
(
queue_name
)
grad_queue_map
[
grad_name
]
=
queue_name
ref_var
=
block
.
vars
[
grad_name
]
queue_var
=
startup_block
.
create_var
(
name
=
queue_name
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
startup_block
.
append_op
(
type
=
'queue_generator'
,
attrs
=
{
'names'
:
[
queue_name
],
'capacity'
:
self
.
_num_microbatches
})
orig_var_name
=
self
.
_strip_grad_suffix
(
grad_name
)
for
_
in
range
(
self
.
_num_microbatches
):
u_name
=
unique_name
.
generate
(
orig_var_name
)
u_grad_name
=
self
.
_append_grad_suffix
(
u_name
)
grad_var
=
self
.
_create_var
(
block
,
ref_var
,
u_grad_name
)
grad_names
.
append
(
u_grad_name
)
grads
.
append
(
grad_var
)
block
.
_insert_op
(
index
=
optimize_index
,
type
=
'dequeue'
,
outputs
=
{
'Out'
:
grads
},
attrs
=
{
self
.
_op_device_key
:
device
,
'queue_name'
:
queue_name
,
self
.
_op_role_key
:
self
.
_op_role
.
Optimize
})
block
.
_insert_op
(
index
=
optimize_index
+
1
,
type
=
'sum'
,
inputs
=
{
'X'
:
grad_names
},
outputs
=
{
'Out'
:
ref_var
},
grad_var
=
block
.
var
(
grad_name
)
block
.
append_op
(
type
=
'fill_constant'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
grad_var
]},
attrs
=
{
'shape'
:
grad_var
.
shape
,
'dtype'
:
grad_var
.
dtype
,
'value'
:
float
(
0
),
'force_cpu'
:
False
,
self
.
_op_device_key
:
device
,
self
.
_op_role_key
:
self
.
_op_role
.
Optimize
})
return
grad_queue_map
def
_
insert_enq_deq_ops_for_update
(
self
,
block
,
startup_program
):
def
_
accumulate_gradients
(
self
,
block
):
"""
Insert enqueue and dequeue ops for gradients of parameters
.
Accumulate the graident generated in microbatch to the one in mini-batch
.
"""
startup_block
=
startup_program
.
global_block
()
grad_queue_map
=
self
.
_add_dequeue_ops_for_optimize
(
block
,
startup_program
)
for
index
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
offset
=
index
device
=
op
.
attr
(
self
.
_op_device_key
)
...
...
@@ -4264,19 +4190,25 @@ class PipelineOptimizer(object):
if
len
(
op_role_var
)
==
0
:
continue
assert
len
(
op_role_var
)
%
2
==
0
offset
=
index
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
grad_name
=
op_role_var
[
i
+
1
]
grad_var
=
block
.
vars
[
grad_name
]
assert
grad_name
in
grad_queue_map
queue_name
=
grad_queue_map
[
grad_name
]
param_name
=
op_role_var
[
i
]
param_var
=
block
.
vars
[
param_name
]
new_var_name
=
unique_name
.
generate
(
param_name
)
new_var_name
=
self
.
_append_grad_suffix
(
new_var_name
)
new_var
=
self
.
_create_var
(
block
,
grad_var
,
new_var_name
)
self
.
_rename_arg
(
op
,
grad_name
,
new_var_name
)
block
.
_insert_op
(
index
=
offset
+
1
,
type
=
'enqueue'
,
inputs
=
{
'X'
:
block
.
vars
[
grad_name
]},
type
=
'sum'
,
inputs
=
{
'X'
:
[
grad_var
,
new_var
]},
outputs
=
{
'Out'
:
grad_var
},
attrs
=
{
'queue_name'
:
queue_name
,
self
.
_op_device_key
:
device
,
self
.
_op_role_key
:
self
.
_op_role
.
Backward
self
.
_op_role_key
:
self
.
_op_role
.
Backward
,
self
.
_op_role_var_key
:
op_role_var
})
offset
+=
1
...
...
@@ -4299,6 +4231,7 @@ class PipelineOptimizer(object):
def
_get_device_info
(
self
,
block
):
for
op
in
block
.
ops
:
if
not
op
.
_has_kernel
(
op
.
type
):
continue
op_device
=
op
.
attr
(
self
.
_op_device_key
)
return
op_device
...
...
@@ -4420,8 +4353,11 @@ class PipelineOptimizer(object):
self
.
_insert_enq_deq_ops_for_boundaries
(
main_block
,
origin_main_block
,
startup_program
)
# Step4: add a pair of enqueue and dequeueN for parameter gradients
self
.
_insert_enq_deq_ops_for_update
(
main_block
,
startup_program
)
# Step4: accumulate gradients during backward
# and clear them after update
self
.
_initialize_gradients
(
startup_program
.
global_block
(),
main_block
)
self
.
_accumulate_gradients
(
main_block
)
self
.
_clear_gradients
(
main_block
)
main_program
=
main_block
.
program
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录