Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f71543ee
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f71543ee
编写于
9月 07, 2020
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'add_timeline' into pipeline_exe_run
上级
27f245cd
0f752e89
变更
5
展开全部
显示空白变更内容
内联
并排
Showing
5 changed file
with
404 addition
and
259 deletion
+404
-259
paddle/fluid/framework/device_worker.h
paddle/fluid/framework/device_worker.h
+1
-0
paddle/fluid/framework/pipeline_trainer.cc
paddle/fluid/framework/pipeline_trainer.cc
+30
-20
paddle/fluid/framework/section_worker.cc
paddle/fluid/framework/section_worker.cc
+278
-55
paddle/fluid/framework/trainer.h
paddle/fluid/framework/trainer.h
+0
-1
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+95
-183
未找到文件。
paddle/fluid/framework/device_worker.h
浏览文件 @
f71543ee
...
@@ -455,6 +455,7 @@ class SectionWorker : public DeviceWorker {
...
@@ -455,6 +455,7 @@ class SectionWorker : public DeviceWorker {
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_
;
static
std
::
mutex
thread_mutex
;
static
std
::
mutex
thread_mutex
;
static
std
::
mutex
cout_mutex
;
static
std
::
condition_variable
thread_condition
;
static
std
::
condition_variable
thread_condition
;
static
bool
threads_completed
;
static
bool
threads_completed
;
std
::
shared_ptr
<
framework
::
ProgramDesc
>
program_
;
std
::
shared_ptr
<
framework
::
ProgramDesc
>
program_
;
...
...
paddle/fluid/framework/pipeline_trainer.cc
浏览文件 @
f71543ee
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
#include <map>
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer.h"
...
@@ -44,7 +45,6 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
...
@@ -44,7 +45,6 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
"must be 1 now, but the value you give is %d."
,
"must be 1 now, but the value you give is %d."
,
num_readers
));
num_readers
));
auto
*
reader
=
readers
[
0
];
auto
*
reader
=
readers
[
0
];
feed_var_names_
=
reader
->
GetUseSlotAlias
();
workers_
.
resize
(
section_num_
);
workers_
.
resize
(
section_num_
);
for
(
int
i
=
0
;
i
<
section_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
section_num_
;
++
i
)
{
...
@@ -123,26 +123,36 @@ void PipelineTrainer::CopyParameters(int section_id, int microbatch_id,
...
@@ -123,26 +123,36 @@ void PipelineTrainer::CopyParameters(int section_id, int microbatch_id,
const
ProgramDesc
&
program
,
const
ProgramDesc
&
program
,
const
platform
::
Place
&
place
)
{
const
platform
::
Place
&
place
)
{
auto
&
global_block
=
program
.
Block
(
0
);
auto
&
global_block
=
program
.
Block
(
0
);
std
::
map
<
std
::
string
,
int
>
param_map
;
for
(
auto
&
var
:
global_block
.
AllVars
())
{
for
(
auto
&
var
:
global_block
.
AllVars
())
{
int
is_feed_var
=
if
(
var
->
Persistable
())
{
std
::
count
(
feed_var_names_
.
begin
(),
feed_var_names_
.
end
(),
var
->
Name
());
param_map
[
var
->
Name
()]
=
1
;
if
((
var
->
Persistable
()
||
is_feed_var
)
&&
microbatch_id
==
0
)
{
}
if
(
is_feed_var
)
{
}
auto
*
new_ptr
=
minibatch_scopes_
[
section_id
]
->
Var
(
var
->
Name
());
for
(
auto
&
var
:
global_block
.
AllVars
())
{
VLOG
(
3
)
<<
"data name: "
<<
var
->
Name
()
<<
", ptr: "
<<
new_ptr
;
bool
is_param_grad
=
false
;
InitializeVariable
(
new_ptr
,
var
->
GetType
());
size_t
pos
=
0
;
}
else
{
if
((
pos
=
var
->
Name
().
find
(
kGradVarSuffix
))
!=
std
::
string
::
npos
)
{
auto
prefix_name
=
var
->
Name
().
substr
(
0
,
pos
);
if
(
param_map
.
find
(
prefix_name
)
!=
param_map
.
end
())
{
is_param_grad
=
true
;
}
}
VLOG
(
3
)
<<
"Var name: "
<<
var
->
Name
();
if
((
var
->
Persistable
()
||
is_param_grad
)
&&
microbatch_id
==
0
)
{
auto
*
ptr
=
root_scope_
->
FindVar
(
var
->
Name
());
auto
*
ptr
=
root_scope_
->
FindVar
(
var
->
Name
());
auto
*
new_ptr
=
minibatch_scopes_
[
section_id
]
->
Var
(
var
->
Name
());
auto
*
new_ptr
=
minibatch_scopes_
[
section_id
]
->
Var
(
var
->
Name
());
VLOG
(
3
)
<<
"Create persistable var "
<<
var
->
Name
()
<<
" for minibatch "
VLOG
(
3
)
<<
"Create persistable var "
<<
var
->
Name
()
<<
" for minibatch "
<<
section_id
<<
", which pointer is "
<<
new_ptr
;
<<
section_id
<<
", which pointer is "
<<
new_ptr
;
InitializeVariable
(
new_ptr
,
var
->
GetType
());
InitializeVariable
(
new_ptr
,
var
->
GetType
());
if
(
is_param_grad
)
{
continue
;
}
const
LoDTensor
&
root_tensor
=
ptr
->
Get
<
LoDTensor
>
();
const
LoDTensor
&
root_tensor
=
ptr
->
Get
<
LoDTensor
>
();
LoDTensor
*
minibatch_tensor
=
new_ptr
->
GetMutable
<
LoDTensor
>
();
LoDTensor
*
minibatch_tensor
=
new_ptr
->
GetMutable
<
LoDTensor
>
();
TensorCopy
(
*
static_cast
<
const
Tensor
*>
(
&
root_tensor
),
place
,
TensorCopy
(
*
static_cast
<
const
Tensor
*>
(
&
root_tensor
),
place
,
static_cast
<
Tensor
*>
(
minibatch_tensor
));
static_cast
<
Tensor
*>
(
minibatch_tensor
));
}
}
else
if
(
!
var
->
Persistable
()
&&
!
is_param_grad
)
{
}
else
if
(
!
var
->
Persistable
()
&&
!
is_feed_var
)
{
auto
*
ptr
=
auto
*
ptr
=
microbatch_scopes_
[
section_id
][
microbatch_id
]
->
Var
(
var
->
Name
());
microbatch_scopes_
[
section_id
][
microbatch_id
]
->
Var
(
var
->
Name
());
VLOG
(
3
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for section "
VLOG
(
3
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for section "
...
@@ -244,7 +254,7 @@ void PipelineTrainer::Finalize() {
...
@@ -244,7 +254,7 @@ void PipelineTrainer::Finalize() {
const
LoDTensor
&
minibatch_tensor
=
minibatch_ptr
->
Get
<
LoDTensor
>
();
const
LoDTensor
&
minibatch_tensor
=
minibatch_ptr
->
Get
<
LoDTensor
>
();
TensorCopy
(
*
static_cast
<
const
Tensor
*>
(
&
minibatch_tensor
),
places_
[
0
],
TensorCopy
(
*
static_cast
<
const
Tensor
*>
(
&
minibatch_tensor
),
places_
[
0
],
static_cast
<
Tensor
*>
(
root_tensor
));
static_cast
<
Tensor
*>
(
root_tensor
));
VLOG
(
4
)
<<
"Copy persitable var "
<<
var
->
Name
()
<<
" to root scope"
;
VLOG
(
3
)
<<
"Copy persitable var "
<<
var
->
Name
()
<<
" to root scope"
;
}
}
}
}
}
}
...
...
paddle/fluid/framework/section_worker.cc
浏览文件 @
f71543ee
...
@@ -32,6 +32,7 @@ namespace framework {
...
@@ -32,6 +32,7 @@ namespace framework {
std
::
atomic
<
int
>
SectionWorker
::
cpu_id_
(
0
);
std
::
atomic
<
int
>
SectionWorker
::
cpu_id_
(
0
);
std
::
mutex
SectionWorker
::
thread_mutex
;
std
::
mutex
SectionWorker
::
thread_mutex
;
std
::
mutex
SectionWorker
::
cout_mutex
;
std
::
condition_variable
SectionWorker
::
thread_condition
;
std
::
condition_variable
SectionWorker
::
thread_condition
;
bool
SectionWorker
::
threads_completed
=
false
;
bool
SectionWorker
::
threads_completed
=
false
;
uint64_t
SectionWorker
::
batch_id_
(
0
);
uint64_t
SectionWorker
::
batch_id_
(
0
);
...
@@ -103,9 +104,14 @@ void SectionWorker::TrainFiles() {
...
@@ -103,9 +104,14 @@ void SectionWorker::TrainFiles() {
}
}
#endif
#endif
platform
::
Timer
batch_timer
;
if
(
thread_id_
==
0
)
{
if
(
thread_id_
==
0
)
{
while
(
true
)
{
while
(
true
)
{
// Start a minibatch.
// Start a minibatch.
// real number of microbatches run
int
real_microbatch_num
=
0
;
batch_timer
.
Start
();
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
try
{
try
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
...
@@ -137,17 +143,21 @@ void SectionWorker::TrainFiles() {
...
@@ -137,17 +143,21 @@ void SectionWorker::TrainFiles() {
VLOG
(
3
)
<<
"called notify all"
;
VLOG
(
3
)
<<
"called notify all"
;
thread_condition
.
notify_all
();
thread_condition
.
notify_all
();
VLOG
(
0
)
<<
"EOF encountered"
;
VLOG
(
0
)
<<
"EOF encountered"
;
return
;
break
;
}
}
if
(
i
==
0
)
{
{
real_microbatch_num
+=
1
;
batch_id_
+=
1
;
VLOG
(
3
)
<<
"called notify all"
;
VLOG
(
3
)
<<
"called notify all"
;
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
batch_id_
+=
1
;
thread_condition
.
notify_all
();
thread_condition
.
notify_all
();
}
}
}
}
dev_ctx_
->
Wait
();
VLOG
(
0
)
<<
"real_microbatch_num for thread 0 "
<<
real_microbatch_num
;
// backward pass
// backward pass
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
real_microbatch_num
;
++
i
)
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
...
@@ -163,6 +173,12 @@ void SectionWorker::TrainFiles() {
...
@@ -163,6 +173,12 @@ void SectionWorker::TrainFiles() {
}
}
}
}
}
}
dev_ctx_
->
Wait
();
if
(
real_microbatch_num
==
0
)
{
batch_timer
.
Pause
();
VLOG
(
0
)
<<
"batch time: "
<<
batch_timer
.
ElapsedUS
();
return
;
}
// update pass
// update pass
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
...
@@ -177,9 +193,21 @@ void SectionWorker::TrainFiles() {
...
@@ -177,9 +193,21 @@ void SectionWorker::TrainFiles() {
}
}
}
}
dev_ctx_
->
Wait
();
dev_ctx_
->
Wait
();
batch_timer
.
Pause
();
VLOG
(
0
)
<<
"batch time: "
<<
batch_timer
.
ElapsedUS
();
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
if
(
threads_completed
)
{
return
;
}
}
}
}
}
else
{
}
else
{
while
(
true
)
{
while
(
true
)
{
// forward pass:
bool
local_completed
=
false
;
int
real_microbatch_num
=
0
;
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
{
{
PADDLE_ENFORCE_LE
(
PADDLE_ENFORCE_LE
(
local_batch_id_
,
batch_id_
,
local_batch_id_
,
batch_id_
,
...
@@ -197,13 +225,13 @@ void SectionWorker::TrainFiles() {
...
@@ -197,13 +225,13 @@ void SectionWorker::TrainFiles() {
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" completed."
;
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" completed."
;
lk
.
unlock
();
lk
.
unlock
();
threads_completed
=
false
;
threads_completed
=
false
;
return
;
local_completed
=
true
;
break
;
}
}
lk
.
unlock
();
lk
.
unlock
();
local_batch_id_
+=
1
;
local_batch_id_
+=
1
;
real_microbatch_num
+=
1
;
}
}
// forward pass:
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
// We run op with op_role = kLRSched only for the first microbatch
// We run op with op_role = kLRSched only for the first microbatch
...
@@ -227,8 +255,9 @@ void SectionWorker::TrainFiles() {
...
@@ -227,8 +255,9 @@ void SectionWorker::TrainFiles() {
}
}
}
}
}
}
dev_ctx_
->
Wait
();
// backward pass
// backward pass
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
real_microbatch_num
;
++
i
)
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
...
@@ -244,7 +273,11 @@ void SectionWorker::TrainFiles() {
...
@@ -244,7 +273,11 @@ void SectionWorker::TrainFiles() {
}
}
}
}
}
}
dev_ctx_
->
Wait
();
// update pass
// update pass
if
(
real_microbatch_num
==
0
)
{
return
;
}
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
...
@@ -258,6 +291,9 @@ void SectionWorker::TrainFiles() {
...
@@ -258,6 +291,9 @@ void SectionWorker::TrainFiles() {
}
}
}
}
dev_ctx_
->
Wait
();
dev_ctx_
->
Wait
();
if
(
local_completed
)
{
return
;
}
}
}
}
}
}
}
...
@@ -307,14 +343,20 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -307,14 +343,20 @@ void SectionWorker::TrainFilesWithProfiler() {
#endif
#endif
if
(
thread_id_
==
0
)
{
if
(
thread_id_
==
0
)
{
struct
timeval
start
;
struct
timeval
end
;
struct
timeval
micro_start
;
struct
timeval
micro_end
;
while
(
true
)
{
while
(
true
)
{
// Start a minibatch.
// Start a minibatch.
// int batch_size = 0;
batch_timer
.
Start
();
batch_timer
.
Start
();
int
real_microbatch_num
=
0
;
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
try
{
try
{
int
op_idx
=
0
;
int
op_idx
=
0
;
gettimeofday
(
&
micro_start
,
NULL
);
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
gettimeofday
(
&
start
,
NULL
);
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
// We run op with op_role = kLRSched only for the first microbatch
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
...
@@ -335,7 +377,9 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -335,7 +377,9 @@ void SectionWorker::TrainFilesWithProfiler() {
DeleteUnusedTensors
(
*
microbatch_scopes_
[
i
],
op
.
get
(),
DeleteUnusedTensors
(
*
microbatch_scopes_
[
i
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
unused_vars_
,
gc
.
get
());
}
}
cudaDeviceSynchronize
();
timeline
.
Pause
();
timeline
.
Pause
();
gettimeofday
(
&
end
,
NULL
);
auto
time
=
timeline
.
ElapsedUS
();
auto
time
=
timeline
.
ElapsedUS
();
op_total_time
[
op_idx
]
+=
time
;
op_total_time
[
op_idx
]
+=
time
;
if
(
time
>
op_max_time
[
op_idx
])
{
if
(
time
>
op_max_time
[
op_idx
])
{
...
@@ -346,9 +390,30 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -346,9 +390,30 @@ void SectionWorker::TrainFilesWithProfiler() {
}
}
op_count
[
op_idx
]
+=
1
;
op_count
[
op_idx
]
+=
1
;
op_total_time
[
op_idx
]
+=
time
;
op_total_time
[
op_idx
]
+=
time
;
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
cout_mutex
);
std
::
cout
<<
std
::
fixed
;
std
::
cout
.
precision
(
0
);
std
::
cout
<<
"::FWD:B["
<<
batch_id_
<<
"]:SEC["
<<
thread_id_
<<
"]:SCOPE["
<<
i
<<
"]:OP["
<<
op
->
Type
()
<<
"]:START["
<<
start
.
tv_sec
*
1e6
+
start
.
tv_usec
<<
"]:END["
<<
end
.
tv_sec
*
1e6
+
end
.
tv_usec
<<
"]"
<<
std
::
endl
;
}
}
}
op_idx
++
;
op_idx
++
;
}
}
gettimeofday
(
&
micro_end
,
NULL
);
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
cout_mutex
);
std
::
cout
<<
std
::
fixed
;
std
::
cout
.
precision
(
0
);
std
::
cout
<<
"!!FWD:B["
<<
batch_id_
<<
"]:SEC["
<<
thread_id_
<<
"]:START["
<<
micro_start
.
tv_sec
*
1e6
+
micro_start
.
tv_usec
<<
"]:END["
<<
micro_end
.
tv_sec
*
1e6
+
micro_end
.
tv_usec
<<
"]"
<<
std
::
endl
;
}
}
catch
(
platform
::
EOFException
&
)
{
}
catch
(
platform
::
EOFException
&
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
threads_completed
=
true
;
threads_completed
=
true
;
...
@@ -363,19 +428,23 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -363,19 +428,23 @@ void SectionWorker::TrainFilesWithProfiler() {
<<
", mean_time: "
<<
op_total_time
[
i
]
/
op_count
[
i
];
<<
", mean_time: "
<<
op_total_time
[
i
]
/
op_count
[
i
];
}
}
VLOG
(
0
)
<<
"================================"
;
VLOG
(
0
)
<<
"================================"
;
return
;
break
;
}
}
if
(
i
==
0
)
{
{
VLOG
(
3
)
<<
"called notify all"
;
VLOG
(
3
)
<<
"called notify all"
;
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
real_microbatch_num
+=
1
;
batch_id_
+=
1
;
batch_id_
+=
1
;
thread_condition
.
notify_all
();
thread_condition
.
notify_all
();
}
}
}
}
dev_ctx_
->
Wait
();
// backward pass
// backward pass
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
real_microbatch_num
;
++
i
)
{
int
op_idx
=
0
;
int
op_idx
=
0
;
gettimeofday
(
&
micro_start
,
NULL
);
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
gettimeofday
(
&
start
,
NULL
);
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
...
@@ -388,6 +457,8 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -388,6 +457,8 @@ void SectionWorker::TrainFilesWithProfiler() {
DeleteUnusedTensors
(
*
microbatch_scopes_
[
i
],
op
.
get
(),
DeleteUnusedTensors
(
*
microbatch_scopes_
[
i
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
unused_vars_
,
gc
.
get
());
}
}
cudaDeviceSynchronize
();
gettimeofday
(
&
end
,
NULL
);
timeline
.
Pause
();
timeline
.
Pause
();
auto
time
=
timeline
.
ElapsedUS
();
auto
time
=
timeline
.
ElapsedUS
();
op_total_time
[
op_idx
]
+=
time
;
op_total_time
[
op_idx
]
+=
time
;
...
@@ -399,13 +470,42 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -399,13 +470,42 @@ void SectionWorker::TrainFilesWithProfiler() {
}
}
op_count
[
op_idx
]
+=
1
;
op_count
[
op_idx
]
+=
1
;
op_total_time
[
op_idx
]
+=
time
;
op_total_time
[
op_idx
]
+=
time
;
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
cout_mutex
);
std
::
cout
<<
std
::
fixed
;
std
::
cout
.
precision
(
0
);
std
::
cout
<<
"::BWD:B["
<<
batch_id_
<<
"]:SEC["
<<
thread_id_
<<
"]:SCOPE["
<<
i
<<
"]:OP["
<<
op
->
Type
()
<<
"]:START["
<<
start
.
tv_sec
*
1e6
+
start
.
tv_usec
<<
"]:END["
<<
end
.
tv_sec
*
1e6
+
end
.
tv_usec
<<
"]"
<<
std
::
endl
;
}
}
}
op_idx
++
;
op_idx
++
;
}
}
gettimeofday
(
&
micro_end
,
NULL
);
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
cout_mutex
);
std
::
cout
<<
std
::
fixed
;
std
::
cout
.
precision
(
0
);
std
::
cout
<<
"!!BWD:B["
<<
batch_id_
<<
"]:SEC["
<<
thread_id_
<<
"]:START["
<<
micro_start
.
tv_sec
*
1e6
+
micro_start
.
tv_usec
<<
"]:END["
<<
micro_end
.
tv_sec
*
1e6
+
micro_end
.
tv_usec
<<
"]"
<<
std
::
endl
;
}
}
dev_ctx_
->
Wait
();
if
(
real_microbatch_num
==
0
)
{
batch_timer
.
Pause
();
VLOG
(
0
)
<<
"batch time: "
<<
batch_timer
.
ElapsedUS
();
return
;
}
}
// update pass
// update pass
int
op_idx
=
0
;
int
op_idx
=
0
;
gettimeofday
(
&
micro_start
,
NULL
);
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
gettimeofday
(
&
start
,
NULL
);
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
VLOG
(
3
)
<<
"running an op "
<<
op
->
Type
()
<<
" for "
<<
thread_id_
VLOG
(
3
)
<<
"running an op "
<<
op
->
Type
()
<<
" for "
<<
thread_id_
...
@@ -416,6 +516,8 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -416,6 +516,8 @@ void SectionWorker::TrainFilesWithProfiler() {
DeleteUnusedTensors
(
*
microbatch_scopes_
[
num_microbatches_
-
1
],
DeleteUnusedTensors
(
*
microbatch_scopes_
[
num_microbatches_
-
1
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
op
.
get
(),
unused_vars_
,
gc
.
get
());
}
}
cudaDeviceSynchronize
();
gettimeofday
(
&
end
,
NULL
);
timeline
.
Pause
();
timeline
.
Pause
();
auto
time
=
timeline
.
ElapsedUS
();
auto
time
=
timeline
.
ElapsedUS
();
op_total_time
[
op_idx
]
+=
time
;
op_total_time
[
op_idx
]
+=
time
;
...
@@ -427,15 +529,53 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -427,15 +529,53 @@ void SectionWorker::TrainFilesWithProfiler() {
}
}
op_count
[
op_idx
]
+=
1
;
op_count
[
op_idx
]
+=
1
;
op_total_time
[
op_idx
]
+=
time
;
op_total_time
[
op_idx
]
+=
time
;
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
cout_mutex
);
std
::
cout
<<
std
::
fixed
;
std
::
cout
.
precision
(
0
);
std
::
cout
<<
"::UPD:B["
<<
batch_id_
<<
"]:SEC["
<<
thread_id_
<<
"]:SCOPE["
<<
num_microbatches_
<<
"]:OP["
<<
op
->
Type
()
<<
"]:START["
<<
start
.
tv_sec
*
1e6
+
start
.
tv_usec
<<
"]:END["
<<
end
.
tv_sec
*
1e6
+
end
.
tv_usec
<<
"]"
<<
std
::
endl
;
}
}
}
op_idx
++
;
op_idx
++
;
}
}
gettimeofday
(
&
micro_end
,
NULL
);
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
cout_mutex
);
std
::
cout
<<
std
::
fixed
;
std
::
cout
.
precision
(
0
);
std
::
cout
<<
"!!UPD:B["
<<
batch_id_
<<
"]:SEC["
<<
thread_id_
<<
"]:START["
<<
micro_start
.
tv_sec
*
1e6
+
micro_start
.
tv_usec
<<
"]:END["
<<
micro_end
.
tv_sec
*
1e6
+
micro_end
.
tv_usec
<<
"]"
<<
std
::
endl
;
}
dev_ctx_
->
Wait
();
dev_ctx_
->
Wait
();
batch_timer
.
Pause
();
batch_timer
.
Pause
();
VLOG
(
0
)
<<
"batch time: "
<<
batch_timer
.
ElapsedUS
();
VLOG
(
0
)
<<
"batch time: "
<<
batch_timer
.
ElapsedUS
();
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
thread_mutex
);
if
(
threads_completed
)
{
return
;
}
}
}
}
}
else
{
}
else
{
struct
timeval
start
;
struct
timeval
end
;
struct
timeval
micro_start
;
struct
timeval
micro_end
;
cudaEvent_t
cu_start
,
cu_stop
;
cudaEventCreate
(
&
cu_start
);
cudaEventCreate
(
&
cu_stop
);
bool
local_completed
=
false
;
while
(
true
)
{
while
(
true
)
{
// forward pass:
int
real_microbatch_num
=
0
;
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
{
{
PADDLE_ENFORCE_LE
(
PADDLE_ENFORCE_LE
(
local_batch_id_
,
batch_id_
,
local_batch_id_
,
batch_id_
,
...
@@ -450,25 +590,27 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -450,25 +590,27 @@ void SectionWorker::TrainFilesWithProfiler() {
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" local_batch_id_ "
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" local_batch_id_ "
<<
local_batch_id_
<<
" batch_id_ "
<<
batch_id_
;
<<
local_batch_id_
<<
" batch_id_ "
<<
batch_id_
;
if
(
threads_completed
)
{
if
(
threads_completed
)
{
local_completed
=
true
;
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" completed."
;
VLOG
(
3
)
<<
"thread "
<<
thread_id_
<<
" completed."
;
lk
.
unlock
();
lk
.
unlock
();
VLOG
(
0
)
<<
"============timeline============"
;
VLOG
(
0
)
<<
"============timeline============"
;
for
(
size_t
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
VLOG
(
0
)
<<
"op: "
<<
op_name
[
i
]
<<
", max_time: "
<<
op_max_time
[
i
]
VLOG
(
0
)
<<
"op: "
<<
op_name
[
i
]
<<
", max_time: "
<<
op_max_time
[
i
]
<<
", min_time: "
<<
op_min_time
[
i
]
<<
", min_time: "
<<
op_min_time
[
i
]
<<
", mean_time: "
<<
op_total_time
[
i
]
/
op_count
[
i
];
<<
", mean_time: "
<<
op_total_time
[
i
]
/
op_count
[
i
];
}
}
VLOG
(
0
)
<<
"================================"
;
VLOG
(
0
)
<<
"================================"
;
threads_completed
=
false
;
break
;
return
;
}
}
lk
.
unlock
();
lk
.
unlock
();
real_microbatch_num
+=
1
;
local_batch_id_
+=
1
;
local_batch_id_
+=
1
;
}
}
// forward pass:
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
int
op_idx
=
0
;
int
op_idx
=
0
;
gettimeofday
(
&
micro_start
,
NULL
);
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
gettimeofday
(
&
start
,
NULL
);
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
// We run op with op_role = kLRSched only for the first microbatch
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
...
@@ -489,6 +631,8 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -489,6 +631,8 @@ void SectionWorker::TrainFilesWithProfiler() {
DeleteUnusedTensors
(
*
microbatch_scopes_
[
i
],
op
.
get
(),
DeleteUnusedTensors
(
*
microbatch_scopes_
[
i
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
unused_vars_
,
gc
.
get
());
}
}
cudaDeviceSynchronize
();
gettimeofday
(
&
end
,
NULL
);
timeline
.
Pause
();
timeline
.
Pause
();
auto
time
=
timeline
.
ElapsedUS
();
auto
time
=
timeline
.
ElapsedUS
();
op_total_time
[
op_idx
]
+=
time
;
op_total_time
[
op_idx
]
+=
time
;
...
@@ -500,14 +644,38 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -500,14 +644,38 @@ void SectionWorker::TrainFilesWithProfiler() {
}
}
op_count
[
op_idx
]
+=
1
;
op_count
[
op_idx
]
+=
1
;
op_total_time
[
op_idx
]
+=
time
;
op_total_time
[
op_idx
]
+=
time
;
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
cout_mutex
);
std
::
cout
<<
std
::
fixed
;
std
::
cout
.
precision
(
0
);
std
::
cout
<<
"::FWD:B["
<<
local_batch_id_
<<
"]:SEC["
<<
thread_id_
<<
"]:SCOPE["
<<
i
<<
"]:OP["
<<
op
->
Type
()
<<
"]:START["
<<
start
.
tv_sec
*
1e6
+
start
.
tv_usec
<<
"]:END["
<<
end
.
tv_sec
*
1e6
+
end
.
tv_usec
<<
"]"
<<
std
::
endl
;
}
}
}
op_idx
++
;
op_idx
++
;
}
}
gettimeofday
(
&
micro_end
,
NULL
);
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
cout_mutex
);
std
::
cout
<<
std
::
fixed
;
std
::
cout
.
precision
(
0
);
std
::
cout
<<
"!!FWD:B["
<<
batch_id_
<<
"]:SEC["
<<
thread_id_
<<
"]:START["
<<
micro_start
.
tv_sec
*
1e6
+
micro_start
.
tv_usec
<<
"]:END["
<<
micro_end
.
tv_sec
*
1e6
+
micro_end
.
tv_usec
<<
"]"
<<
std
::
endl
;
}
}
}
dev_ctx_
->
Wait
();
// backward pass
// backward pass
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
real_microbatch_num
;
++
i
)
{
int
op_idx
=
0
;
int
op_idx
=
0
;
gettimeofday
(
&
micro_start
,
NULL
);
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
gettimeofday
(
&
start
,
NULL
);
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
...
@@ -520,6 +688,8 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -520,6 +688,8 @@ void SectionWorker::TrainFilesWithProfiler() {
DeleteUnusedTensors
(
*
microbatch_scopes_
[
i
],
op
.
get
(),
DeleteUnusedTensors
(
*
microbatch_scopes_
[
i
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
unused_vars_
,
gc
.
get
());
}
}
cudaDeviceSynchronize
();
gettimeofday
(
&
end
,
NULL
);
timeline
.
Pause
();
timeline
.
Pause
();
auto
time
=
timeline
.
ElapsedUS
();
auto
time
=
timeline
.
ElapsedUS
();
op_total_time
[
op_idx
]
+=
time
;
op_total_time
[
op_idx
]
+=
time
;
...
@@ -531,13 +701,40 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -531,13 +701,40 @@ void SectionWorker::TrainFilesWithProfiler() {
}
}
op_count
[
op_idx
]
+=
1
;
op_count
[
op_idx
]
+=
1
;
op_total_time
[
op_idx
]
+=
time
;
op_total_time
[
op_idx
]
+=
time
;
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
cout_mutex
);
std
::
cout
<<
std
::
fixed
;
std
::
cout
.
precision
(
0
);
std
::
cout
<<
"::BWD:B["
<<
local_batch_id_
<<
"]:SEC["
<<
thread_id_
<<
"]:SCOPE["
<<
i
<<
"]:OP["
<<
op
->
Type
()
<<
"]:START["
<<
start
.
tv_sec
*
1e6
+
start
.
tv_usec
<<
"]:END["
<<
end
.
tv_sec
*
1e6
+
end
.
tv_usec
<<
"]"
<<
std
::
endl
;
}
}
}
op_idx
++
;
op_idx
++
;
}
}
gettimeofday
(
&
micro_end
,
NULL
);
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
cout_mutex
);
std
::
cout
<<
std
::
fixed
;
std
::
cout
.
precision
(
0
);
std
::
cout
<<
"!!BWD:B["
<<
batch_id_
<<
"]:SEC["
<<
thread_id_
<<
"]:START["
<<
micro_start
.
tv_sec
*
1e6
+
micro_start
.
tv_usec
<<
"]:END["
<<
micro_end
.
tv_sec
*
1e6
+
micro_end
.
tv_usec
<<
"]"
<<
std
::
endl
;
}
}
dev_ctx_
->
Wait
();
if
(
real_microbatch_num
==
0
)
{
return
;
}
}
// update pass
// update pass
int
op_idx
=
0
;
int
op_idx
=
0
;
gettimeofday
(
&
micro_start
,
NULL
);
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
gettimeofday
(
&
start
,
NULL
);
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
VLOG
(
3
)
<<
"running an op "
<<
op
->
Type
()
<<
" for "
<<
thread_id_
VLOG
(
3
)
<<
"running an op "
<<
op
->
Type
()
<<
" for "
<<
thread_id_
...
@@ -548,6 +745,8 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -548,6 +745,8 @@ void SectionWorker::TrainFilesWithProfiler() {
DeleteUnusedTensors
(
*
microbatch_scopes_
[
num_microbatches_
-
1
],
DeleteUnusedTensors
(
*
microbatch_scopes_
[
num_microbatches_
-
1
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
op
.
get
(),
unused_vars_
,
gc
.
get
());
}
}
cudaDeviceSynchronize
();
gettimeofday
(
&
end
,
NULL
);
timeline
.
Pause
();
timeline
.
Pause
();
auto
time
=
timeline
.
ElapsedUS
();
auto
time
=
timeline
.
ElapsedUS
();
op_total_time
[
op_idx
]
+=
time
;
op_total_time
[
op_idx
]
+=
time
;
...
@@ -559,10 +758,34 @@ void SectionWorker::TrainFilesWithProfiler() {
...
@@ -559,10 +758,34 @@ void SectionWorker::TrainFilesWithProfiler() {
}
}
op_count
[
op_idx
]
+=
1
;
op_count
[
op_idx
]
+=
1
;
op_total_time
[
op_idx
]
+=
time
;
op_total_time
[
op_idx
]
+=
time
;
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
cout_mutex
);
std
::
cout
<<
std
::
fixed
;
std
::
cout
.
precision
(
0
);
std
::
cout
<<
"::UPD:B["
<<
batch_id_
<<
"]:SEC["
<<
thread_id_
<<
"]:SCOPE["
<<
num_microbatches_
<<
"]:OP["
<<
op
->
Type
()
<<
"]:START["
<<
start
.
tv_sec
*
1e6
+
start
.
tv_usec
<<
"]:END["
<<
end
.
tv_sec
*
1e6
+
end
.
tv_usec
<<
"]"
<<
std
::
endl
;
}
}
}
op_idx
++
;
op_idx
++
;
}
}
gettimeofday
(
&
micro_end
,
NULL
);
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
cout_mutex
);
std
::
cout
<<
std
::
fixed
;
std
::
cout
.
precision
(
0
);
std
::
cout
<<
"!!UPD:B["
<<
batch_id_
<<
"]:SEC["
<<
thread_id_
<<
"]:START["
<<
micro_start
.
tv_sec
*
1e6
+
micro_start
.
tv_usec
<<
"]:END["
<<
micro_end
.
tv_sec
*
1e6
+
micro_end
.
tv_usec
<<
"]"
<<
std
::
endl
;
}
dev_ctx_
->
Wait
();
dev_ctx_
->
Wait
();
if
(
local_completed
)
{
return
;
}
}
}
}
}
}
}
...
...
paddle/fluid/framework/trainer.h
浏览文件 @
f71543ee
...
@@ -223,7 +223,6 @@ class PipelineTrainer : public TrainerBase {
...
@@ -223,7 +223,6 @@ class PipelineTrainer : public TrainerBase {
int
section_num_
;
int
section_num_
;
int
num_microbatches_
;
int
num_microbatches_
;
int
start_cpu_core_id_
;
int
start_cpu_core_id_
;
std
::
vector
<
std
::
string
>
feed_var_names_
;
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
skip_vars_
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
skip_vars_
;
TrainerDesc
trainer_desc_
;
TrainerDesc
trainer_desc_
;
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
f71543ee
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录