Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
3fd3b663
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3fd3b663
编写于
9月 18, 2019
作者:
Z
Zeng Jinle
提交者:
GitHub
9月 18, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix gc bug in controlflow ops, test=develop (#19827)
上级
982e61f5
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
48 addition
and
51 deletion
+48
-51
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+4
-3
paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc
...optimize_pass/conditional_block_op_eager_deletion_pass.cc
+1
-1
paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc
.../memory_optimize_pass/recurrent_op_eager_deletion_pass.cc
+2
-1
paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc
...k/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc
+1
-1
paddle/fluid/operators/controlflow/conditional_block_op_helper.cc
...luid/operators/controlflow/conditional_block_op_helper.cc
+13
-15
paddle/fluid/operators/controlflow/conditional_block_op_helper.h
...fluid/operators/controlflow/conditional_block_op_helper.h
+2
-1
paddle/fluid/operators/controlflow/recurrent_op_helper.cc
paddle/fluid/operators/controlflow/recurrent_op_helper.cc
+8
-13
paddle/fluid/operators/controlflow/recurrent_op_helper.h
paddle/fluid/operators/controlflow/recurrent_op_helper.h
+2
-2
paddle/fluid/operators/controlflow/while_op_helper.cc
paddle/fluid/operators/controlflow/while_op_helper.cc
+13
-13
paddle/fluid/operators/controlflow/while_op_helper.h
paddle/fluid/operators/controlflow/while_op_helper.h
+2
-1
未找到文件。
paddle/fluid/framework/executor.cc
浏览文件 @
3fd3b663
...
@@ -78,10 +78,11 @@ void ExecutorPrepareContext::PrepareUnusedVars(
...
@@ -78,10 +78,11 @@ void ExecutorPrepareContext::PrepareUnusedVars(
// If gc is enabled and block size > 1
// If gc is enabled and block size > 1
if
(
prog_
.
Size
()
>
1
)
{
if
(
prog_
.
Size
()
>
1
)
{
operators
::
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp
(
operators
::
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp
(
block_id_
,
ops_
);
prog_
,
block_id_
,
ops_
);
operators
::
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
block_id_
,
ops_
);
operators
::
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
prog_
,
block_id_
,
ops_
);
operators
::
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
operators
::
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
block_id_
,
ops_
);
prog_
,
block_id_
,
ops_
);
}
}
unused_vars_
=
GetUnusedVars
(
prog_
.
Block
(
block_id_
),
ops_
,
keep_vars
);
unused_vars_
=
GetUnusedVars
(
prog_
.
Block
(
block_id_
),
ops_
,
keep_vars
);
}
}
...
...
paddle/fluid/framework/ir/memory_optimize_pass/conditional_block_op_eager_deletion_pass.cc
浏览文件 @
3fd3b663
...
@@ -48,7 +48,7 @@ class ConditionalOpEagerDeletionPass : public Pass {
...
@@ -48,7 +48,7 @@ class ConditionalOpEagerDeletionPass : public Pass {
auto
&
ifelse_ops
=
ops_pair
.
second
.
first
;
auto
&
ifelse_ops
=
ops_pair
.
second
.
first
;
auto
&
ifelse_grad_ops
=
ops_pair
.
second
.
second
;
auto
&
ifelse_grad_ops
=
ops_pair
.
second
.
second
;
operators
::
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp
(
operators
::
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp
(
ifelse_ops
,
ifelse_grad_ops
);
graph
->
OriginProgram
(),
ifelse_ops
,
ifelse_grad_ops
);
}
}
}
}
};
};
...
...
paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc
浏览文件 @
3fd3b663
...
@@ -40,7 +40,8 @@ void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const {
...
@@ -40,7 +40,8 @@ void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const {
// Prepare safe eager deletion on different devices because the garbage
// Prepare safe eager deletion on different devices because the garbage
// collection may be different across devices
// collection may be different across devices
OpAndGradOpPair
&
op_pair
=
entry
.
second
;
OpAndGradOpPair
&
op_pair
=
entry
.
second
;
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
&
op_pair
);
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
graph
->
OriginProgram
(),
&
op_pair
);
}
}
}
}
...
...
paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc
浏览文件 @
3fd3b663
...
@@ -47,7 +47,7 @@ class WhileOpEagerDeletionPass : public ir::Pass {
...
@@ -47,7 +47,7 @@ class WhileOpEagerDeletionPass : public ir::Pass {
auto
&
while_ops
=
ops_pair
.
second
.
first
;
auto
&
while_ops
=
ops_pair
.
second
.
first
;
auto
&
while_grad_ops
=
ops_pair
.
second
.
second
;
auto
&
while_grad_ops
=
ops_pair
.
second
.
second
;
operators
::
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
operators
::
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
while_ops
,
while_grad_ops
);
graph
->
OriginProgram
(),
while_ops
,
while_grad_ops
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/controlflow/conditional_block_op_helper.cc
浏览文件 @
3fd3b663
...
@@ -29,16 +29,12 @@ static bool IsMatchedConditionalBlockOpAndConditionalBlockGradOp(
...
@@ -29,16 +29,12 @@ static bool IsMatchedConditionalBlockOpAndConditionalBlockGradOp(
}
}
static
void
FindAllConditionalBlockAndConditionalBlockGradOp
(
static
void
FindAllConditionalBlockAndConditionalBlockGradOp
(
std
::
vector
<
OpVariant
>
*
fwd_ops
,
std
::
vector
<
OpVariant
>
*
bwd_ops
)
{
const
framework
::
ProgramDesc
&
program
,
std
::
vector
<
OpVariant
>
*
fwd_ops
,
std
::
vector
<
OpVariant
>
*
bwd_ops
)
{
PADDLE_ENFORCE_GE
(
fwd_ops
->
size
(),
bwd_ops
->
size
());
PADDLE_ENFORCE_GE
(
fwd_ops
->
size
(),
bwd_ops
->
size
());
if
(
fwd_ops
->
empty
())
return
;
for
(
size_t
i
=
1
;
i
<
program
.
Size
();
++
i
)
{
auto
&
block
=
program
.
Block
(
i
);
const
auto
*
program
=
fwd_ops
->
front
().
Attr
<
framework
::
BlockDesc
*>
(
"sub_block"
)
->
Program
();
for
(
size_t
i
=
1
;
i
<
program
->
Size
();
++
i
)
{
auto
&
block
=
program
->
Block
(
i
);
for
(
size_t
j
=
0
;
j
<
block
.
OpSize
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
block
.
OpSize
();
++
j
)
{
auto
*
op
=
block
.
Op
(
j
);
auto
*
op
=
block
.
Op
(
j
);
if
(
op
->
Type
()
==
"conditional_block"
)
{
if
(
op
->
Type
()
==
"conditional_block"
)
{
...
@@ -86,9 +82,10 @@ static void SetSkipVarsForConditionalBlockOp(OpVariant *fwd_op,
...
@@ -86,9 +82,10 @@ static void SetSkipVarsForConditionalBlockOp(OpVariant *fwd_op,
}
}
static
void
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl
(
static
void
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl
(
std
::
vector
<
OpVariant
>
*
ifelse_ops
,
const
framework
::
ProgramDesc
&
program
,
std
::
vector
<
OpVariant
>
*
ifelse_ops
,
std
::
vector
<
OpVariant
>
*
ifelse_grad_ops
)
{
std
::
vector
<
OpVariant
>
*
ifelse_grad_ops
)
{
FindAllConditionalBlockAndConditionalBlockGradOp
(
ifelse_ops
,
ifelse_grad_ops
);
FindAllConditionalBlockAndConditionalBlockGradOp
(
program
,
ifelse_ops
,
ifelse_grad_ops
);
VLOG
(
2
)
<<
"Found conditional_block op num: "
<<
ifelse_ops
->
size
()
VLOG
(
2
)
<<
"Found conditional_block op num: "
<<
ifelse_ops
->
size
()
<<
", conditional_block_grad op num: "
<<
ifelse_grad_ops
->
size
();
<<
", conditional_block_grad op num: "
<<
ifelse_grad_ops
->
size
();
...
@@ -121,7 +118,7 @@ static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
...
@@ -121,7 +118,7 @@ static void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl(
}
}
void
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp
(
void
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp
(
int
block_id
,
const
framework
::
ProgramDesc
&
program
,
int
block_id
,
const
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
&
all_ops
)
{
const
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
&
all_ops
)
{
// If block_id is not 0, returns
// If block_id is not 0, returns
// This is because all conditional_block_ops and conditional_block_grad_ops
// This is because all conditional_block_ops and conditional_block_grad_ops
...
@@ -143,11 +140,12 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
...
@@ -143,11 +140,12 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
}
}
}
}
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl
(
&
fwd_ops
,
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl
(
&
bwd_ops
);
program
,
&
fwd_ops
,
&
bwd_ops
);
}
}
void
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp
(
void
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp
(
const
framework
::
ProgramDesc
&
program
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
ifelse_ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
ifelse_ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
ifelse_grad_ops
)
{
const
std
::
vector
<
framework
::
OperatorBase
*>
&
ifelse_grad_ops
)
{
std
::
vector
<
OpVariant
>
fwd_ops
,
bwd_ops
;
std
::
vector
<
OpVariant
>
fwd_ops
,
bwd_ops
;
...
@@ -161,8 +159,8 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
...
@@ -161,8 +159,8 @@ void PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
bwd_ops
.
emplace_back
(
op
);
bwd_ops
.
emplace_back
(
op
);
}
}
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl
(
&
fwd_ops
,
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOpImpl
(
&
bwd_ops
);
program
,
&
fwd_ops
,
&
bwd_ops
);
}
}
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/controlflow/conditional_block_op_helper.h
浏览文件 @
3fd3b663
...
@@ -23,10 +23,11 @@ namespace paddle {
...
@@ -23,10 +23,11 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
void
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp
(
void
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp
(
int
block_id
,
const
framework
::
ProgramDesc
&
program
,
int
block_id
,
const
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
&
all_ops
);
const
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
&
all_ops
);
void
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp
(
void
PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp
(
const
framework
::
ProgramDesc
&
program
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
ifelse_ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
ifelse_ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
ifelse_grad_ops
);
const
std
::
vector
<
framework
::
OperatorBase
*>
&
ifelse_grad_ops
);
...
...
paddle/fluid/operators/controlflow/recurrent_op_helper.cc
浏览文件 @
3fd3b663
...
@@ -65,7 +65,8 @@ static void AddSkipVars(const OpVariant &op, const Container &skip_vars) {
...
@@ -65,7 +65,8 @@ static void AddSkipVars(const OpVariant &op, const Container &skip_vars) {
// Find all ops and grad ops with given type name. The ops and grad ops
// Find all ops and grad ops with given type name. The ops and grad ops
// may locate in different blocks so we should traverse all blocks in the
// may locate in different blocks so we should traverse all blocks in the
// program and find them out
// program and find them out
static
void
FindAllOpAndGradOp
(
OpAndGradOpPair
*
op_and_grad_op
,
static
void
FindAllOpAndGradOp
(
const
framework
::
ProgramDesc
&
program
,
OpAndGradOpPair
*
op_and_grad_op
,
const
std
::
string
&
type_name
,
const
std
::
string
&
type_name
,
const
std
::
string
&
backward_type_name
)
{
const
std
::
string
&
backward_type_name
)
{
OpVariantSet
&
ops
=
op_and_grad_op
->
first
;
OpVariantSet
&
ops
=
op_and_grad_op
->
first
;
...
@@ -74,14 +75,8 @@ static void FindAllOpAndGradOp(OpAndGradOpPair *op_and_grad_op,
...
@@ -74,14 +75,8 @@ static void FindAllOpAndGradOp(OpAndGradOpPair *op_and_grad_op,
PADDLE_ENFORCE_GE
(
ops
.
size
(),
grad_ops
.
size
(),
PADDLE_ENFORCE_GE
(
ops
.
size
(),
grad_ops
.
size
(),
"There are extra grad ops in the graph or program"
);
"There are extra grad ops in the graph or program"
);
if
(
ops
.
empty
())
return
;
for
(
size_t
i
=
1
;
i
<
program
.
Size
();
++
i
)
{
auto
&
block
=
program
.
Block
(
i
);
const
auto
*
program
=
ops
.
begin
()
->
Attr
<
framework
::
BlockDesc
*>
(
RecurrentBase
::
kStepBlock
)
->
Program
();
for
(
size_t
i
=
1
;
i
<
program
->
Size
();
++
i
)
{
auto
&
block
=
program
->
Block
(
i
);
for
(
size_t
j
=
0
;
j
<
block
.
OpSize
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
block
.
OpSize
();
++
j
)
{
auto
*
op
=
block
.
Op
(
j
);
auto
*
op
=
block
.
Op
(
j
);
if
(
op
->
Type
()
==
type_name
)
{
if
(
op
->
Type
()
==
type_name
)
{
...
@@ -201,7 +196,7 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr(
...
@@ -201,7 +196,7 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr(
}
}
void
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
void
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
int
block_id
,
const
framework
::
ProgramDesc
&
program
,
int
block_id
,
const
std
::
vector
<
std
::
unique_ptr
<
paddle
::
framework
::
OperatorBase
>>
const
std
::
vector
<
std
::
unique_ptr
<
paddle
::
framework
::
OperatorBase
>>
&
all_ops
)
{
&
all_ops
)
{
// If block_id is not 0, returns
// If block_id is not 0, returns
...
@@ -224,13 +219,13 @@ void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
...
@@ -224,13 +219,13 @@ void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
op_pair
.
second
.
emplace
(
op
.
get
());
op_pair
.
second
.
emplace
(
op
.
get
());
}
}
}
}
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
&
op_pair
);
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
program
,
&
op_pair
);
}
}
void
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
void
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
OpAndGradOpPair
*
op_pair
)
{
const
framework
::
ProgramDesc
&
program
,
OpAndGradOpPair
*
op_pair
)
{
// Find all ops and grad ops at all blocks
// Find all ops and grad ops at all blocks
FindAllOpAndGradOp
(
op_pair
,
"recurrent"
,
"recurrent_grad"
);
FindAllOpAndGradOp
(
program
,
op_pair
,
"recurrent"
,
"recurrent_grad"
);
OpVariantSet
&
recurrent_ops
=
op_pair
->
first
;
OpVariantSet
&
recurrent_ops
=
op_pair
->
first
;
OpVariantSet
&
recurrent_grad_ops
=
op_pair
->
second
;
OpVariantSet
&
recurrent_grad_ops
=
op_pair
->
second
;
...
...
paddle/fluid/operators/controlflow/recurrent_op_helper.h
浏览文件 @
3fd3b663
...
@@ -37,14 +37,14 @@ using OpAndGradOpPair = std::pair<OpVariantSet, OpVariantSet>;
...
@@ -37,14 +37,14 @@ using OpAndGradOpPair = std::pair<OpVariantSet, OpVariantSet>;
// recurrent_grad ops at block 0 and the function will find all recurrent and
// recurrent_grad ops at block 0 and the function will find all recurrent and
// recurrent_grad ops across blocks.
// recurrent_grad ops across blocks.
void
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
void
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
OpAndGradOpPair
*
op_pair
);
const
framework
::
ProgramDesc
&
program
,
OpAndGradOpPair
*
op_pair
);
// Set vars to skip eager deletion on input recurrent and recurrent_grad for
// Set vars to skip eager deletion on input recurrent and recurrent_grad for
// preparing safe eager deletion. The input block_id must be 0 and caller can
// preparing safe eager deletion. The input block_id must be 0 and caller can
// input all ops in the block. The function will find all recurrent and
// input all ops in the block. The function will find all recurrent and
// recurrent_grad ops across blocks.
// recurrent_grad ops across blocks.
void
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
void
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp
(
int
block_id
,
const
framework
::
ProgramDesc
&
program
,
int
block_id
,
const
std
::
vector
<
std
::
unique_ptr
<
paddle
::
framework
::
OperatorBase
>>
const
std
::
vector
<
std
::
unique_ptr
<
paddle
::
framework
::
OperatorBase
>>
&
all_ops
);
&
all_ops
);
...
...
paddle/fluid/operators/controlflow/while_op_helper.cc
浏览文件 @
3fd3b663
...
@@ -100,16 +100,12 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
...
@@ -100,16 +100,12 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
// Find all while_ops and while_grad_ops in the graph or program
// Find all while_ops and while_grad_ops in the graph or program
// The while_grad_op and while_op may located in different blocks
// The while_grad_op and while_op may located in different blocks
// So we should traverse all blocks in the program and find them out.
// So we should traverse all blocks in the program and find them out.
static
void
FindAllWhileAndWhileGradOp
(
std
::
vector
<
OpVariant
>
*
while_ops
,
static
void
FindAllWhileAndWhileGradOp
(
const
framework
::
ProgramDesc
&
program
,
std
::
vector
<
OpVariant
>
*
while_ops
,
std
::
vector
<
OpVariant
>
*
while_grad_ops
)
{
std
::
vector
<
OpVariant
>
*
while_grad_ops
)
{
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
());
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
());
for
(
size_t
i
=
1
;
i
<
program
.
Size
();
++
i
)
{
if
(
while_ops
->
empty
())
return
;
auto
&
block
=
program
.
Block
(
i
);
const
auto
*
program
=
while_ops
->
front
().
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
)
->
Program
();
for
(
size_t
i
=
1
;
i
<
program
->
Size
();
++
i
)
{
auto
&
block
=
program
->
Block
(
i
);
for
(
size_t
j
=
0
;
j
<
block
.
OpSize
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
block
.
OpSize
();
++
j
)
{
auto
*
op
=
block
.
Op
(
j
);
auto
*
op
=
block
.
Op
(
j
);
if
(
op
->
Type
()
==
"while"
)
{
if
(
op
->
Type
()
==
"while"
)
{
...
@@ -125,8 +121,9 @@ static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
...
@@ -125,8 +121,9 @@ static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
}
}
static
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
static
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
std
::
vector
<
OpVariant
>
*
while_ops
,
std
::
vector
<
OpVariant
>
*
while_grad_ops
)
{
const
framework
::
ProgramDesc
&
program
,
std
::
vector
<
OpVariant
>
*
while_ops
,
FindAllWhileAndWhileGradOp
(
while_ops
,
while_grad_ops
);
std
::
vector
<
OpVariant
>
*
while_grad_ops
)
{
FindAllWhileAndWhileGradOp
(
program
,
while_ops
,
while_grad_ops
);
VLOG
(
2
)
<<
"Found while op num: "
<<
while_ops
->
size
()
VLOG
(
2
)
<<
"Found while op num: "
<<
while_ops
->
size
()
<<
", while grad op num: "
<<
while_grad_ops
->
size
();
<<
", while grad op num: "
<<
while_grad_ops
->
size
();
...
@@ -155,7 +152,7 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
...
@@ -155,7 +152,7 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
}
}
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
int
block_id
,
const
framework
::
ProgramDesc
&
program
,
int
block_id
,
const
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
&
all_ops
)
{
const
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
&
all_ops
)
{
// If block_id is not 0, returns
// If block_id is not 0, returns
// This is because all while_ops and while_grad_ops in the whole program
// This is because all while_ops and while_grad_ops in the whole program
...
@@ -176,10 +173,12 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
...
@@ -176,10 +173,12 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
bwd_ops
.
emplace_back
(
op
.
get
());
bwd_ops
.
emplace_back
(
op
.
get
());
}
}
}
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
&
fwd_ops
,
&
bwd_ops
);
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
program
,
&
fwd_ops
,
&
bwd_ops
);
}
}
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
const
framework
::
ProgramDesc
&
program
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_grad_ops
)
{
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_grad_ops
)
{
std
::
vector
<
OpVariant
>
fwd_ops
,
bwd_ops
;
std
::
vector
<
OpVariant
>
fwd_ops
,
bwd_ops
;
...
@@ -193,7 +192,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
...
@@ -193,7 +192,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
bwd_ops
.
emplace_back
(
op
);
bwd_ops
.
emplace_back
(
op
);
}
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
&
fwd_ops
,
&
bwd_ops
);
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
program
,
&
fwd_ops
,
&
bwd_ops
);
}
}
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/controlflow/while_op_helper.h
浏览文件 @
3fd3b663
...
@@ -32,10 +32,11 @@ static constexpr char kOutputs[] = "Out";
...
@@ -32,10 +32,11 @@ static constexpr char kOutputs[] = "Out";
static
constexpr
char
kSkipEagerDeletionVars
[]
=
"skip_eager_deletion_vars"
;
static
constexpr
char
kSkipEagerDeletionVars
[]
=
"skip_eager_deletion_vars"
;
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
int
block_id
,
const
framework
::
ProgramDesc
&
program
,
int
block_id
,
const
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
&
all_ops
);
const
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
&
all_ops
);
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
const
framework
::
ProgramDesc
&
program
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_grad_ops
);
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_grad_ops
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录