Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
07ea9ade
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
07ea9ade
编写于
10月 20, 2017
作者:
Y
Yan Chunwei
提交者:
GitHub
10月 20, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feature/dynamic recurrent op forward and backward (#4799)
上级
5380a547
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
478 addition
and
283 deletion
+478
-283
doc/design/block.md
doc/design/block.md
+1
-1
paddle/framework/backward.cc
paddle/framework/backward.cc
+14
-2
paddle/operators/dynamic_recurrent_op.cc
paddle/operators/dynamic_recurrent_op.cc
+194
-115
paddle/operators/dynamic_recurrent_op.h
paddle/operators/dynamic_recurrent_op.h
+108
-57
paddle/operators/dynamic_recurrent_op_test.cc
paddle/operators/dynamic_recurrent_op_test.cc
+22
-26
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+13
-13
paddle/operators/rnn/recurrent_op_utils.cc
paddle/operators/rnn/recurrent_op_utils.cc
+11
-11
paddle/operators/rnn/recurrent_op_utils.h
paddle/operators/rnn/recurrent_op_utils.h
+6
-6
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+5
-5
python/paddle/v2/framework/tests/test_dynamic_recurrent_op.py
...on/paddle/v2/framework/tests/test_dynamic_recurrent_op.py
+94
-37
python/paddle/v2/framework/tests/test_recurrent_op.py
python/paddle/v2/framework/tests/test_recurrent_op.py
+10
-10
未找到文件。
doc/design/block.md
浏览文件 @
07ea9ade
...
@@ -189,7 +189,7 @@ OpDesc {
...
@@ -189,7 +189,7 @@ OpDesc {
inputs = {0} // the index of x in vars of BlockDesc above
inputs = {0} // the index of x in vars of BlockDesc above
outputs = {5, 3} // indices of act and hidden_out in vars of BlockDesc above
outputs = {5, 3} // indices of act and hidden_out in vars of BlockDesc above
attrs {
attrs {
"
memori
es" : {1} // the index of h
"
stat
es" : {1} // the index of h
"step_net" : <above step net>
"step_net" : <above step net>
}
}
};
};
...
...
paddle/framework/backward.cc
浏览文件 @
07ea9ade
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "paddle/framework/block_desc.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/dynamic_recurrent_op.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/operators/recurrent_op.h"
...
@@ -220,8 +221,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
...
@@ -220,8 +221,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// process recurrent gradient op as a special operator.
// process recurrent gradient op as a special operator.
if
(
forwardOp
.
Type
()
==
"recurrent"
)
{
if
(
forwardOp
.
Type
()
==
"recurrent"
)
{
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or
// or this will result in infinite loop.
// this will result in infinite loop.
const
auto
&
rnnop
=
const
auto
&
rnnop
=
*
static_cast
<
const
operators
::
RecurrentOp
*>
(
&
forwardOp
);
*
static_cast
<
const
operators
::
RecurrentOp
*>
(
&
forwardOp
);
auto
rnn_grad_op
=
auto
rnn_grad_op
=
...
@@ -231,6 +231,18 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
...
@@ -231,6 +231,18 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// create stepnet's gradient op
// create stepnet's gradient op
rnn_grad_op
->
set_stepnet
(
rnn_grad_op
->
set_stepnet
(
BackwardRecursive
(
stepnet_op
,
no_grad_names
,
grad_to_var
,
uniq_id
));
BackwardRecursive
(
stepnet_op
,
no_grad_names
,
grad_to_var
,
uniq_id
));
}
else
if
(
forwardOp
.
Type
()
==
"dynamic_recurrent"
)
{
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or this will result in infinite loop.
const
auto
&
rnnop
=
*
static_cast
<
const
operators
::
DynamicRecurrentOp
*>
(
&
forwardOp
);
auto
rnn_grad_op
=
static_cast
<
operators
::
DynamicRecurrentGradientOp
*>
(
grad_op
.
get
());
const
auto
&
stepnet_op
=
*
static_cast
<
const
OperatorBase
*>
(
&
rnnop
.
rnn
.
GetStepUnit
());
// create stepnet's gradient op
rnn_grad_op
->
rnn
.
SetStepUnit
(
BackwardRecursive
(
stepnet_op
,
no_grad_names
,
grad_to_var
,
uniq_id
));
}
}
if
(
net
->
ops_
.
empty
())
{
// Current no aux op is added to network
if
(
net
->
ops_
.
empty
())
{
// Current no aux op is added to network
...
...
paddle/operators/dynamic_recurrent_op.cc
浏览文件 @
07ea9ade
...
@@ -23,6 +23,7 @@ using framework::Scope;
...
@@ -23,6 +23,7 @@ using framework::Scope;
using
framework
::
TensorArray
;
using
framework
::
TensorArray
;
using
framework
::
LoDTensor
;
using
framework
::
LoDTensor
;
using
framework
::
Variable
;
using
framework
::
Variable
;
using
framework
::
OperatorBase
;
using
framework
::
DySeqMetaBatch
;
using
framework
::
DySeqMetaBatch
;
namespace
detail
{
namespace
detail
{
...
@@ -43,8 +44,7 @@ inline void CreateVariables(Scope& scope,
...
@@ -43,8 +44,7 @@ inline void CreateVariables(Scope& scope,
* be reordered, but the RNN op should not change the `boot_state` as an input
* be reordered, but the RNN op should not change the `boot_state` as an input
* variable's content.
* variable's content.
*/
*/
template
<
typename
T
>
inline
void
ReorderInitialState
(
const
DySeqMetaBatch
&
metas
,
inline
void
ReorderBootState
(
const
DySeqMetaBatch
&
metas
,
const
LoDTensor
&
boot_state
,
LoDTensor
*
tensor
,
const
LoDTensor
&
boot_state
,
LoDTensor
*
tensor
,
const
platform
::
Place
&
dst_place
)
{
const
platform
::
Place
&
dst_place
)
{
for
(
size_t
seq_id
=
0
;
seq_id
<
metas
.
size
();
seq_id
++
)
{
for
(
size_t
seq_id
=
0
;
seq_id
<
metas
.
size
();
seq_id
++
)
{
...
@@ -56,58 +56,60 @@ inline void ReorderBootState(const DySeqMetaBatch& metas,
...
@@ -56,58 +56,60 @@ inline void ReorderBootState(const DySeqMetaBatch& metas,
}
}
}
}
}
// namespace detail
inline
void
RestoreInitialState
(
const
DySeqMetaBatch
&
metas
,
const
LoDTensor
&
tensor
,
LoDTensor
*
boot_state
,
class
DynamicRecurrentOpProtoAndCheckerMaker
const
platform
::
Place
&
dst_place
)
{
:
public
framework
::
OpProtoAndCheckerMaker
{
for
(
size_t
seq_id
=
0
;
seq_id
<
metas
.
size
();
seq_id
++
)
{
public:
auto
slice
=
tensor
.
Slice
(
seq_id
,
seq_id
+
1
);
DynamicRecurrentOpProtoAndCheckerMaker
(
framework
::
OpProto
*
proto
,
auto
boot_slice
=
framework
::
OpAttrChecker
*
op_checker
)
boot_state
->
Slice
(
metas
[
seq_id
].
ori_idx
,
metas
[
seq_id
].
ori_idx
+
1
);
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
boot_slice
.
CopyFrom
(
slice
,
dst_place
,
platform
::
CPUDeviceContext
());
const
auto
&
name
=
DynamicRecurrentOp
::
kArgName
;
// inputs and outputs stored in proto
AddInput
(
name
.
inlinks
,
"the inputs that need to be segmented for each step."
)
.
AsDuplicable
();
AddInput
(
name
.
boot_memories
,
"variables to initialize memories."
)
.
AsDuplicable
();
AddOutput
(
name
.
outlinks
,
"the outputs that need to concated for all steps."
)
.
AsDuplicable
();
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
// Attributes stored in AttributeMap
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
pre_memories
,
"names of pre-memories"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
memories
,
"names of memories"
);
AddComment
(
"This is a RNN operator for varience-length sequences."
);
}
}
}
;
}
void
DynamicRecurrentOp
::
Run
(
const
Scope
&
scope
,
}
// namespace detail
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
cache_
.
Init
(
kArgName
,
*
this
,
scope
,
&
arg_
);
// Implementation for forward propagation.
template
<
>
void
RNNAlgorithm
::
Run
<
RNNAlgorithm
::
ComputeMode
::
kForward
>
(
const
framework
::
Scope
&
scope
,
const
framework
::
OperatorBase
&
op
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
SetComputeMode
(
ComputeMode
::
kForward
);
cache_
.
Init
(
kArgNames
[
mode_
],
op
,
scope
,
&
dev_ctx
,
&
arg_
);
SplitInputs
();
SplitInputs
();
CreateScopes
();
CreateScopes
();
WriteStepInputs
();
WriteStepInputs
();
InitStates
();
InitStates
();
WriteStepOutputs
();
WriteStepOutputs
();
RunSteps
();
ConcatOutputs
();
}
// call stepnet in all the time steps
// Implementation for backward propagation.
for
(
size_t
step
=
0
;
step
<
cache_
.
num_steps
;
step
++
)
{
template
<
>
auto
&
step_scope
=
cache_
.
GetScope
(
step
);
void
RNNAlgorithm
::
Run
<
RNNAlgorithm
::
ComputeMode
::
kBackward
>
(
stepnet_
->
Run
(
step_scope
,
dev_ctx
);
const
framework
::
Scope
&
scope
,
const
framework
::
OperatorBase
&
op
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
SetComputeMode
(
ComputeMode
::
kBackward
);
cache_
.
Init
(
kArgNames
[
mode_
],
op
,
scope
,
&
dev_ctx
,
&
arg_
);
SplitInputs
();
WriteStepInputs
();
InitStates
();
WriteStepOutputs
();
RunSteps
();
// copy boot-states' gradients back.
for
(
const
auto
&
state
:
arg_
.
states
)
{
ExportInitialStateGradient
(
state
);
}
}
ConcatOutputs
();
ConcatOutputs
();
}
}
void
DynamicRecurrentOp
::
SplitInputs
()
const
{
void
RNNAlgorithm
::
SplitInputs
()
{
// TODO(superjom) make level a config
// TODO(superjom) make level a config
// TODO(superjom) check all the inputs has the same LoD
// TODO(superjom) check all the inputs has the same LoD
int
level
=
0
;
int
level
=
0
;
for
(
const
auto
&
item
:
cache_
.
in
link
s
)
{
for
(
const
auto
&
item
:
cache_
.
in
put
s
)
{
const
auto
&
var
=
item
.
second
;
const
auto
&
var
=
item
.
second
;
const
auto
&
tensor
=
var
->
Get
<
LoDTensor
>
();
const
auto
&
tensor
=
var
->
Get
<
LoDTensor
>
();
TensorArray
&
ta
=
step_inputs_
[
item
.
first
];
TensorArray
&
ta
=
step_inputs_
[
item
.
first
];
...
@@ -124,8 +126,8 @@ void DynamicRecurrentOp::SplitInputs() const {
...
@@ -124,8 +126,8 @@ void DynamicRecurrentOp::SplitInputs() const {
}
}
}
}
void
DynamicRecurrentOp
::
WriteStepInputs
()
const
{
void
RNNAlgorithm
::
WriteStepInputs
()
{
for
(
const
auto
&
item
:
cache_
.
in
link
s
)
{
for
(
const
auto
&
item
:
cache_
.
in
put
s
)
{
auto
ta_it
=
step_inputs_
.
find
(
item
.
first
);
auto
ta_it
=
step_inputs_
.
find
(
item
.
first
);
PADDLE_ENFORCE
(
ta_it
!=
step_inputs_
.
end
(),
PADDLE_ENFORCE
(
ta_it
!=
step_inputs_
.
end
(),
"step_inputs_ not compatible with memory set"
);
"step_inputs_ not compatible with memory set"
);
...
@@ -142,15 +144,15 @@ void DynamicRecurrentOp::WriteStepInputs() const {
...
@@ -142,15 +144,15 @@ void DynamicRecurrentOp::WriteStepInputs() const {
}
}
}
}
void
DynamicRecurrentOp
::
WriteStepOutputs
()
const
{
void
RNNAlgorithm
::
WriteStepOutputs
()
{
// initialize step outputs
// initialize step outputs
for
(
const
auto
&
item
:
cache_
.
out
link
s
)
{
for
(
const
auto
&
item
:
cache_
.
out
put
s
)
{
step_outputs_
.
emplace
(
item
.
first
,
TensorArray
());
step_outputs_
.
emplace
(
item
.
first
,
TensorArray
());
}
}
PADDLE_ENFORCE_GT
(
step_outputs_
.
size
(),
0UL
);
PADDLE_ENFORCE_GT
(
step_outputs_
.
size
(),
0UL
);
}
}
void
DynamicRecurrentOp
::
CreateScopes
()
const
{
void
RNNAlgorithm
::
CreateScopes
()
{
PADDLE_ENFORCE_GT
(
cache_
.
num_steps
,
0
);
PADDLE_ENFORCE_GT
(
cache_
.
num_steps
,
0
);
// resize scopes
// resize scopes
size_t
num_scopes_need_create
=
cache_
.
num_steps
-
cache_
.
scopes
->
size
();
size_t
num_scopes_need_create
=
cache_
.
num_steps
-
cache_
.
scopes
->
size
();
...
@@ -159,19 +161,19 @@ void DynamicRecurrentOp::CreateScopes() const {
...
@@ -159,19 +161,19 @@ void DynamicRecurrentOp::CreateScopes() const {
}
}
// init temporary inputs
// init temporary inputs
PADDLE_ENFORCE_NOT_NULL
(
step
ne
t_
,
"stepnet should be set first"
);
PADDLE_ENFORCE_NOT_NULL
(
step
_uni
t_
,
"stepnet should be set first"
);
std
::
vector
<
std
::
string
>
memori
es
;
std
::
vector
<
std
::
string
>
stat
es
;
std
::
vector
<
std
::
string
>
pre_memori
es
;
std
::
vector
<
std
::
string
>
ex_stat
es
;
std
::
vector
<
std
::
string
>
step
ne
t_outputs
;
std
::
vector
<
std
::
string
>
step
_uni
t_outputs
;
std
::
transform
(
arg_
.
memories
.
begin
(),
arg_
.
memori
es
.
end
(),
std
::
transform
(
arg_
.
states
.
begin
(),
arg_
.
stat
es
.
end
(),
std
::
back_inserter
(
memori
es
),
std
::
back_inserter
(
stat
es
),
[](
const
rnn
::
Memory
Attr
&
m
)
{
return
m
.
var
;
});
[](
const
rnn
::
State
Attr
&
m
)
{
return
m
.
var
;
});
std
::
transform
(
arg_
.
memories
.
begin
(),
arg_
.
memori
es
.
end
(),
std
::
transform
(
arg_
.
states
.
begin
(),
arg_
.
stat
es
.
end
(),
std
::
back_inserter
(
pre_memori
es
),
std
::
back_inserter
(
ex_stat
es
),
[](
const
rnn
::
Memory
Attr
&
m
)
{
return
m
.
pre_var
;
});
[](
const
rnn
::
State
Attr
&
m
)
{
return
m
.
pre_var
;
});
for
(
const
auto
&
item
:
step
ne
t_
->
Outputs
())
{
for
(
const
auto
&
item
:
step
_uni
t_
->
Outputs
())
{
for
(
const
auto
&
var
:
item
.
second
)
{
for
(
const
auto
&
var
:
item
.
second
)
{
step
ne
t_outputs
.
push_back
(
var
);
step
_uni
t_outputs
.
push_back
(
var
);
}
}
}
}
...
@@ -179,13 +181,13 @@ void DynamicRecurrentOp::CreateScopes() const {
...
@@ -179,13 +181,13 @@ void DynamicRecurrentOp::CreateScopes() const {
auto
&
scope
=
cache_
.
GetScope
(
step
);
auto
&
scope
=
cache_
.
GetScope
(
step
);
detail
::
CreateVariables
(
scope
,
arg_
.
inlinks
);
detail
::
CreateVariables
(
scope
,
arg_
.
inlinks
);
detail
::
CreateVariables
(
scope
,
arg_
.
outlinks
);
detail
::
CreateVariables
(
scope
,
arg_
.
outlinks
);
detail
::
CreateVariables
(
scope
,
memori
es
);
detail
::
CreateVariables
(
scope
,
stat
es
);
detail
::
CreateVariables
(
scope
,
pre_memori
es
);
detail
::
CreateVariables
(
scope
,
ex_stat
es
);
detail
::
CreateVariables
(
scope
,
step
ne
t_outputs
);
detail
::
CreateVariables
(
scope
,
step
_uni
t_outputs
);
}
}
}
}
void
DynamicRecurrentOp
::
ConcatOutputs
()
const
{
void
RNNAlgorithm
::
ConcatOutputs
()
{
// TODO(superjom) transform this to a config
// TODO(superjom) transform this to a config
int
level
=
0
;
int
level
=
0
;
for
(
size_t
step
=
0
;
step
<
cache_
.
num_steps
;
step
++
)
{
for
(
size_t
step
=
0
;
step
<
cache_
.
num_steps
;
step
++
)
{
...
@@ -198,31 +200,45 @@ void DynamicRecurrentOp::ConcatOutputs() const {
...
@@ -198,31 +200,45 @@ void DynamicRecurrentOp::ConcatOutputs() const {
item
.
second
.
WriteShared
(
step
,
*
tensor
);
item
.
second
.
WriteShared
(
step
,
*
tensor
);
}
}
}
}
// the in
link
s' lods should be the same, so randomly get one lod.
// the in
put
s' lods should be the same, so randomly get one lod.
const
auto
&
some_lod
=
const
auto
&
some_lod
=
cache_
.
scope
->
FindVar
(
arg_
.
inlinks
.
front
())
->
Get
<
LoDTensor
>
().
lod
();
cache_
.
scope
->
FindVar
(
arg_
.
inlinks
.
front
())
->
Get
<
LoDTensor
>
().
lod
();
const
auto
&
some_meta
=
dy_seq_metas_
[
arg_
.
inlinks
.
front
()];
const
auto
&
some_meta
=
dy_seq_metas_
[
arg_
.
inlinks
.
front
()];
for
(
auto
&
item
:
step_outputs_
)
{
for
(
auto
&
item
:
step_outputs_
)
{
auto
tensor
=
item
.
second
.
Pack
(
level
,
some_meta
,
some_lod
);
auto
tensor
=
item
.
second
.
Pack
(
level
,
some_meta
,
some_lod
);
auto
*
output
=
cache_
.
out
link
s
[
item
.
first
]
->
GetMutable
<
LoDTensor
>
();
auto
*
output
=
cache_
.
out
put
s
[
item
.
first
]
->
GetMutable
<
LoDTensor
>
();
const_cast
<
LoDTensor
*>
(
output
)
->
ShareDataWith
(
tensor
);
const_cast
<
LoDTensor
*>
(
output
)
->
ShareDataWith
(
tensor
);
}
}
}
}
void
DynamicRecurrentOp
::
InitStates
()
const
{
void
RNNAlgorithm
::
RunSteps
()
{
if
(
IsBackward
())
{
// call stepnet in all the time steps reversely
for
(
int
step
=
cache_
.
num_steps
-
1
;
step
>=
0
;
step
--
)
{
auto
&
step_scope
=
cache_
.
GetScope
(
step
);
step_unit_
->
Run
(
step_scope
,
*
cache_
.
dev_ctx
);
}
}
else
{
for
(
size_t
step
=
0
;
step
<
cache_
.
num_steps
;
step
++
)
{
auto
&
step_scope
=
cache_
.
GetScope
(
step
);
step_unit_
->
Run
(
step_scope
,
*
cache_
.
dev_ctx
);
}
}
}
void
RNNAlgorithm
::
InitStates
()
{
for
(
size_t
step
=
0
;
step
<
cache_
.
num_steps
;
step
++
)
{
for
(
size_t
step
=
0
;
step
<
cache_
.
num_steps
;
step
++
)
{
for
(
const
auto
&
memory
:
arg_
.
memori
es
)
{
for
(
const
auto
&
state
:
arg_
.
stat
es
)
{
CreateState
(
memory
,
step
);
CreateState
(
state
,
step
);
LinkState
(
memory
,
step
);
LinkState
(
state
,
step
);
}
}
}
}
}
}
void
DynamicRecurrentOp
::
CreateState
(
const
rnn
::
MemoryAttr
&
memory
,
void
RNNAlgorithm
::
CreateState
(
const
rnn
::
StateAttr
&
state_attr
,
size_t
step
)
{
size_t
step
)
const
{
auto
&
scope
=
cache_
.
GetScope
(
step
);
auto
&
scope
=
cache_
.
GetScope
(
step
);
auto
&
state
=
*
cache_
.
GetTensor
(
scope
,
memory
.
var
);
auto
&
state
=
*
cache_
.
GetTensor
(
scope
,
state_attr
.
var
);
auto
&
boot_state
=
*
cache_
.
GetTensor
(
*
cache_
.
scope
,
memory
.
boot_var
);
auto
&
boot_state
=
*
cache_
.
GetTensor
(
*
cache_
.
scope
,
state_attr
.
boot_var
);
size_t
num_instances
=
size_t
num_instances
=
step_inputs_
[
arg_
.
inlinks
.
front
()].
Read
(
step
).
dims
()[
0
];
step_inputs_
[
arg_
.
inlinks
.
front
()].
Read
(
step
).
dims
()[
0
];
...
@@ -231,55 +247,78 @@ void DynamicRecurrentOp::CreateState(const rnn::MemoryAttr& memory,
...
@@ -231,55 +247,78 @@ void DynamicRecurrentOp::CreateState(const rnn::MemoryAttr& memory,
state
.
Resize
(
dims
);
state
.
Resize
(
dims
);
state
.
mutable_data
<
value_type
>
(
platform
::
CPUPlace
());
state
.
mutable_data
<
value_type
>
(
platform
::
CPUPlace
());
states_
[
memory
.
var
].
WriteShared
(
step
,
state
);
states_
[
state_attr
.
var
].
WriteShared
(
step
,
state
);
}
}
void
DynamicRecurrentOp
::
LinkState
(
const
rnn
::
MemoryAttr
&
memory
,
void
RNNAlgorithm
::
LinkState
(
const
rnn
::
StateAttr
&
state
,
size_t
step
)
{
size_t
step
)
const
{
auto
&
scope
=
cache_
.
GetScope
(
step
);
auto
&
scope
=
cache_
.
GetScope
(
step
);
auto
&
state_pre
=
*
cache_
.
GetTensor
(
scope
,
memory
.
pre_var
);
auto
&
state_pre
=
*
cache_
.
GetTensor
(
scope
,
state
.
pre_var
);
// process the first state's boot-state(the 0-step in forward mode or the
// last step in backward mode)
// Only forward mode need to link the boot-state to the `pre-state` in first
// time step. In backward mode, need to copy the gradient of `pre-state` in
// first time step to the gradient of `boot-state`.
if
(
step
==
0
&&
IsForward
())
{
LinkInitialState
(
state
);
}
else
{
size_t
num_instances
=
step_inputs_
[
arg_
.
inlinks
.
front
()].
Read
(
step
).
dims
()[
0
];
auto
*
pre_state
=
cache_
.
GetTensor
(
cache_
.
GetScope
(
step
-
1
),
state
.
var
);
// shink and share from previous state
auto
shrinked_pre_state
=
pre_state
->
Slice
(
0
,
num_instances
);
state_pre
.
ShareDataWith
(
shrinked_pre_state
);
}
}
void
RNNAlgorithm
::
LinkInitialState
(
const
rnn
::
StateAttr
&
state
)
{
// all the step_inputs' metas should be the same, just randomly select one
// all the step_inputs' metas should be the same, just randomly select one
// and get the dyseq meta.
// and get the dyseq meta.
const
auto
&
some_meta
=
dy_seq_metas_
[
arg_
.
inlinks
.
front
()];
const
auto
&
some_meta
=
dy_seq_metas_
[
arg_
.
inlinks
.
front
()];
size_t
num_instances
=
auto
&
scope
=
cache_
.
GetScope
(
0
);
step_inputs_
[
arg_
.
inlinks
.
front
()].
Read
(
step
).
dims
()[
0
];
auto
&
state_pre
=
*
cache_
.
GetTensor
(
scope
,
state
.
pre_var
);
auto
*
pre_state
=
cache_
.
GetTensor
(
*
cache_
.
scope
,
state
.
boot_var
);
LoDTensor
*
pre_state
{
nullptr
};
if
(
step
==
0
)
{
pre_state
=
cache_
.
GetTensor
(
*
cache_
.
scope
,
memory
.
boot_var
);
pre_state
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
pre_state
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
// allocate memory
// allocate state
state_pre
.
Resize
(
pre_state
->
dims
());
state_pre
.
Resize
(
pre_state
->
dims
());
state_pre
.
mutable_data
<
value_type
>
(
platform
::
CPUPlace
());
state_pre
.
mutable_data
<
value_type
>
(
platform
::
CPUPlace
());
detail
::
ReorderBootState
<
value_type
>
(
some_meta
,
*
pre_state
,
&
state_pre
,
detail
::
ReorderInitialState
(
some_meta
,
*
pre_state
,
&
state_pre
,
pre_state
->
place
());
pre_state
->
place
());
}
else
{
}
pre_state
=
cache_
.
GetTensor
(
cache_
.
GetScope
(
step
-
1
),
memory
.
var
);
}
// shink and share from previous state
void
RNNAlgorithm
::
ExportInitialStateGradient
(
const
rnn
::
StateAttr
&
state
)
{
auto
shrinked_pre_state
=
pre_state
->
Slice
(
0
,
num_instances
);
// all the step_inputs' metas should be the same, just randomly select one
state_pre
.
ShareDataWith
(
shrinked_pre_state
);
// and get the dyseq meta.
const
auto
&
some_meta
=
dy_seq_metas_
[
arg_
.
inlinks
.
front
()];
auto
&
scope
=
cache_
.
GetScope
(
0
);
auto
&
state_pre
=
*
cache_
.
GetTensor
(
scope
,
state
.
pre_var
);
auto
&
pre_state
=
*
cache_
.
GetTensor
(
*
cache_
.
scope
,
state
.
boot_var
);
pre_state
.
Resize
(
state_pre
.
dims
());
detail
::
RestoreInitialState
(
some_meta
,
state_pre
,
&
pre_state
,
pre_state
.
place
());
}
}
void
DynamicRecurrentOp
::
ArgCache
::
Init
(
void
RNNAlgorithm
::
ArgCache
::
Init
(
const
rnn
::
ArgumentName
&
name
,
const
rnn
::
ArgumentName
&
name
,
const
paddle
::
framework
::
OperatorBase
&
op
,
const
paddle
::
framework
::
OperatorBase
&
op
,
const
paddle
::
framework
::
Scope
&
scope
,
rnn
::
Argument
*
arg
)
{
const
paddle
::
framework
::
Scope
&
scope
,
platform
::
DeviceContext
const
*
dev_ctx
,
rnn
::
Argument
*
arg
)
{
this
->
scope
=
&
scope
;
this
->
scope
=
&
scope
;
InitArgument
(
name
,
op
,
arg
);
InitArgument
(
name
,
op
,
arg
);
CacheScopes
(
scope
,
*
arg
);
CacheScopes
(
scope
,
*
arg
);
CacheInlinks
(
scope
,
arg
->
inlinks
);
CacheInlinks
(
scope
,
arg
->
inlinks
);
CacheOutlinks
(
scope
,
arg
->
outlinks
);
CacheOutlinks
(
scope
,
arg
->
outlinks
);
this
->
dev_ctx
=
dev_ctx
;
}
}
void
DynamicRecurrentOp
::
ArgCache
::
InitArgument
(
const
rnn
::
ArgumentName
&
name
,
void
RNNAlgorithm
::
ArgCache
::
InitArgument
(
const
rnn
::
ArgumentName
&
name
,
const
OperatorBase
&
op
,
const
OperatorBase
&
op
,
rnn
::
Argument
*
arg
)
{
rnn
::
Argument
*
arg
)
{
rnn
::
InitArgument
(
name
,
arg
,
op
,
false
/*is_grad*/
);
rnn
::
InitArgument
(
name
,
arg
,
op
,
false
/*is_grad*/
);
}
}
void
DynamicRecurrentOp
::
ArgCache
::
CacheScopes
(
const
Scope
&
scope
,
void
RNNAlgorithm
::
ArgCache
::
CacheScopes
(
const
Scope
&
scope
,
const
rnn
::
Argument
&
arg
)
{
const
rnn
::
Argument
&
arg
)
{
auto
scopes_var
=
scope
.
FindVar
(
arg
.
step_scopes
);
auto
scopes_var
=
scope
.
FindVar
(
arg
.
step_scopes
);
PADDLE_ENFORCE
(
scopes_var
!=
nullptr
,
PADDLE_ENFORCE
(
scopes_var
!=
nullptr
,
...
@@ -289,45 +328,85 @@ void DynamicRecurrentOp::ArgCache::CacheScopes(const Scope& scope,
...
@@ -289,45 +328,85 @@ void DynamicRecurrentOp::ArgCache::CacheScopes(const Scope& scope,
this
->
scopes
=
scopes_var
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
this
->
scopes
=
scopes_var
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
}
}
void
DynamicRecurrentOp
::
ArgCache
::
CacheInlinks
(
void
RNNAlgorithm
::
ArgCache
::
CacheInlinks
(
const
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
names
)
{
const
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
names
)
{
for
(
auto
name
:
names
)
{
for
(
auto
name
:
names
)
{
auto
*
var
=
GetVariable
(
scope
,
name
);
auto
*
var
=
GetVariable
(
scope
,
name
);
in
link
s
[
name
]
=
var
;
in
put
s
[
name
]
=
var
;
}
}
}
}
void
DynamicRecurrentOp
::
ArgCache
::
CacheOutlinks
(
void
RNNAlgorithm
::
ArgCache
::
CacheOutlinks
(
const
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
names
)
{
const
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
names
)
{
for
(
auto
name
:
names
)
{
for
(
auto
name
:
names
)
{
auto
*
var
=
GetVariable
(
scope
,
name
);
auto
*
var
=
GetVariable
(
scope
,
name
);
out
link
s
[
name
]
=
var
;
out
put
s
[
name
]
=
var
;
}
}
}
}
Variable
*
DynamicRecurrentOp
::
ArgCache
::
GetVariable
(
const
Scope
&
scope
,
Variable
*
RNNAlgorithm
::
ArgCache
::
GetVariable
(
const
Scope
&
scope
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
)
{
auto
*
var
=
scope
.
FindVar
(
name
);
auto
*
var
=
scope
.
FindVar
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"variable [%s] not exist in scope"
,
name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"variable [%s] not exist in scope"
,
name
);
return
var
;
return
var
;
}
}
LoDTensor
*
DynamicRecurrentOp
::
ArgCache
::
GetTensor
(
LoDTensor
*
RNNAlgorithm
::
ArgCache
::
GetTensor
(
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
)
{
auto
*
var
=
GetVariable
(
scope
,
name
);
auto
*
var
=
GetVariable
(
scope
,
name
);
return
var
->
GetMutable
<
LoDTensor
>
();
return
var
->
GetMutable
<
LoDTensor
>
();
}
}
const
rnn
::
ArgumentName
DynamicRecurrentOp
::
kArgName
{
const
std
::
array
<
rnn
::
ArgumentName
,
2
>
RNNAlgorithm
::
kArgNames
{
"step_net"
,
"step_scopes"
,
"inlinks"
,
"outlinks"
,
rnn
::
ArgumentName
{
"step_unit"
,
"step_scopes"
,
"inputs"
,
"outputs"
,
"states"
,
"memories"
,
"pre_memories"
,
"boot_memories"
};
"ex_states"
,
"initial_states"
},
rnn
::
ArgumentName
{
"step_unit"
,
"step_scopes@GRAD"
,
"outputs@GRAD"
,
"inputs@GRAD"
,
"states"
,
"ex_states"
,
"initial_states@GRAD"
}};
void
DynamicRecurrentOp
::
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
rnn
.
Run
<
RNNAlgorithm
::
ComputeMode
::
kForward
>
(
scope
,
*
dynamic_cast
<
const
OperatorBase
*>
(
this
),
dev_ctx
);
}
void
DynamicRecurrentGradientOp
::
Run
(
void
DynamicRecurrentGradientOp
::
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{}
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
rnn
.
Run
<
RNNAlgorithm
::
ComputeMode
::
kBackward
>
(
scope
,
*
dynamic_cast
<
const
OperatorBase
*>
(
this
),
dev_ctx
);
}
class
DynamicRecurrentOpProtoAndCheckerMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
DynamicRecurrentOpProtoAndCheckerMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
const
auto
&
name
=
RNNAlgorithm
::
kArgNames
[
RNNAlgorithm
::
ComputeMode
::
kForward
];
// inputs and outputs stored in proto
AddInput
(
name
.
inlinks
,
"the inputs that need to be segmented for each step."
)
.
AsDuplicable
();
AddInput
(
name
.
initial_states
,
"variables to initialize states."
)
.
AsDuplicable
();
AddOutput
(
name
.
outlinks
,
"the outputs that need to concated for all steps."
)
.
AsDuplicable
();
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
// Attributes stored in AttributeMap
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
ex_states
,
"names of ex_states"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
states
,
"names of states"
);
AddComment
(
"This is a RNN operator for varience-length sequences."
);
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP_WITHOUT_GRADIENT
(
REGISTER_OP
(
dynamic_recurrent
,
paddle
::
operators
::
DynamicRecurrentOp
,
dynamic_recurrent
,
paddle
::
operators
::
DynamicRecurrentOp
,
paddle
::
operators
::
DynamicRecurrentOpProtoAndCheckerMaker
,
paddle
::
operators
::
DynamicRecurrentOpProtoAndCheckerMaker
);
dynamic_recurrent_grad
,
paddle
::
operators
::
DynamicRecurrentGradientOp
);
paddle/operators/dynamic_recurrent_op.h
浏览文件 @
07ea9ade
...
@@ -27,47 +27,39 @@
...
@@ -27,47 +27,39 @@
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
DynamicRecurrentOp
:
public
framework
::
OperatorBase
{
class
RNNAlgorithm
{
public:
public:
static
const
rnn
::
ArgumentName
kArgName
;
enum
ComputeMode
{
kForward
=
0
,
kBackward
=
1
};
static
const
std
::
array
<
rnn
::
ArgumentName
,
2
>
kArgNames
;
using
value_type
=
float
;
using
value_type
=
float
;
DynamicRecurrentOp
(
const
std
::
string
&
type
,
/*
const
framework
::
VariableNameMap
&
inputs
,
* Different `Run` method for forward and backward, `_` is just for template
const
framework
::
VariableNameMap
&
outputs
,
* specifialization.
const
framework
::
AttributeMap
&
attrs
)
*/
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
template
<
ComputeMode
_
>
void
Run
(
const
framework
::
Scope
&
scope
,
const
framework
::
OperatorBase
&
op
,
DynamicRecurrentOp
(
const
DynamicRecurrentOp
&
o
)
const
platform
::
DeviceContext
&
dev_ctx
);
:
framework
::
OperatorBase
(
static_cast
<
const
framework
::
OperatorBase
&>
(
o
))
{
// TODO(yuyang18): Implement copy ctor well.
PADDLE_THROW
(
"Not implemented"
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
;
/*
/*
* Split the inputs(LoDTensors) to segments for each time step.
* Split the inputs(LoDTensors) to segments for each time step.
*/
*/
void
SplitInputs
()
const
;
void
SplitInputs
();
/*
/*
* Create step-scopes to store temporary outputs in each time steps.
* Create step-scopes to store temporary outputs in each time steps.
*/
*/
void
CreateScopes
()
const
;
void
CreateScopes
();
/*
/*
* Link TensorArray steps to the corresponding variables located in
* Link TensorArray steps to the corresponding variables located in
* step-scopes.
* step-scopes.
*/
*/
void
WriteStepInputs
()
const
;
void
WriteStepInputs
();
/*
/*
* Write output of each step to the corresponding TensorArray.
* Write output of each step to the corresponding TensorArray.
*/
*/
void
WriteStepOutputs
()
const
;
void
WriteStepOutputs
();
/*
/*
* Initialize the states, each state will have a corresponding pre-state,
* Initialize the states, each state will have a corresponding pre-state,
...
@@ -75,54 +67,83 @@ class DynamicRecurrentOp : public framework::OperatorBase {
...
@@ -75,54 +67,83 @@ class DynamicRecurrentOp : public framework::OperatorBase {
* pre-state in the first time step will be initialized with an zero tensor or
* pre-state in the first time step will be initialized with an zero tensor or
* a tensor in parent scope if is provided.
* a tensor in parent scope if is provided.
*/
*/
void
InitStates
()
const
;
void
InitStates
();
/*
/*
* Create state variables for each time step.
* Create state variables for each time step.
*/
*/
void
CreateState
(
const
rnn
::
MemoryAttr
&
memory
,
size_t
step
)
const
;
void
CreateState
(
const
rnn
::
StateAttr
&
state
,
size_t
step
)
;
/*
/*
* Link pre-state variable in current scope to the state variable in the
* Link pre-state variable in current scope to the state variable in the
* previous time step (scope).
* previous time step (scope) by reference.
*/
void
LinkState
(
const
rnn
::
StateAttr
&
state
,
size_t
step
);
/*
* Link the pre-state of the first time step to the `boot-state` in parent's
* scope.
*/
void
LinkInitialState
(
const
rnn
::
StateAttr
&
state
);
/*
* Copy the gradient from `pre-state` in the first step-scope to the
* `boot-state` in parent's scope.
*/
void
ExportInitialStateGradient
(
const
rnn
::
StateAttr
&
state
);
/*
* Calculate time steps.
*/
*/
void
LinkState
(
const
rnn
::
MemoryAttr
&
memory
,
size_t
step
)
const
;
void
RunSteps
()
;
/*
/*
* Concatenate outputs in each time step and generate a LoDTensor.
* Concatenate outputs in each time step and generate a LoDTensor.
*/
*/
void
ConcatOutputs
()
const
;
void
ConcatOutputs
();
void
SetComputeMode
(
ComputeMode
mode
)
{
mode_
=
mode
;
}
bool
IsForward
()
const
{
return
mode_
==
ComputeMode
::
kForward
;
}
bool
IsBackward
()
const
{
return
mode_
==
ComputeMode
::
kBackward
;
}
/*
/*
* set a step
net that is created according to a RecurrentOp's stepne
t.
* set a step
unit that is created according to a RecurrentOp's step uni
t.
*/
*/
void
SetStep
Net
(
std
::
unique_ptr
<
OperatorBase
>
ne
t
)
{
void
SetStep
Unit
(
std
::
unique_ptr
<
framework
::
OperatorBase
>
step_uni
t
)
{
PADDLE_ENFORCE_NOT_NULL
(
ne
t
);
PADDLE_ENFORCE_NOT_NULL
(
step_uni
t
);
step
net_
=
std
::
move
(
ne
t
);
step
_unit_
=
std
::
move
(
step_uni
t
);
}
}
const
OperatorBase
&
GetStepNet
()
const
{
return
*
stepne
t_
;
}
const
framework
::
OperatorBase
&
GetStepUnit
()
const
{
return
*
step_uni
t_
;
}
const
framework
::
TensorArray
&
state
(
const
std
::
string
&
name
)
const
{
const
framework
::
TensorArray
&
state
(
const
std
::
string
&
name
)
const
{
return
states_
[
name
];
auto
it
=
states_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
states_
.
end
());
return
it
->
second
;
}
}
const
framework
::
TensorArray
&
step_input
(
const
std
::
string
&
name
)
const
{
const
framework
::
TensorArray
&
step_input
(
const
std
::
string
&
name
)
const
{
return
step_inputs_
[
name
];
auto
it
=
step_inputs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
step_inputs_
.
end
());
return
it
->
second
;
}
}
const
framework
::
TensorArray
&
step_output
(
const
std
::
string
&
name
)
const
{
const
framework
::
TensorArray
&
step_output
(
const
std
::
string
&
name
)
const
{
return
step_outputs_
[
name
];
auto
it
=
step_outputs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
step_outputs_
.
end
());
return
it
->
second
;
}
}
protected:
protected:
struct
ArgCache
{
struct
ArgCache
{
framework
::
Scope
const
*
scope
;
framework
::
Scope
const
*
scope
;
std
::
vector
<
framework
::
Scope
*>*
scopes
;
std
::
vector
<
framework
::
Scope
*>*
scopes
;
std
::
map
<
std
::
string
,
framework
::
Variable
*>
inlinks
;
std
::
map
<
std
::
string
,
framework
::
Variable
*>
inputs
;
std
::
map
<
std
::
string
,
framework
::
Variable
*>
outlinks
;
std
::
map
<
std
::
string
,
framework
::
Variable
*>
outputs
;
platform
::
DeviceContext
const
*
dev_ctx
;
size_t
num_steps
{
0
};
size_t
num_steps
{
0
};
void
Init
(
const
rnn
::
ArgumentName
&
name
,
const
OperatorBase
&
op
,
void
Init
(
const
rnn
::
ArgumentName
&
name
,
const
framework
::
OperatorBase
&
op
,
const
framework
::
Scope
&
scope
,
rnn
::
Argument
*
arg
);
const
framework
::
Scope
&
scope
,
platform
::
DeviceContext
const
*
dev_ctx
,
rnn
::
Argument
*
arg
);
framework
::
Scope
&
GetScope
(
size_t
index
)
{
framework
::
Scope
&
GetScope
(
size_t
index
)
{
PADDLE_ENFORCE_LT
(
index
,
num_steps
);
PADDLE_ENFORCE_LT
(
index
,
num_steps
);
...
@@ -133,8 +154,8 @@ class DynamicRecurrentOp : public framework::OperatorBase {
...
@@ -133,8 +154,8 @@ class DynamicRecurrentOp : public framework::OperatorBase {
const
std
::
string
&
name
);
const
std
::
string
&
name
);
private:
private:
void
InitArgument
(
const
rnn
::
ArgumentName
&
name
,
const
OperatorBase
&
op
,
void
InitArgument
(
const
rnn
::
ArgumentName
&
name
,
rnn
::
Argument
*
arg
);
const
framework
::
OperatorBase
&
op
,
rnn
::
Argument
*
arg
);
void
CacheScopes
(
const
framework
::
Scope
&
scope
,
const
rnn
::
Argument
&
arg
);
void
CacheScopes
(
const
framework
::
Scope
&
scope
,
const
rnn
::
Argument
&
arg
);
void
CacheInlinks
(
const
framework
::
Scope
&
scope
,
void
CacheInlinks
(
const
framework
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
names
);
const
std
::
vector
<
std
::
string
>&
names
);
...
@@ -145,27 +166,49 @@ class DynamicRecurrentOp : public framework::OperatorBase {
...
@@ -145,27 +166,49 @@ class DynamicRecurrentOp : public framework::OperatorBase {
};
};
private:
private:
std
::
unique_ptr
<
OperatorBase
>
stepne
t_
;
std
::
unique_ptr
<
framework
::
OperatorBase
>
step_uni
t_
;
mutable
std
::
map
<
std
::
string
,
framework
::
TensorArray
>
states_
;
std
::
map
<
std
::
string
,
framework
::
TensorArray
>
states_
;
mutable
std
::
map
<
std
::
string
,
framework
::
TensorArray
>
step_inputs_
;
std
::
map
<
std
::
string
,
framework
::
TensorArray
>
step_inputs_
;
mutable
std
::
map
<
std
::
string
,
framework
::
TensorArray
>
step_outputs_
;
std
::
map
<
std
::
string
,
framework
::
TensorArray
>
step_outputs_
;
mutable
std
::
map
<
std
::
string
,
std
::
vector
<
framework
::
DySeqMeta
>>
std
::
map
<
std
::
string
,
std
::
vector
<
framework
::
DySeqMeta
>>
dy_seq_metas_
;
dy_seq_metas
_
;
rnn
::
Argument
arg
_
;
mutable
rnn
::
Argument
arg
_
;
ArgCache
cache
_
;
mutable
ArgCache
cache_
;
ComputeMode
mode_
{
ComputeMode
::
kForward
}
;
#ifdef PADDLE_WITH_TESTING
#ifdef PADDLE_WITH_TESTING
friend
class
DynamicRecurrentOpTestHelper
;
// test forward
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
SplitInputs
);
friend
class
RNNAlgorithmTestHelper
;
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
CreateCache
);
FRIEND_TEST
(
RNNAlgorithmTestHelper
,
SplitInputs
);
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
CreateScopes
);
FRIEND_TEST
(
RNNAlgorithmTestHelper
,
CreateCache
);
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
WriteStepInputs
);
FRIEND_TEST
(
RNNAlgorithmTestHelper
,
CreateScopes
);
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
WriteStepOutputs
);
FRIEND_TEST
(
RNNAlgorithmTestHelper
,
WriteStepInputs
);
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
InitStates
);
FRIEND_TEST
(
RNNAlgorithmTestHelper
,
WriteStepOutputs
);
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
ConcatOutputs
);
FRIEND_TEST
(
RNNAlgorithmTestHelper
,
InitStates
);
FRIEND_TEST
(
RNNAlgorithmTestHelper
,
ConcatOutputs
);
// TODO(superjom) test backward
#endif
#endif
};
};
class
DynamicRecurrentOp
:
public
framework
::
OperatorBase
{
public:
DynamicRecurrentOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
DynamicRecurrentOp
(
const
DynamicRecurrentOp
&
o
)
:
framework
::
OperatorBase
(
static_cast
<
const
framework
::
OperatorBase
&>
(
o
))
{
PADDLE_THROW
(
"Not implemented"
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
;
mutable
RNNAlgorithm
rnn
;
};
class
DynamicRecurrentGradientOp
:
public
framework
::
OperatorBase
{
class
DynamicRecurrentGradientOp
:
public
framework
::
OperatorBase
{
public:
public:
DynamicRecurrentGradientOp
(
const
std
::
string
&
type
,
DynamicRecurrentGradientOp
(
const
std
::
string
&
type
,
...
@@ -174,8 +217,16 @@ class DynamicRecurrentGradientOp : public framework::OperatorBase {
...
@@ -174,8 +217,16 @@ class DynamicRecurrentGradientOp : public framework::OperatorBase {
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
DynamicRecurrentGradientOp
(
const
DynamicRecurrentGradientOp
&
o
)
:
framework
::
OperatorBase
(
static_cast
<
const
framework
::
OperatorBase
&>
(
o
))
{
PADDLE_THROW
(
"Not implemented"
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
;
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
;
mutable
RNNAlgorithm
rnn
;
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/operators/dynamic_recurrent_op_test.cc
浏览文件 @
07ea9ade
...
@@ -43,16 +43,16 @@ LoDTensor* CreateVar(Scope& scope, std::string name, framework::DDim dims,
...
@@ -43,16 +43,16 @@ LoDTensor* CreateVar(Scope& scope, std::string name, framework::DDim dims,
return
tensor
;
return
tensor
;
}
}
class
DynamicRecurrentOp
TestHelper
:
public
::
testing
::
Test
{
class
RNNAlgorithm
TestHelper
:
public
::
testing
::
Test
{
protected:
protected:
const
rnn
::
ArgumentName
argname
=
DynamicRecurrentOp
::
kArgName
;
const
rnn
::
ArgumentName
argname
=
RNNAlgorithm
::
kArgNames
[
0
]
;
virtual
void
SetUp
()
override
{
virtual
void
SetUp
()
override
{
CreateGlobalVariables
();
CreateGlobalVariables
();
auto
op_desc
=
CreateOpDesc
();
auto
op_desc
=
CreateOpDesc
();
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
,
nullptr
);
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
,
nullptr
);
dop
=
dynamic_cast
<
DynamicRecurrentOp
*>
(
op
.
get
()
);
dop
=
&
(
dynamic_cast
<
DynamicRecurrentOp
*>
(
op
.
get
())
->
rnn
);
InitCacheManually
();
InitCacheManually
();
InitStepNet
();
InitStepNet
();
}
}
...
@@ -63,20 +63,20 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
...
@@ -63,20 +63,20 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
op_desc
.
set_type
(
"dynamic_recurrent"
);
op_desc
.
set_type
(
"dynamic_recurrent"
);
OpDescNewVar
(
argname
.
inlinks
,
{
"in0"
},
op_desc
.
add_inputs
());
OpDescNewVar
(
argname
.
inlinks
,
{
"in0"
},
op_desc
.
add_inputs
());
OpDescNewVar
(
argname
.
boot_memori
es
,
{
"boot_mem"
},
op_desc
.
add_inputs
());
OpDescNewVar
(
argname
.
initial_stat
es
,
{
"boot_mem"
},
op_desc
.
add_inputs
());
OpDescNewVar
(
argname
.
step_scopes
,
{
"step_scopes"
},
op_desc
.
add_outputs
());
OpDescNewVar
(
argname
.
step_scopes
,
{
"step_scopes"
},
op_desc
.
add_outputs
());
OpDescNewVar
(
argname
.
outlinks
,
{
"out0"
},
op_desc
.
add_outputs
());
OpDescNewVar
(
argname
.
outlinks
,
{
"out0"
},
op_desc
.
add_outputs
());
// set pre-
memori
es
// set pre-
stat
es
auto
pre_memories
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
pre_memories
=
op_desc
.
mutable_attrs
()
->
Add
();
pre_memories
->
set_name
(
argname
.
pre_memori
es
);
pre_memories
->
set_name
(
argname
.
ex_stat
es
);
pre_memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
pre_memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
pre_memories_item
=
pre_memories
->
add_strings
();
auto
pre_memories_item
=
pre_memories
->
add_strings
();
*
pre_memories_item
=
"mem@pre"
;
*
pre_memories_item
=
"mem@pre"
;
// set
memori
es
// set
stat
es
auto
memories
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
memories
=
op_desc
.
mutable_attrs
()
->
Add
();
memories
->
set_name
(
argname
.
memori
es
);
memories
->
set_name
(
argname
.
stat
es
);
memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
memories_item
=
memories
->
add_strings
();
auto
memories_item
=
memories
->
add_strings
();
*
memories_item
=
"mem"
;
*
memories_item
=
"mem"
;
...
@@ -113,32 +113,33 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
...
@@ -113,32 +113,33 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
}
}
void
InitCacheManually
()
{
void
InitCacheManually
()
{
dop
->
cache_
.
Init
(
DynamicRecurrentOp
::
kArgName
,
*
dop
,
scope
,
&
dop
->
arg_
);
dop
->
cache_
.
Init
(
RNNAlgorithm
::
kArgNames
[
0
],
*
op
,
scope
,
&
device_context
,
&
dop
->
arg_
);
}
}
void
InitStepNet
()
{
void
InitStepNet
()
{
std
::
unique_ptr
<
framework
::
OperatorBase
>
stepnet
{
new
NetOp
};
std
::
unique_ptr
<
framework
::
OperatorBase
>
stepnet
{
new
NetOp
};
dynamic_cast
<
NetOp
*>
(
stepnet
.
get
())
dynamic_cast
<
NetOp
*>
(
stepnet
.
get
())
->
AppendOp
(
std
::
unique_ptr
<
TestOp
>
(
new
TestOp
(
->
AppendOp
(
std
::
unique_ptr
<
TestOp
>
(
new
TestOp
(
"test"
,
{{
"in
links"
,
{
"in0"
}},
{
"boot_memori
es"
,
{
"boot_mem"
}}},
"test"
,
{{
"in
puts"
,
{
"in0"
}},
{
"initial_stat
es"
,
{
"boot_mem"
}}},
{{
"out
link
s"
,
{
"out0"
}},
{
"step_scopes"
,
{
"step_scopes"
}}},
{})));
{{
"out
put
s"
,
{
"out0"
}},
{
"step_scopes"
,
{
"step_scopes"
}}},
{})));
dop
->
SetStep
Ne
t
(
std
::
move
(
stepnet
));
dop
->
SetStep
Uni
t
(
std
::
move
(
stepnet
));
}
}
protected:
protected:
DynamicRecurrentOp
*
dop
;
RNNAlgorithm
*
dop
;
std
::
unique_ptr
<
framework
::
OperatorBase
>
op
;
std
::
unique_ptr
<
framework
::
OperatorBase
>
op
;
paddle
::
platform
::
CPUDeviceContext
device_context
;
paddle
::
platform
::
CPUDeviceContext
device_context
;
paddle
::
framework
::
Scope
scope
;
paddle
::
framework
::
Scope
scope
;
};
};
TEST_F
(
DynamicRecurrentOp
TestHelper
,
CreateCache
)
{
TEST_F
(
RNNAlgorithm
TestHelper
,
CreateCache
)
{
const
rnn
::
Argument
&
arg
=
dop
->
arg_
;
const
rnn
::
Argument
&
arg
=
dop
->
arg_
;
ASSERT_EQ
(
arg
.
inlinks
.
size
(),
1UL
);
ASSERT_EQ
(
arg
.
inlinks
.
size
(),
1UL
);
ASSERT_EQ
(
arg
.
outlinks
.
size
(),
1UL
);
ASSERT_EQ
(
arg
.
outlinks
.
size
(),
1UL
);
}
}
TEST_F
(
DynamicRecurrentOp
TestHelper
,
SplitInputs
)
{
TEST_F
(
RNNAlgorithm
TestHelper
,
SplitInputs
)
{
dop
->
SplitInputs
();
dop
->
SplitInputs
();
auto
&
in0_ta
=
dop
->
step_inputs_
[
"in0"
];
auto
&
in0_ta
=
dop
->
step_inputs_
[
"in0"
];
ASSERT_EQ
(
in0_ta
.
size
(),
4UL
);
ASSERT_EQ
(
in0_ta
.
size
(),
4UL
);
...
@@ -153,14 +154,14 @@ TEST_F(DynamicRecurrentOpTestHelper, SplitInputs) {
...
@@ -153,14 +154,14 @@ TEST_F(DynamicRecurrentOpTestHelper, SplitInputs) {
EXPECT_EQ
(
batch3
.
dims
()[
0
],
1
);
EXPECT_EQ
(
batch3
.
dims
()[
0
],
1
);
}
}
TEST_F
(
DynamicRecurrentOp
TestHelper
,
CreateScopes
)
{
TEST_F
(
RNNAlgorithm
TestHelper
,
CreateScopes
)
{
dop
->
SplitInputs
();
dop
->
SplitInputs
();
dop
->
CreateScopes
();
dop
->
CreateScopes
();
ASSERT_EQ
(
dop
->
cache_
.
num_steps
,
4UL
);
ASSERT_EQ
(
dop
->
cache_
.
num_steps
,
4UL
);
ASSERT_EQ
(
dop
->
cache_
.
scopes
->
size
(),
4UL
);
ASSERT_EQ
(
dop
->
cache_
.
scopes
->
size
(),
4UL
);
}
}
TEST_F
(
DynamicRecurrentOp
TestHelper
,
WriteStepInputs
)
{
TEST_F
(
RNNAlgorithm
TestHelper
,
WriteStepInputs
)
{
dop
->
SplitInputs
();
dop
->
SplitInputs
();
dop
->
CreateScopes
();
dop
->
CreateScopes
();
dop
->
WriteStepInputs
();
dop
->
WriteStepInputs
();
...
@@ -173,7 +174,7 @@ TEST_F(DynamicRecurrentOpTestHelper, WriteStepInputs) {
...
@@ -173,7 +174,7 @@ TEST_F(DynamicRecurrentOpTestHelper, WriteStepInputs) {
}
}
}
}
TEST_F
(
DynamicRecurrentOp
TestHelper
,
WriteStepOutputs
)
{
TEST_F
(
RNNAlgorithm
TestHelper
,
WriteStepOutputs
)
{
dop
->
SplitInputs
();
dop
->
SplitInputs
();
dop
->
CreateScopes
();
dop
->
CreateScopes
();
dop
->
WriteStepInputs
();
dop
->
WriteStepInputs
();
...
@@ -187,11 +188,12 @@ TEST_F(DynamicRecurrentOpTestHelper, WriteStepOutputs) {
...
@@ -187,11 +188,12 @@ TEST_F(DynamicRecurrentOpTestHelper, WriteStepOutputs) {
}
}
}
}
TEST_F
(
DynamicRecurrentOp
TestHelper
,
ConcatOutputs
)
{
TEST_F
(
RNNAlgorithm
TestHelper
,
ConcatOutputs
)
{
// Let's leave this test to python unittest.
// Let's leave this test to python unittest.
}
}
TEST_F
(
DynamicRecurrentOpTestHelper
,
InitStates
)
{
TEST_F
(
RNNAlgorithmTestHelper
,
InitStates
)
{
dop
->
SetComputeMode
(
RNNAlgorithm
::
ComputeMode
::
kForward
);
dop
->
SplitInputs
();
dop
->
SplitInputs
();
dop
->
CreateScopes
();
dop
->
CreateScopes
();
dop
->
WriteStepInputs
();
dop
->
WriteStepInputs
();
...
@@ -208,12 +210,6 @@ TEST_F(DynamicRecurrentOpTestHelper, InitStates) {
...
@@ -208,12 +210,6 @@ TEST_F(DynamicRecurrentOpTestHelper, InitStates) {
auto
*
boot_state
=
scope
.
FindVar
(
"boot_mem"
);
auto
*
boot_state
=
scope
.
FindVar
(
"boot_mem"
);
ASSERT_TRUE
(
boot_state
!=
nullptr
);
ASSERT_TRUE
(
boot_state
!=
nullptr
);
if
(
step
==
0
)
{
// check pre_state is a reference of boot_state
ASSERT_EQ
(
boot_state
->
Get
<
LoDTensor
>
().
data
<
float
>
(),
pre_state
->
Get
<
LoDTensor
>
().
data
<
float
>
());
}
}
}
}
}
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
07ea9ade
...
@@ -42,7 +42,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
...
@@ -42,7 +42,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
for
(
size_t
step_id
=
0
;
step_id
<
seq_len
;
step_id
++
)
{
for
(
size_t
step_id
=
0
;
step_id
<
seq_len
;
step_id
++
)
{
if
(
step_id
>
0
)
{
if
(
step_id
>
0
)
{
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memori
es
,
step_id
,
-
1
);
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
stat
es
,
step_id
,
-
1
);
}
}
(
*
stepnet_
)
->
Run
(
*
step_scopes
[
step_id
],
dev_ctx
);
(
*
stepnet_
)
->
Run
(
*
step_scopes
[
step_id
],
dev_ctx
);
}
}
...
@@ -59,7 +59,8 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope,
...
@@ -59,7 +59,8 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope,
// Now all variables in scope must be created outside of op.
// Now all variables in scope must be created outside of op.
PADDLE_ENFORCE_NOT_NULL
(
stepnet_
);
PADDLE_ENFORCE_NOT_NULL
(
stepnet_
);
PADDLE_ENFORCE
(
!
(
*
stepnet_
)
->
Outputs
().
empty
(),
"stepnet_ op has no outputs"
);
PADDLE_ENFORCE
(
!
(
*
stepnet_
)
->
Outputs
().
empty
(),
"step_unit_ op has no outputs"
);
if
(
seq_len
>
step_scopes
->
size
())
{
if
(
seq_len
>
step_scopes
->
size
())
{
for
(
size_t
i
=
step_scopes
->
size
();
i
<
seq_len
;
++
i
)
{
for
(
size_t
i
=
step_scopes
->
size
();
i
<
seq_len
;
++
i
)
{
...
@@ -86,7 +87,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope,
...
@@ -86,7 +87,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope,
}
}
void
RecurrentAlgorithm
::
InitMemories
(
Scope
*
step_scope
)
const
{
void
RecurrentAlgorithm
::
InitMemories
(
Scope
*
step_scope
)
const
{
for
(
auto
&
attr
:
arg_
->
memori
es
)
{
for
(
auto
&
attr
:
arg_
->
stat
es
)
{
auto
*
pre_mem
=
step_scope
->
Var
(
attr
.
pre_var
)
->
GetMutable
<
LoDTensor
>
();
auto
*
pre_mem
=
step_scope
->
Var
(
attr
.
pre_var
)
->
GetMutable
<
LoDTensor
>
();
PADDLE_ENFORCE
(
step_scope
->
FindVar
(
attr
.
boot_var
)
!=
nullptr
,
PADDLE_ENFORCE
(
step_scope
->
FindVar
(
attr
.
boot_var
)
!=
nullptr
,
"memory [%s]'s boot variable [%s] not exists"
,
attr
.
var
,
"memory [%s]'s boot variable [%s] not exists"
,
attr
.
var
,
...
@@ -100,12 +101,12 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
...
@@ -100,12 +101,12 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
}
}
const
rnn
::
ArgumentName
RecurrentOp
::
kArgName
{
const
rnn
::
ArgumentName
RecurrentOp
::
kArgName
{
"step_net"
,
"step_scopes"
,
"inlinks"
,
"outlink
s"
,
"step_net"
,
"step_scopes"
,
"inputs"
,
"output
s"
,
"
memories"
,
"pre_memories"
,
"boot_memori
es"
};
"
states"
,
"ex_states"
,
"initial_stat
es"
};
const
rnn
::
ArgumentName
RecurrentGradientOp
::
kArgName
{
const
rnn
::
ArgumentName
RecurrentGradientOp
::
kArgName
{
"step_net"
,
"step_scopes@GRAD"
,
"out
links@GRAD"
,
"inlink
s@GRAD"
,
"step_net"
,
"step_scopes@GRAD"
,
"out
puts@GRAD"
,
"input
s@GRAD"
,
"
memories"
,
"pre_memories"
,
"boot_memori
es@GRAD"
};
"
states"
,
"ex_states"
,
"initial_stat
es@GRAD"
};
RecurrentOp
::
RecurrentOp
(
const
std
::
string
&
type
,
RecurrentOp
::
RecurrentOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
inputs
,
...
@@ -127,7 +128,7 @@ class RecurrentAlgorithmProtoAndCheckerMaker
...
@@ -127,7 +128,7 @@ class RecurrentAlgorithmProtoAndCheckerMaker
AddInput
(
name
.
inlinks
,
AddInput
(
name
.
inlinks
,
"the inputs that need to be segmented for each step."
)
"the inputs that need to be segmented for each step."
)
.
AsDuplicable
();
.
AsDuplicable
();
AddInput
(
name
.
boot_memories
,
"variables to initialize memori
es."
)
AddInput
(
name
.
initial_states
,
"variables to initialize stat
es."
)
.
AsDuplicable
();
.
AsDuplicable
();
AddOutput
(
name
.
outlinks
,
"the outputs that need to concated for all steps."
)
AddOutput
(
name
.
outlinks
,
"the outputs that need to concated for all steps."
)
...
@@ -135,9 +136,8 @@ class RecurrentAlgorithmProtoAndCheckerMaker
...
@@ -135,9 +136,8 @@ class RecurrentAlgorithmProtoAndCheckerMaker
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
// Attributes stored in AttributeMap
// Attributes stored in AttributeMap
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
pre_memories
,
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
ex_states
,
"names of pre-states"
);
"names of pre-memories"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
states
,
"names of states"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
memories
,
"names of memories"
);
AddComment
(
"This is a recurrent group operator."
);
AddComment
(
"This is a recurrent group operator."
);
}
}
...
@@ -152,7 +152,7 @@ void RecurrentGradientAlgorithm::Run(
...
@@ -152,7 +152,7 @@ void RecurrentGradientAlgorithm::Run(
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len
);
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len
);
for
(
int
step_id
=
seq_len
-
1
;
step_id
>=
0
;
--
step_id
)
{
for
(
int
step_id
=
seq_len
-
1
;
step_id
>=
0
;
--
step_id
)
{
if
(
static_cast
<
size_t
>
(
step_id
)
!=
seq_len
-
1
)
{
if
(
static_cast
<
size_t
>
(
step_id
)
!=
seq_len
-
1
)
{
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memori
es
,
step_id
,
1
);
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
stat
es
,
step_id
,
1
);
}
}
(
*
stepnet_
)
->
Run
(
*
step_scopes
[
step_id
],
dev_ctx
);
(
*
stepnet_
)
->
Run
(
*
step_scopes
[
step_id
],
dev_ctx
);
}
}
...
@@ -162,7 +162,7 @@ void RecurrentGradientAlgorithm::Run(
...
@@ -162,7 +162,7 @@ void RecurrentGradientAlgorithm::Run(
void
RecurrentGradientAlgorithm
::
LinkBootMemoryGradients
(
void
RecurrentGradientAlgorithm
::
LinkBootMemoryGradients
(
Scope
*
step_scope
)
const
{
Scope
*
step_scope
)
const
{
for
(
auto
&
attr
:
arg_
->
memori
es
)
{
for
(
auto
&
attr
:
arg_
->
stat
es
)
{
PADDLE_ENFORCE
(
step_scope
->
FindVar
(
attr
.
var
)
!=
nullptr
,
PADDLE_ENFORCE
(
step_scope
->
FindVar
(
attr
.
var
)
!=
nullptr
,
"memory variable [%s] does not exists"
,
attr
.
var
);
"memory variable [%s] does not exists"
,
attr
.
var
);
PADDLE_ENFORCE
(
step_scope
->
FindVar
(
attr
.
boot_var
)
!=
nullptr
,
PADDLE_ENFORCE
(
step_scope
->
FindVar
(
attr
.
boot_var
)
!=
nullptr
,
...
...
paddle/operators/rnn/recurrent_op_utils.cc
浏览文件 @
07ea9ade
...
@@ -36,7 +36,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
...
@@ -36,7 +36,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
LoDTensor
*
input
=
input_var
->
GetMutable
<
LoDTensor
>
();
LoDTensor
*
input
=
input_var
->
GetMutable
<
LoDTensor
>
();
f
::
DDim
dims
=
input
->
dims
();
f
::
DDim
dims
=
input
->
dims
();
PADDLE_ENFORCE_EQ
(
static_cast
<
size_t
>
(
dims
[
0
]),
seq_len
,
PADDLE_ENFORCE_EQ
(
static_cast
<
size_t
>
(
dims
[
0
]),
seq_len
,
"all the in
link
s be the same length"
);
"all the in
put
s be the same length"
);
f
::
DDim
step_dims
=
slice_ddim
(
dims
,
1
,
dims
.
size
());
f
::
DDim
step_dims
=
slice_ddim
(
dims
,
1
,
dims
.
size
());
for
(
size_t
j
=
0
;
j
<
seq_len
;
j
++
)
{
for
(
size_t
j
=
0
;
j
<
seq_len
;
j
++
)
{
Tensor
*
step_input
=
Tensor
*
step_input
=
...
@@ -78,7 +78,7 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
...
@@ -78,7 +78,7 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
}
}
void
LinkMemories
(
const
std
::
vector
<
Scope
*>&
scopes
,
void
LinkMemories
(
const
std
::
vector
<
Scope
*>&
scopes
,
const
std
::
vector
<
rnn
::
Memory
Attr
>&
memories
,
const
std
::
vector
<
rnn
::
State
Attr
>&
memories
,
const
size_t
step_id
,
const
int
offset
)
{
const
size_t
step_id
,
const
int
offset
)
{
PADDLE_ENFORCE_LT
(
step_id
,
scopes
.
size
(),
PADDLE_ENFORCE_LT
(
step_id
,
scopes
.
size
(),
"step [%d] is out of range of step scopes' size [%d]"
,
"step [%d] is out of range of step scopes' size [%d]"
,
...
@@ -106,26 +106,26 @@ void InitArgument(const ArgumentName& name, Argument* arg,
...
@@ -106,26 +106,26 @@ void InitArgument(const ArgumentName& name, Argument* arg,
arg
->
inlinks
=
op
.
Inputs
(
name
.
inlinks
);
arg
->
inlinks
=
op
.
Inputs
(
name
.
inlinks
);
arg
->
outlinks
=
op
.
Outputs
(
name
.
outlinks
);
arg
->
outlinks
=
op
.
Outputs
(
name
.
outlinks
);
auto
&
boot_memories
=
auto
&
boot_memories
=
is_grad
?
op
.
Outputs
(
name
.
initial_states
)
is_grad
?
op
.
Outputs
(
name
.
boot_memories
)
:
op
.
Inputs
(
name
.
boot_memori
es
);
:
op
.
Inputs
(
name
.
initial_stat
es
);
// attributes
// attributes
auto
&
memories
=
op
.
Attr
<
std
::
vector
<
std
::
string
>>
(
name
.
memori
es
);
auto
&
memories
=
op
.
Attr
<
std
::
vector
<
std
::
string
>>
(
name
.
stat
es
);
auto
&
pre_memories
=
op
.
Attr
<
std
::
vector
<
std
::
string
>>
(
name
.
pre_memori
es
);
auto
&
pre_memories
=
op
.
Attr
<
std
::
vector
<
std
::
string
>>
(
name
.
ex_stat
es
);
PADDLE_ENFORCE
(
memories
.
size
()
==
boot_memories
.
size
(),
PADDLE_ENFORCE
(
memories
.
size
()
==
boot_memories
.
size
(),
"the size of
memories, boot_memori
es don't match:%d,%d"
,
"the size of
states, initial_stat
es don't match:%d,%d"
,
memories
.
size
(),
boot_memories
.
size
());
memories
.
size
(),
boot_memories
.
size
());
PADDLE_ENFORCE
(
pre_memories
.
size
()
==
boot_memories
.
size
(),
PADDLE_ENFORCE
(
pre_memories
.
size
()
==
boot_memories
.
size
(),
"the size of
pre_memories, boot_memori
es don't match:%d,%d"
,
"the size of
ex_states, initial_stat
es don't match:%d,%d"
,
pre_memories
.
size
(),
boot_memories
.
size
());
pre_memories
.
size
(),
boot_memories
.
size
());
PADDLE_ENFORCE
(
memories
.
size
()
>
0
,
"more than 1
memori
es should be set"
);
PADDLE_ENFORCE
(
memories
.
size
()
>
0
,
"more than 1
stat
es should be set"
);
for
(
size_t
i
=
0
;
i
<
memories
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
memories
.
size
();
++
i
)
{
rnn
::
Memory
Attr
mem_attr
;
rnn
::
State
Attr
mem_attr
;
mem_attr
.
var
=
memories
[
i
];
mem_attr
.
var
=
memories
[
i
];
mem_attr
.
pre_var
=
pre_memories
[
i
];
mem_attr
.
pre_var
=
pre_memories
[
i
];
mem_attr
.
boot_var
=
boot_memories
[
i
];
mem_attr
.
boot_var
=
boot_memories
[
i
];
(
arg
->
memori
es
).
push_back
(
mem_attr
);
(
arg
->
stat
es
).
push_back
(
mem_attr
);
}
}
}
}
...
...
paddle/operators/rnn/recurrent_op_utils.h
浏览文件 @
07ea9ade
...
@@ -31,7 +31,7 @@ using Scope = framework::Scope;
...
@@ -31,7 +31,7 @@ using Scope = framework::Scope;
* boot memories in father scope. Other attributes are copied from Op's proto
* boot memories in father scope. Other attributes are copied from Op's proto
* attributes.
* attributes.
*/
*/
struct
Memory
Attr
{
struct
State
Attr
{
// name of current state variable
// name of current state variable
std
::
string
var
;
std
::
string
var
;
// name of previous step's state variable
// name of previous step's state variable
...
@@ -46,7 +46,7 @@ struct Argument {
...
@@ -46,7 +46,7 @@ struct Argument {
std
::
string
step_scopes
;
std
::
string
step_scopes
;
std
::
vector
<
std
::
string
>
inlinks
;
std
::
vector
<
std
::
string
>
inlinks
;
std
::
vector
<
std
::
string
>
outlinks
;
std
::
vector
<
std
::
string
>
outlinks
;
std
::
vector
<
rnn
::
MemoryAttr
>
memori
es
;
std
::
vector
<
rnn
::
StateAttr
>
stat
es
;
};
};
struct
ArgumentName
{
struct
ArgumentName
{
...
@@ -54,9 +54,9 @@ struct ArgumentName {
...
@@ -54,9 +54,9 @@ struct ArgumentName {
std
::
string
step_scopes
;
std
::
string
step_scopes
;
std
::
string
inlinks
;
std
::
string
inlinks
;
std
::
string
outlinks
;
std
::
string
outlinks
;
std
::
string
memories
;
// the memory name
std
::
string
states
;
// the memory name
std
::
string
pre_memories
;
// the previous memory name
std
::
string
ex_states
;
// the previous memory name
std
::
string
boot_memori
es
;
// the boot memory name
std
::
string
initial_stat
es
;
// the boot memory name
};
};
/**
/**
...
@@ -74,7 +74,7 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
...
@@ -74,7 +74,7 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const
size_t
seq_len
,
const
platform
::
DeviceContext
&
ctx
);
const
size_t
seq_len
,
const
platform
::
DeviceContext
&
ctx
);
void
LinkMemories
(
const
std
::
vector
<
Scope
*>&
step_scopes
,
void
LinkMemories
(
const
std
::
vector
<
Scope
*>&
step_scopes
,
const
std
::
vector
<
Memory
Attr
>&
memories
,
const
size_t
step_id
,
const
std
::
vector
<
State
Attr
>&
memories
,
const
size_t
step_id
,
const
int
offset
);
const
int
offset
);
void
InitArgument
(
const
ArgumentName
&
name
,
Argument
*
arg
,
void
InitArgument
(
const
ArgumentName
&
name
,
Argument
*
arg
,
...
...
paddle/pybind/pybind.cc
浏览文件 @
07ea9ade
...
@@ -413,18 +413,18 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -413,18 +413,18 @@ All parameter, weight, gradient are variables in Paddle.
return
static_cast
<
operators
::
DynamicRecurrentOp
*>
(
return
static_cast
<
operators
::
DynamicRecurrentOp
*>
(
rnn_op
.
release
());
rnn_op
.
release
());
})
})
.
def
(
"set_step
ne
t"
,
.
def
(
"set_step
_uni
t"
,
[](
operators
::
DynamicRecurrentOp
&
self
,
const
operators
::
NetOp
&
net
)
[](
operators
::
DynamicRecurrentOp
&
self
,
const
operators
::
NetOp
&
net
)
->
void
{
self
.
SetStepNe
t
(
net
.
Clone
());
})
->
void
{
self
.
rnn
.
SetStepUni
t
(
net
.
Clone
());
})
.
def
(
"get_state"
,
.
def
(
"get_state"
,
[](
operators
::
DynamicRecurrentOp
&
self
,
const
std
::
string
&
name
)
[](
operators
::
DynamicRecurrentOp
&
self
,
const
std
::
string
&
name
)
->
const
TensorArray
&
{
return
self
.
state
(
name
);
})
->
const
TensorArray
&
{
return
self
.
rnn
.
state
(
name
);
})
.
def
(
"get_step_input"
,
.
def
(
"get_step_input"
,
[](
operators
::
DynamicRecurrentOp
&
self
,
const
std
::
string
&
name
)
[](
operators
::
DynamicRecurrentOp
&
self
,
const
std
::
string
&
name
)
->
const
TensorArray
&
{
return
self
.
step_input
(
name
);
})
->
const
TensorArray
&
{
return
self
.
rnn
.
step_input
(
name
);
})
.
def
(
"get_step_output"
,
.
def
(
"get_step_output"
,
[](
operators
::
DynamicRecurrentOp
&
self
,
const
std
::
string
&
name
)
[](
operators
::
DynamicRecurrentOp
&
self
,
const
std
::
string
&
name
)
->
const
TensorArray
&
{
return
self
.
step_output
(
name
);
});
->
const
TensorArray
&
{
return
self
.
rnn
.
step_output
(
name
);
});
// cond_op
// cond_op
py
::
class_
<
operators
::
CondOp
,
OperatorBase
>
(
m
,
"CondOp"
)
py
::
class_
<
operators
::
CondOp
,
OperatorBase
>
(
m
,
"CondOp"
)
...
...
python/paddle/v2/framework/tests/test_dynamic_recurrent_op.py
浏览文件 @
07ea9ade
...
@@ -4,6 +4,12 @@ import unittest
...
@@ -4,6 +4,12 @@ import unittest
from
paddle.v2.framework.op
import
Operator
,
DynamicRecurrentOp
from
paddle.v2.framework.op
import
Operator
,
DynamicRecurrentOp
import
numpy
as
np
import
numpy
as
np
# for siplicity, just one level LoD
lod_py
=
[[
0
,
4
,
7
,
9
,
10
]]
input_dim
=
30
num_sents
=
len
(
lod_py
[
0
])
-
1
weight_dim
=
15
def
create_tensor
(
scope
,
name
,
shape
,
np_data
):
def
create_tensor
(
scope
,
name
,
shape
,
np_data
):
tensor
=
scope
.
var
(
name
).
get_tensor
()
tensor
=
scope
.
var
(
name
).
get_tensor
()
...
@@ -12,6 +18,17 @@ def create_tensor(scope, name, shape, np_data):
...
@@ -12,6 +18,17 @@ def create_tensor(scope, name, shape, np_data):
return
tensor
return
tensor
class
PyRNNStep
(
object
):
def
__init__
(
self
):
self
.
x
=
np
.
random
.
normal
(
size
=
(
lod_py
[
0
][
-
1
],
input_dim
)).
astype
(
"float32"
)
self
.
W
=
np
.
random
.
normal
(
size
=
(
input_dim
,
input_dim
)).
astype
(
"float32"
)
self
.
U
=
np
.
random
.
normal
(
size
=
(
input_dim
,
input_dim
)).
astype
(
"float32"
)
self
.
h_boot
=
np
.
random
.
normal
(
size
=
(
num_sents
,
input_dim
)).
astype
(
"float32"
)
class
DynamicRecurrentOpTest
(
unittest
.
TestCase
):
class
DynamicRecurrentOpTest
(
unittest
.
TestCase
):
'''
'''
Test RNNOp
Test RNNOp
...
@@ -23,17 +40,13 @@ class DynamicRecurrentOpTest(unittest.TestCase):
...
@@ -23,17 +40,13 @@ class DynamicRecurrentOpTest(unittest.TestCase):
- U
- U
vars:
vars:
- x
- x
memori
es:
stat
es:
- h
- h
outputs:
outputs:
- h
- h
'''
'''
# for siplicity, just one level LoD
py
=
PyRNNStep
()
lod_py
=
[[
0
,
4
,
7
,
9
,
10
]]
input_dim
=
30
num_sents
=
len
(
lod_py
[
0
])
-
1
weight_dim
=
15
def
forward
(
self
):
def
forward
(
self
):
self
.
scope
=
core
.
Scope
()
self
.
scope
=
core
.
Scope
()
...
@@ -42,64 +55,55 @@ class DynamicRecurrentOpTest(unittest.TestCase):
...
@@ -42,64 +55,55 @@ class DynamicRecurrentOpTest(unittest.TestCase):
self
.
create_step_net
()
self
.
create_step_net
()
ctx
=
core
.
DeviceContext
.
create
(
core
.
CPUPlace
())
ctx
=
core
.
DeviceContext
.
create
(
core
.
CPUPlace
())
self
.
rnnop
.
run
(
self
.
scope
,
ctx
)
self
.
rnnop
.
run
(
self
.
scope
,
ctx
)
state
=
self
.
rnnop
.
get_state
(
"h@
mem
"
)
state
=
self
.
rnnop
.
get_state
(
"h@
state
"
)
print
'state size: '
,
state
.
size
()
print
'state size: '
,
state
.
size
()
step_inputs
=
self
.
rnnop
.
get_step_input
(
"x"
)
step_inputs
=
self
.
rnnop
.
get_step_input
(
"x"
)
print
"x size "
,
step_inputs
.
size
()
print
"x size "
,
step_inputs
.
size
()
for
i
in
range
(
step_inputs
.
size
()):
for
i
in
range
(
step_inputs
.
size
()):
print
"x %d"
%
i
,
np
.
array
(
step_inputs
.
read
(
i
).
get_dims
())
print
"x %d"
%
i
,
np
.
array
(
step_inputs
.
read
(
i
).
get_dims
())
step_outputs
=
self
.
rnnop
.
get_step_output
(
'h@
mem
'
)
step_outputs
=
self
.
rnnop
.
get_step_output
(
'h@
state
'
)
print
'step_outputs.size '
,
step_outputs
.
size
()
print
'step_outputs.size '
,
step_outputs
.
size
()
output
=
self
.
scope
.
find_var
(
"h@mem"
).
get_tensor
()
output
=
self
.
scope
.
find_var
(
"h@state"
).
get_tensor
()
print
'output'
,
np
.
array
(
output
).
shape
print
'output'
,
np
.
array
(
output
).
shape
def
create_global_variables
(
self
):
def
create_global_variables
(
self
):
x
=
np
.
random
.
normal
(
size
=
(
self
.
lod_py
[
0
][
-
1
],
self
.
input_dim
)).
astype
(
"float32"
)
W
=
np
.
random
.
normal
(
size
=
(
self
.
input_dim
,
self
.
input_dim
)).
astype
(
"float32"
)
U
=
np
.
random
.
normal
(
size
=
(
self
.
input_dim
,
self
.
input_dim
)).
astype
(
"float32"
)
h_boot
=
np
.
random
.
normal
(
size
=
(
self
.
num_sents
,
self
.
input_dim
)).
astype
(
"float32"
)
# create inlink
# create inlink
x_tensor
=
create_tensor
(
self
.
scope
,
"x"
,
x_tensor
=
create_tensor
(
self
.
scope
,
"x"
,
[
num_sents
,
input_dim
],
[
self
.
num_sents
,
self
.
input_dim
],
x
)
self
.
py
.
x
)
x_tensor
.
set_lod
(
self
.
lod_py
)
x_tensor
.
set_lod
(
lod_py
)
create_tensor
(
self
.
scope
,
"W"
,
[
self
.
input_dim
,
self
.
input_dim
],
W
)
create_tensor
(
self
.
scope
,
"W"
,
[
input_dim
,
input_dim
],
self
.
py
.
W
)
create_tensor
(
self
.
scope
,
"U"
,
[
self
.
input_dim
,
self
.
input_dim
],
U
)
create_tensor
(
self
.
scope
,
"U"
,
[
input_dim
,
input_dim
],
self
.
py
.
U
)
create_tensor
(
self
.
scope
,
"h_boot"
,
[
self
.
num_sents
,
self
.
input_dim
],
create_tensor
(
self
.
scope
,
"h_boot"
,
[
num_sents
,
input_dim
],
h_boot
)
self
.
py
.
h_boot
)
self
.
scope
.
var
(
"step_scopes"
)
self
.
scope
.
var
(
"step_scopes"
)
self
.
scope
.
var
(
"h@
mem
"
)
self
.
scope
.
var
(
"h@
state
"
)
def
create_rnn_op
(
self
):
def
create_rnn_op
(
self
):
# create RNNOp
# create RNNOp
self
.
rnnop
=
DynamicRecurrentOp
(
self
.
rnnop
=
DynamicRecurrentOp
(
# inputs
# inputs
in
link
s
=
[
"x"
],
in
put
s
=
[
"x"
],
boot_memori
es
=
[
"h_boot"
],
initial_stat
es
=
[
"h_boot"
],
step_net
=
"step
ne
t"
,
step_net
=
"step
_uni
t"
,
# outputs
# outputs
out
links
=
[
"h@mem
"
],
out
puts
=
[
"h@state
"
],
step_scopes
=
"step_scopes"
,
step_scopes
=
"step_scopes"
,
# attributes
# attributes
pre_memori
es
=
[
"h@pre"
],
ex_stat
es
=
[
"h@pre"
],
memories
=
[
"h@mem
"
])
states
=
[
"h@state
"
])
def
create_step_net
(
self
):
def
create_step_net
(
self
):
step
ne
t
=
core
.
Net
.
create
()
step
_uni
t
=
core
.
Net
.
create
()
x_fc_op
=
Operator
(
"mul"
,
X
=
"x"
,
Y
=
"W"
,
Out
=
"Wx"
)
x_fc_op
=
Operator
(
"mul"
,
X
=
"x"
,
Y
=
"W"
,
Out
=
"Wx"
)
h_fc_op
=
Operator
(
"mul"
,
X
=
"h@pre"
,
Y
=
"U"
,
Out
=
"Uh"
)
h_fc_op
=
Operator
(
"mul"
,
X
=
"h@pre"
,
Y
=
"U"
,
Out
=
"Uh"
)
sum_op
=
Operator
(
"sum"
,
X
=
[
"Wx"
,
"Uh"
],
Out
=
"sum"
)
sum_op
=
Operator
(
"sum"
,
X
=
[
"Wx"
,
"Uh"
],
Out
=
"sum"
)
sig_op
=
Operator
(
"sigmoid"
,
X
=
"sum"
,
Y
=
"h@
mem
"
)
sig_op
=
Operator
(
"sigmoid"
,
X
=
"sum"
,
Y
=
"h@
state
"
)
for
op
in
[
x_fc_op
,
h_fc_op
,
sum_op
,
sig_op
]:
for
op
in
[
x_fc_op
,
h_fc_op
,
sum_op
,
sig_op
]:
step
ne
t
.
append_op
(
op
)
step
_uni
t
.
append_op
(
op
)
step
ne
t
.
complete_add_op
(
True
)
step
_uni
t
.
complete_add_op
(
True
)
self
.
rnnop
.
set_step
net
(
stepne
t
)
self
.
rnnop
.
set_step
_unit
(
step_uni
t
)
def
test_forward
(
self
):
def
test_forward
(
self
):
print
'test recurrent op forward'
print
'test recurrent op forward'
...
@@ -107,5 +111,58 @@ class DynamicRecurrentOpTest(unittest.TestCase):
...
@@ -107,5 +111,58 @@ class DynamicRecurrentOpTest(unittest.TestCase):
print
'pd_output'
,
pd_output
print
'pd_output'
,
pd_output
class
RecurrentGradientOpTest
(
unittest
.
TestCase
):
py
=
PyRNNStep
()
def
create_forward_op
(
self
):
# create RNNOp
self
.
forward_op
=
DynamicRecurrentOp
(
# inputs
inputs
=
[
"x"
],
initial_states
=
[
"h_boot"
],
step_net
=
"step_unit"
,
# outputs
outputs
=
[
"h@state"
],
step_scopes
=
"step_scopes"
,
# attributes
ex_states
=
[
"h@pre"
],
states
=
[
"h@state"
])
def
create_gradient_op
(
self
):
a
=
set
()
backward_op
=
core
.
DynamicRecurrentOp
.
backward
(
self
.
forward_op
,
a
)
def
create_step_net
(
self
):
step_unit
=
core
.
Net
.
create
()
x_fc_op
=
Operator
(
"mul"
,
X
=
"x"
,
Y
=
"W"
,
Out
=
"Wx"
)
h_fc_op
=
Operator
(
"mul"
,
X
=
"h@pre"
,
Y
=
"U"
,
Out
=
"Uh"
)
sum_op
=
Operator
(
"sum"
,
X
=
[
"Wx"
,
"Uh"
],
Out
=
"sum"
)
sig_op
=
Operator
(
"sigmoid"
,
X
=
"sum"
,
Y
=
"h@state"
)
for
op
in
[
x_fc_op
,
h_fc_op
,
sum_op
,
sig_op
]:
step_unit
.
append_op
(
op
)
step_unit
.
complete_add_op
(
True
)
self
.
forward_op
.
set_step_unit
(
step_unit
)
def
create_global_variables
(
self
):
# create inlink
x_tensor
=
create_tensor
(
self
.
scope
,
"x"
,
[
num_sents
,
input_dim
],
self
.
py
.
x
)
x_tensor
.
set_lod
(
lod_py
)
create_tensor
(
self
.
scope
,
"W"
,
[
input_dim
,
input_dim
],
self
.
py
.
W
)
create_tensor
(
self
.
scope
,
"U"
,
[
input_dim
,
input_dim
],
self
.
py
.
U
)
create_tensor
(
self
.
scope
,
"h_boot"
,
[
num_sents
,
input_dim
],
self
.
py
.
h_boot
)
self
.
scope
.
var
(
"step_scopes"
)
self
.
scope
.
var
(
"h@state"
)
def
test_grad
(
self
):
self
.
scope
=
core
.
Scope
()
self
.
create_forward_op
()
self
.
create_global_variables
()
self
.
create_step_net
()
self
.
create_gradient_op
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/v2/framework/tests/test_recurrent_op.py
浏览文件 @
07ea9ade
...
@@ -132,15 +132,15 @@ class RecurrentOpTest(unittest.TestCase):
...
@@ -132,15 +132,15 @@ class RecurrentOpTest(unittest.TestCase):
# create RNNOp
# create RNNOp
self
.
rnnop
=
RecurrentOp
(
self
.
rnnop
=
RecurrentOp
(
# inputs
# inputs
in
link
s
=
[
"x"
],
in
put
s
=
[
"x"
],
boot_memori
es
=
[
"h_boot"
],
initial_stat
es
=
[
"h_boot"
],
step_net
=
"stepnet"
,
step_net
=
"stepnet"
,
# outputs
# outputs
out
link
s
=
[
"h@mem"
],
out
put
s
=
[
"h@mem"
],
step_scopes
=
"step_scopes"
,
step_scopes
=
"step_scopes"
,
# attributes
# attributes
pre_memori
es
=
[
"h@pre"
],
ex_stat
es
=
[
"h@pre"
],
memori
es
=
[
"h@mem"
])
stat
es
=
[
"h@mem"
])
def
create_step_net
(
self
):
def
create_step_net
(
self
):
stepnet
=
core
.
Net
.
create
()
stepnet
=
core
.
Net
.
create
()
...
@@ -169,15 +169,15 @@ class RecurrentGradientOpTest(unittest.TestCase):
...
@@ -169,15 +169,15 @@ class RecurrentGradientOpTest(unittest.TestCase):
def
create_forward_op
(
self
):
def
create_forward_op
(
self
):
self
.
forward_op
=
RecurrentOp
(
self
.
forward_op
=
RecurrentOp
(
# inputs
# inputs
in
link
s
=
[
"x"
],
in
put
s
=
[
"x"
],
boot_memori
es
=
[
"h_boot"
],
initial_stat
es
=
[
"h_boot"
],
step_net
=
"stepnet"
,
step_net
=
"stepnet"
,
# outputs
# outputs
out
link
s
=
[
"h"
],
out
put
s
=
[
"h"
],
step_scopes
=
"step_scopes"
,
step_scopes
=
"step_scopes"
,
# attributes
# attributes
pre_memori
es
=
[
"h@pre"
],
ex_stat
es
=
[
"h@pre"
],
memori
es
=
[
"h@alias"
])
stat
es
=
[
"h@alias"
])
# create a stepnet for RNN
# create a stepnet for RNN
stepnet
=
core
.
Net
.
create
()
stepnet
=
core
.
Net
.
create
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录