Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
04fdb10a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
04fdb10a
编写于
9月 14, 2021
作者:
Y
Yuang Liu
提交者:
GitHub
9月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid performance] Optimize Pipeline Scheduler (#35680)
上级
e46ffaf2
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
53 addition
and
40 deletion
+53
-40
paddle/fluid/framework/device_worker.h
paddle/fluid/framework/device_worker.h
+4
-0
paddle/fluid/framework/section_worker.cc
paddle/fluid/framework/section_worker.cc
+49
-40
未找到文件。
paddle/fluid/framework/device_worker.h
浏览文件 @
04fdb10a
...
...
@@ -601,6 +601,10 @@ class SectionWorker : public DeviceWorker {
std
::
vector
<
std
::
string
>
backward_send_vars_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_
;
std
::
vector
<
OperatorBase
*>
forward_and_lr_ops_
;
std
::
vector
<
OperatorBase
*>
forward_ops_
;
std
::
vector
<
OperatorBase
*>
backward_ops_
;
std
::
vector
<
OperatorBase
*>
optimizer_ops_
;
std
::
shared_ptr
<
framework
::
ProgramDesc
>
program_
;
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
unused_vars_
;
...
...
paddle/fluid/framework/section_worker.cc
浏览文件 @
04fdb10a
...
...
@@ -31,6 +31,33 @@ void SectionWorker::Initialize(const TrainerDesc &desc) {
ops_
.
push_back
(
OpRegistry
::
CreateOp
(
*
op_desc
));
}
for
(
auto
&
op
:
ops_
)
{
// cache the op type during the init part
// reduce unnecessary op visit during running
int
op_role
=
op
->
Attr
<
int
>
(
"op_role"
);
if
((
op_role
==
static_cast
<
int
>
(
OpRole
::
kForward
))
||
(
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)))
||
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kLRSched
)))
{
// forward ops and lr schedule ops, used for first micro step
forward_and_lr_ops_
.
push_back
(
op
.
get
());
if
((
op_role
!=
static_cast
<
int
>
(
OpRole
::
kLRSched
)))
{
// only forward ops, used for second and later micro steps
forward_ops_
.
push_back
(
op
.
get
());
}
}
else
if
((
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
))
||
(
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
))))
{
backward_ops_
.
push_back
(
op
.
get
());
}
else
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
optimizer_ops_
.
push_back
(
op
.
get
());
}
else
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"The op %s is None of LRSched, Forward, Backward or Optimize."
,
op
->
Type
()));
}
}
// if not 1F1B scheduler
if
(
schedule_mode_
!=
1
)
return
;
...
...
@@ -66,25 +93,15 @@ void SectionWorker::RunForward(
int
micro_id
,
std
::
unique_ptr
<
GarbageCollector
>
&
gc
,
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
&
unused_vars_
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool
run_first_mbatch
=
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kForward
))
||
(
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)))
||
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kLRSched
));
bool
run_others
=
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kForward
))
||
(
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)));
if
((
micro_id
==
0
&&
run_first_mbatch
)
||
(
micro_id
!=
0
&&
run_others
))
{
VLOG
(
3
)
<<
"Forward: running op "
<<
op
->
Type
()
<<
" for micro-batch "
<<
micro_id
;
op
->
Run
(
*
microbatch_scopes_
[
micro_id
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
micro_id
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
}
std
::
vector
<
OperatorBase
*>
&
forward_tmp
=
micro_id
==
0
?
forward_and_lr_ops_
:
forward_ops_
;
for
(
auto
&
op
:
forward_tmp
)
{
VLOG
(
3
)
<<
"Forward: running op "
<<
op
->
Type
()
<<
" for micro-batch "
<<
micro_id
;
op
->
Run
(
*
microbatch_scopes_
[
micro_id
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
micro_id
],
op
,
unused_vars_
,
gc
.
get
());
}
}
}
...
...
@@ -93,18 +110,13 @@ void SectionWorker::RunBackward(
int
micro_id
,
std
::
unique_ptr
<
GarbageCollector
>
&
gc
,
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
&
unused_vars_
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
((
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
))
||
(
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
))))
{
VLOG
(
3
)
<<
"Backward: running op "
<<
op
->
Type
()
<<
" for micro-batch "
<<
micro_id
;
op
->
Run
(
*
microbatch_scopes_
[
micro_id
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
micro_id
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
}
for
(
auto
&
op
:
backward_ops_
)
{
VLOG
(
3
)
<<
"Backward: running op "
<<
op
->
Type
()
<<
" for micro-batch "
<<
micro_id
;
op
->
Run
(
*
microbatch_scopes_
[
micro_id
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
micro_id
],
op
,
unused_vars_
,
gc
.
get
());
}
}
}
...
...
@@ -113,15 +125,12 @@ void SectionWorker::RunUpdate(
std
::
unique_ptr
<
GarbageCollector
>
&
gc
,
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
&
unused_vars_
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
VLOG
(
3
)
<<
"Update: running op "
<<
op
->
Type
();
op
->
Run
(
*
microbatch_scopes_
[
num_microbatches_
-
1
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
num_microbatches_
-
1
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
}
for
(
auto
&
op
:
optimizer_ops_
)
{
VLOG
(
3
)
<<
"Update: running op "
<<
op
->
Type
();
op
->
Run
(
*
microbatch_scopes_
[
num_microbatches_
-
1
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
num_microbatches_
-
1
],
op
,
unused_vars_
,
gc
.
get
());
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录