Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
56dd7653
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
56dd7653
编写于
8月 27, 2019
作者:
H
Huihuang Zheng
提交者:
GitHub
8月 27, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Delete useless ex-scope in recurrent op (#19426)
上级
b8aa37d5
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
56 addition
and
54 deletion
+56
-54
paddle/fluid/operators/recurrent_op.cc
paddle/fluid/operators/recurrent_op.cc
+40
-43
paddle/fluid/operators/recurrent_op.h
paddle/fluid/operators/recurrent_op.h
+16
-11
未找到文件。
paddle/fluid/operators/recurrent_op.cc
浏览文件 @
56dd7653
...
...
@@ -54,20 +54,6 @@ static void ClearStepScopes(const platform::DeviceContext &dev_ctx,
step_scopes
->
clear
();
}
// StepScopes manages scopes inside RNN.
// StepScopes::CurScope() get the current scope
// StepScopes::ExScope() get the ex-scope, or scope in previous time step.
// StepScopes::Next() move to next time step.
//
// if is_train = False, then
// there are two scopes for the RNN and just support forward.
// else
// the len(scopes) == seq_len
//
// if is_backward = True, then
// reversely access scopes
// else
// access scopes from begin to end.
StepScopes
::
StepScopes
(
const
platform
::
DeviceContext
&
dev_ctx
,
const
framework
::
Scope
&
parent
,
StepScopeVar
*
scopes
,
bool
is_train
,
size_t
seq_len
,
bool
is_backward
)
...
...
@@ -76,8 +62,8 @@ StepScopes::StepScopes(const platform::DeviceContext &dev_ctx,
is_train_
(
is_train
),
is_backward_
(
is_backward
)
{
size_t
num_step_scopes
=
is_train
?
seq_len
:
2
;
PADDLE_ENFORCE
(
is_train
||
!
is_backward
,
"Cannot backward when is not training"
);
PADDLE_ENFORCE
_EQ
(
is_train
||
!
is_backward
,
true
,
"Cannot backward when is not training"
);
if
(
!
is_backward_
)
{
ClearStepScopes
(
dev_ctx
,
const_cast
<
framework
::
Scope
*>
(
&
parent
),
scopes
);
scopes
->
reserve
(
static_cast
<
size_t
>
(
num_step_scopes
));
...
...
@@ -94,12 +80,22 @@ framework::Scope &StepScopes::ExScope() {
return
scope
;
}
void
StepScopes
::
Next
()
{
if
(
is_backward_
)
{
--
counter_
;
}
else
{
++
counter_
;
void
StepScopes
::
BackwardNext
(
const
platform
::
DeviceContext
&
dev_ctx
,
framework
::
Scope
*
parent_scope
)
{
PADDLE_ENFORCE_EQ
(
is_backward_
,
true
,
"Cannot get backward next scope when is forward"
);
if
(
counter_
+
2
==
scopes_
->
size
())
{
parent_scope
->
DeleteScope
((
*
scopes_
)[
counter_
+
1
]);
scopes_
->
pop_back
();
VLOG
(
3
)
<<
"Deleted scope at "
<<
counter_
+
1
;
}
--
counter_
;
}
void
StepScopes
::
ForwardNext
()
{
PADDLE_ENFORCE_EQ
(
is_backward_
,
false
,
"Cannot get forward next scope when is backward"
);
++
counter_
;
}
framework
::
Scope
&
StepScopes
::
GetScope
(
size_t
scope_id
)
const
{
...
...
@@ -125,11 +121,11 @@ int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const {
// Dim format SEQ_LEN, BATCH_SIZE, ...
int64_t
seq_len
=
-
1
;
auto
&
all_inputs
=
Inputs
(
kInputs
);
PADDLE_ENFORCE
(
!
all_inputs
.
empty
()
);
PADDLE_ENFORCE
_EQ
(
!
all_inputs
.
empty
(),
true
);
for
(
auto
&
iname
:
all_inputs
)
{
auto
*
var
=
scope
.
FindVar
(
iname
);
PADDLE_ENFORCE
(
var
!=
nullpt
r
);
PADDLE_ENFORCE
(
var
->
IsType
<
framework
::
LoDTensor
>
()
);
PADDLE_ENFORCE
_NOT_NULL
(
va
r
);
PADDLE_ENFORCE
_EQ
(
var
->
IsType
<
framework
::
LoDTensor
>
(),
true
);
auto
&
dim
=
var
->
Get
<
framework
::
LoDTensor
>
().
dims
();
if
(
seq_len
==
-
1
)
{
seq_len
=
dim
[
0
];
...
...
@@ -254,7 +250,7 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
});
}
scopes
.
Next
();
scopes
.
Forward
Next
();
}
}
...
...
@@ -262,7 +258,7 @@ StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx,
const
framework
::
Scope
&
scope
,
size_t
seq_len
)
const
{
auto
*
var
=
scope
.
FindVar
(
Output
(
kStepScopes
));
PADDLE_ENFORCE
(
var
!=
nullpt
r
);
PADDLE_ENFORCE
_NOT_NULL
(
va
r
);
return
StepScopes
(
dev_ctx
,
scope
,
var
->
GetMutable
<
StepScopeVar
>
(),
Attr
<
bool
>
(
kIsTrain
),
seq_len
);
}
...
...
@@ -459,11 +455,11 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
VLOG
(
5
)
<<
"Link initialize state gradient finished "
;
}
}
scopes
.
Next
(
);
scopes
.
BackwardNext
(
dev_ctx
,
const_cast
<
framework
::
Scope
*>
(
&
scope
)
);
}
// Delete the scope of StepScopes
auto
*
var
=
scope
.
FindVar
(
Input
(
kStepScopes
));
PADDLE_ENFORCE
(
var
!=
nullpt
r
);
PADDLE_ENFORCE
_NOT_NULL
(
va
r
);
auto
*
step_scopes
=
var
->
GetMutable
<
StepScopeVar
>
();
ClearStepScopes
(
dev_ctx
,
const_cast
<
framework
::
Scope
*>
(
&
scope
),
step_scopes
);
}
...
...
@@ -472,7 +468,7 @@ StepScopes RecurrentGradOp::CreateStepScopes(
const
platform
::
DeviceContext
&
dev_ctx
,
const
framework
::
Scope
&
scope
,
size_t
seq_len
)
const
{
auto
*
var
=
scope
.
FindVar
(
Input
(
kStepScopes
));
PADDLE_ENFORCE
(
var
!=
nullpt
r
);
PADDLE_ENFORCE
_NOT_NULL
(
va
r
);
return
StepScopes
(
dev_ctx
,
scope
,
var
->
GetMutable
<
StepScopeVar
>
(),
Attr
<
bool
>
(
kIsTrain
),
seq_len
,
true
/*is_backward*/
);
}
...
...
@@ -491,6 +487,7 @@ std::unordered_set<std::string> RecurrentGradOp::LocalVarNames(
const
framework
::
Scope
&
scope
)
const
{
return
this
->
List2Set
(
scope
.
LocalVarNames
());
}
std
::
vector
<
std
::
string
>
RecurrentGradOp
::
GradVarLists
(
const
std
::
vector
<
std
::
string
>
&
var_names
)
{
std
::
vector
<
std
::
string
>
retv
;
...
...
@@ -627,25 +624,25 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
0
,
"The Attr(%s) should be empty."
,
RecurrentBase
::
kStates
);
}
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
RecurrentBase
::
kInputs
)
,
"The input(%s) should not be empty."
,
RecurrentBase
::
kInputs
);
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
RecurrentBase
::
kOutputs
)
,
"The input(%s) should not be empty."
,
RecurrentBase
::
kOutputs
);
PADDLE_ENFORCE
_EQ
(
ctx
->
HasInputs
(
RecurrentBase
::
kInputs
),
true
,
"The input(%s) should not be empty."
,
RecurrentBase
::
kInputs
);
PADDLE_ENFORCE
_EQ
(
ctx
->
HasInputs
(
RecurrentBase
::
kOutputs
),
true
,
"The input(%s) should not be empty."
,
RecurrentBase
::
kOutputs
);
// In some case the kInitialStates is empty.
if
(
ctx
->
HasInputs
(
RecurrentBase
::
kInitialStates
))
{
PADDLE_ENFORCE
(
ctx
->
HasOutputs
(
framework
::
GradVarName
(
RecurrentBase
::
kInitialStates
)),
"The output of(%s) should not be empty."
,
framework
::
GradVarName
(
RecurrentBase
::
kInitialStates
));
PADDLE_ENFORCE
_EQ
(
ctx
->
HasOutputs
(
framework
::
GradVarName
(
RecurrentBase
::
kInitialStates
)),
true
,
"The output of(%s) should not be empty."
,
framework
::
GradVarName
(
RecurrentBase
::
kInitialStates
));
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
RecurrentBase
::
kInitialStates
),
ctx
->
GetInputsDim
(
RecurrentBase
::
kInitialStates
));
}
PADDLE_ENFORCE
(
ctx
->
HasOutputs
(
framework
::
GradVarName
(
RecurrentBase
::
kInputs
)),
PADDLE_ENFORCE
_EQ
(
ctx
->
HasOutputs
(
framework
::
GradVarName
(
RecurrentBase
::
kInputs
)),
true
,
"The output of(%s) should not be empty."
,
framework
::
GradVarName
(
RecurrentBase
::
kInputs
));
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
RecurrentBase
::
kInputs
),
...
...
@@ -653,9 +650,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
// In some case the kParameters is empty.
if
(
ctx
->
HasInputs
(
RecurrentBase
::
kParameters
))
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
_EQ
(
ctx
->
HasOutputs
(
framework
::
GradVarName
(
RecurrentBase
::
kParameters
)),
"The output of(%s) should not be empty."
,
true
,
"The output of(%s) should not be empty."
,
framework
::
GradVarName
(
RecurrentBase
::
kParameters
));
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
RecurrentBase
::
kParameters
),
ctx
->
GetInputsDim
(
RecurrentBase
::
kParameters
));
...
...
paddle/fluid/operators/recurrent_op.h
浏览文件 @
56dd7653
...
...
@@ -25,20 +25,17 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
// StepScopes manages scopes inside RNN.
// StepScopes::CurScope() get the current scope
// StepScopes::ExScope() get the ex-scope, or scope in previous time step.
// StepScopes::Next() move to next time step.
// StepScopes manages the scopes inside Recurrent Op.
//
// if is_train = False, then
// there are two scopes for the RNN and just support forward
.
// there are two scopes for the RNN and just support forward
// else
// the len(scopes) == seq_len
//
// if is_backward = True, then
// reversely access scopes
// reversely access scopes
, delete useless ex-scope
// else
// access scopes from begin
to end.
// access scopes from begin
ning to end
class
StepScopes
{
public:
StepScopes
(
const
platform
::
DeviceContext
&
dev_ctx
,
...
...
@@ -46,11 +43,19 @@ class StepScopes {
std
::
vector
<
framework
::
Scope
*>
*
scopes
,
bool
is_train
,
size_t
seq_len
,
bool
is_backward
=
false
);
// Get the current scope
framework
::
Scope
&
CurScope
();
// Get the ex-scope, which is the scope in previous time step
framework
::
Scope
&
ExScope
();
void
Next
();
// Move to next time step when forwarding
void
ForwardNext
();
// Delete ex-scope after using it, then move to next time step when
// backwarding
void
BackwardNext
(
const
platform
::
DeviceContext
&
dev_ctx
,
framework
::
Scope
*
parent_scope
);
private:
framework
::
Scope
&
GetScope
(
size_t
scope_id
)
const
;
...
...
@@ -154,7 +159,7 @@ class RecurrentBase : public framework::OperatorBase {
if
(
is_backward
&&
src_var
==
nullptr
)
{
return
;
}
PADDLE_ENFORCE
(
src_var
!=
nullpt
r
,
"%s is not found."
,
src_var_name
);
PADDLE_ENFORCE
_NOT_NULL
(
src_va
r
,
"%s is not found."
,
src_var_name
);
auto
&
src_tensor
=
src_var
->
Get
<
framework
::
LoDTensor
>
();
auto
*
dst_var
=
dst_scope
->
Var
(
dst_var_name
);
...
...
@@ -173,9 +178,9 @@ class RecurrentBase : public framework::OperatorBase {
return
;
}
auto
*
src_var
=
src_scope
.
FindVar
(
src_var_name
);
PADDLE_ENFORCE
(
src_var
!=
nullpt
r
,
"%s is not found."
,
src_var_name
);
PADDLE_ENFORCE
_NOT_NULL
(
src_va
r
,
"%s is not found."
,
src_var_name
);
auto
&
src_tensor
=
src_var
->
Get
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE
(
dst_var
!=
nullpt
r
,
"%s is not found."
,
dst_var_name
);
PADDLE_ENFORCE
_NOT_NULL
(
dst_va
r
,
"%s is not found."
,
dst_var_name
);
auto
*
dst_tensor
=
dst_var
->
GetMutable
<
framework
::
LoDTensor
>
();
callback
(
src_tensor
,
dst_tensor
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录