Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d2c81529
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d2c81529
编写于
3月 02, 2021
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update fb_scheduler
上级
6760cbd9
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
121 addition
and
14 deletion
+121
-14
paddle/fluid/framework/section_worker.cc
paddle/fluid/framework/section_worker.cc
+121
-14
未找到文件。
paddle/fluid/framework/section_worker.cc
浏览文件 @
d2c81529
...
...
@@ -48,7 +48,18 @@ void SectionWorker::TrainFiles() {
#endif
}
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
auto
startup_steps
=
num_pipeline_stages_
-
pipeline_stage_
-
1
;
VLOG
(
3
)
<<
"startup_steps:"
<<
startup_steps
<<
", num_stages: "
<<
num_pipeline_stages_
<<
", stage:"
<<
pipeline_stage_
;
if
(
startup_steps
>
num_microbatches_
)
{
startup_steps
=
num_microbatches_
;
}
int
fw_step
=
0
;
int
bw_step
=
0
;
// startup phase
while
(
fw_step
<
startup_steps
)
{
VLOG
(
3
)
<<
"to run forward batch:"
<<
fw_step
;
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
// We run op with op_role = kLRSched only for the first microbatch
...
...
@@ -60,37 +71,129 @@ void SectionWorker::TrainFiles() {
bool
run_others
=
op_role
==
static_cast
<
int
>
(
OpRole
::
kForward
)
||
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
));
if
((
i
==
0
&&
run_first_mbatch
)
||
(
i
!=
0
&&
run_others
))
{
if
((
fw_step
==
0
&&
run_first_mbatch
)
||
(
fw_step
!=
0
&&
run_others
))
{
VLOG
(
3
)
<<
"Forward: running op "
<<
op
->
Type
()
<<
" for micro-batch "
<<
i
;
op
->
Run
(
*
microbatch_scopes_
[
i
],
place_
);
<<
fw_step
;
op
->
Run
(
*
microbatch_scopes_
[
fw_step
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
i
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
DeleteUnusedTensors
(
*
microbatch_scopes_
[
fw_step
],
op
.
get
()
,
unused_vars_
,
gc
.
get
());
}
}
}
cudaDeviceSynchronize
()
;
fw_step
+=
1
;
}
// backward pass
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
// 1f1b phase
while
(
fw_step
<
num_microbatches_
)
{
VLOG
(
3
)
<<
"to run forward batch:"
<<
fw_step
;
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool
run_first_mbatch
=
op_role
==
static_cast
<
int
>
(
OpRole
::
kForward
)
||
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
))
||
op_role
==
static_cast
<
int
>
(
OpRole
::
kLRSched
);
bool
run_others
=
op_role
==
static_cast
<
int
>
(
OpRole
::
kForward
)
||
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
));
if
((
fw_step
==
0
&&
run_first_mbatch
)
||
(
fw_step
!=
0
&&
run_others
))
{
VLOG
(
3
)
<<
"Forward: running op "
<<
op
->
Type
()
<<
" for micro-batch "
<<
fw_step
;
op
->
Run
(
*
microbatch_scopes_
[
fw_step
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
fw_step
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
}
}
}
fw_step
+=
1
;
VLOG
(
3
)
<<
"to run backward batch:"
<<
bw_step
;
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)))
{
VLOG
(
3
)
<<
"Backward: running op "
<<
op
->
Type
()
<<
" for micro-batch "
<<
i
;
op
->
Run
(
*
microbatch_scopes_
[
i
],
place_
);
<<
bw_step
;
op
->
Run
(
*
microbatch_scopes_
[
bw_step
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
i
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
DeleteUnusedTensors
(
*
microbatch_scopes_
[
bw_step
],
op
.
get
()
,
unused_vars_
,
gc
.
get
());
}
}
}
cudaDeviceSynchronize
()
;
bw_step
+=
1
;
}
// backward phase
while
(
bw_step
<
num_microbatches_
)
{
VLOG
(
3
)
<<
"to run backward batch:"
<<
bw_step
;
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)))
{
VLOG
(
3
)
<<
"Backward: running op "
<<
op
->
Type
()
<<
" for micro-batch "
<<
bw_step
;
op
->
Run
(
*
microbatch_scopes_
[
bw_step
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
bw_step
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
}
}
}
bw_step
+=
1
;
}
// for (int i = 0; i < num_microbatches_; ++i) {
// for (auto& op : ops_) {
// int op_role = op->Attr<int>(std::string("op_role"));
// // We run op with op_role = kLRSched only for the first microbatch
// // to avoid increasing the @LR_DECAY_STEP@ multiple times.
// bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward)
// ||
// op_role == (static_cast<int>(OpRole::kForward)
// |
// static_cast<int>(OpRole::kLoss)) ||
// op_role == static_cast<int>(OpRole::kLRSched);
// bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
// op_role == (static_cast<int>(OpRole::kForward) |
// static_cast<int>(OpRole::kLoss));
// if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
// VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch
// "
// << i;
// op->Run(*microbatch_scopes_[i], place_);
// if (gc) {
// DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
// gc.get());
// }
// }
// }
// cudaDeviceSynchronize();
// }
// // backward pass
// for (int i = 0; i < num_microbatches_; ++i) {
// for (auto& op : ops_) {
// int op_role = op->Attr<int>(std::string("op_role"));
// if (op_role == static_cast<int>(OpRole::kBackward) ||
// op_role == (static_cast<int>(OpRole::kBackward) |
// static_cast<int>(OpRole::kLoss))) {
// VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch
// "
// << i;
// op->Run(*microbatch_scopes_[i], place_);
// if (gc) {
// DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
// gc.get());
// }
// }
// }
// cudaDeviceSynchronize();
// }
// update pass
for
(
auto
&
op
:
ops_
)
{
...
...
@@ -99,6 +202,10 @@ void SectionWorker::TrainFiles() {
VLOG
(
3
)
<<
"Update: running op "
<<
op
->
Type
();
op
->
Run
(
*
microbatch_scopes_
[
num_microbatches_
-
1
],
place_
);
if
(
gc
)
{
// for (int i = 0; i < num_microbatches_; ++i) {
// DeleteUnusedTensors(*microbatch_scopes_[i],
// op.get(), unused_vars_, gc.get());
//}
DeleteUnusedTensors
(
*
microbatch_scopes_
[
num_microbatches_
-
1
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录