Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
712bfb17
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
712bfb17
编写于
5月 16, 2019
作者:
Z
Zeng Jinle
提交者:
GitHub
5月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix recurrent_op,test=develop (#17433)
上级
5babcd02
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
6 addition
and
12 deletion
+6
-12
paddle/fluid/framework/ir/memory_optimize_pass/record_skip_memory_opt_vars_pass.cc
.../memory_optimize_pass/record_skip_memory_opt_vars_pass.cc
+3
-3
paddle/fluid/operators/recurrent_op.cc
paddle/fluid/operators/recurrent_op.cc
+1
-0
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+2
-9
未找到文件。
paddle/fluid/framework/ir/memory_optimize_pass/record_skip_memory_opt_vars_pass.cc
浏览文件 @
712bfb17
...
@@ -140,9 +140,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
...
@@ -140,9 +140,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
// fail since "states" and "ex_states" cannot be found in main block.
// fail since "states" and "ex_states" cannot be found in main block.
// When memory optimization is enabled, "states", "ex_states" and their
// When memory optimization is enabled, "states", "ex_states" and their
// gradient should be skipped.
// gradient should be skipped.
auto
&
ex_states
=
auto
ex_states
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op_desc
->
GetAttr
(
"ex_states"
));
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op_desc
->
GetAttr
(
"ex_states"
));
auto
&
states
=
auto
states
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op_desc
->
GetAttr
(
"states"
));
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op_desc
->
GetAttr
(
"states"
));
if
(
op_type
==
"recurrent"
)
{
if
(
op_type
==
"recurrent"
)
{
UpdateSkipVarSet
(
skip_vars
,
{
ex_states
,
states
});
UpdateSkipVarSet
(
skip_vars
,
{
ex_states
,
states
});
...
@@ -154,7 +154,7 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
...
@@ -154,7 +154,7 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
UpdateSkipVarSet
(
UpdateSkipVarSet
(
skip_vars
,
skip_vars
,
{
ToGradVarName
(
op_desc
->
Input
(
"parameters"
)),
{
ToGradVarName
(
op_desc
->
Input
(
"parameters"
)),
ToGradVarName
(
op_desc
->
Input
(
"input"
)),
ex_states
,
states
,
ToGradVarName
(
op_desc
->
Input
(
"input
s
"
)),
ex_states
,
states
,
ToGradVarName
(
ex_states
),
ToGradVarName
(
states
)});
ToGradVarName
(
ex_states
),
ToGradVarName
(
states
)});
}
}
}
}
...
...
paddle/fluid/operators/recurrent_op.cc
浏览文件 @
712bfb17
...
@@ -508,6 +508,7 @@ class RecurrentGradOp : public RecurrentBase {
...
@@ -508,6 +508,7 @@ class RecurrentGradOp : public RecurrentBase {
for
(
auto
*
sub_scope
:
*
step_scopes
)
{
for
(
auto
*
sub_scope
:
*
step_scopes
)
{
const_cast
<
framework
::
Scope
&>
(
scope
).
DeleteScope
(
sub_scope
);
const_cast
<
framework
::
Scope
&>
(
scope
).
DeleteScope
(
sub_scope
);
}
}
step_scopes
->
clear
();
}
}
private:
private:
...
...
python/paddle/fluid/backward.py
浏览文件 @
712bfb17
...
@@ -232,15 +232,8 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
...
@@ -232,15 +232,8 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
for
arg
in
op_desc
.
input_arg_names
():
for
arg
in
op_desc
.
input_arg_names
():
if
core
.
grad_var_suffix
()
in
arg
and
arg
in
no_grad_set
:
if
core
.
grad_var_suffix
()
in
arg
and
arg
in
no_grad_set
:
x_in
=
_strip_grad_suffix_
(
arg
)
x_in
=
_strip_grad_suffix_
(
arg
)
x_in_var_desc
=
op_desc
.
block
().
find_var_recursive
(
to_insert
.
append
((
_create_op_desc_
(
cpt
.
to_bytes
(
x_in
))
"fill_zeros_like"
,
{
"X"
:
[
x_in
]},
{
"Out"
:
[
arg
]},
{}),
idx
))
assert
x_in_var_desc
is
not
None
,
"Variable {} not found"
.
format
(
x_in
)
dtype
=
x_in_var_desc
.
dtype
()
to_insert
.
append
(
(
_create_op_desc_
(
"fill_zeros_like2"
,
{
"X"
:
[
x_in
]},
{
"Out"
:
[
arg
]},
{
"dtype"
:
dtype
}),
idx
))
list
([
op_descs
.
insert
(
p
[
1
],
p
[
0
])
for
p
in
reversed
(
to_insert
)])
list
([
op_descs
.
insert
(
p
[
1
],
p
[
0
])
for
p
in
reversed
(
to_insert
)])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录