Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
843ed8e3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
843ed8e3
编写于
10月 10, 2017
作者:
Y
Yan Chunwei
提交者:
GitHub
10月 10, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dynamic recurrent op forward c++ implentation (#4597)
上级
7506e481
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
666 addition
and
5 deletion
+666
-5
cmake/configure.cmake
cmake/configure.cmake
+4
-0
paddle/framework/operator.h
paddle/framework/operator.h
+3
-3
paddle/framework/tensor_array.h
paddle/framework/tensor_array.h
+2
-2
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+1
-0
paddle/operators/dynamic_recurrent_op.cc
paddle/operators/dynamic_recurrent_op.cc
+276
-0
paddle/operators/dynamic_recurrent_op.h
paddle/operators/dynamic_recurrent_op.h
+158
-0
paddle/operators/dynamic_recurrent_op_test.cc
paddle/operators/dynamic_recurrent_op_test.cc
+222
-0
未找到文件。
cmake/configure.cmake
浏览文件 @
843ed8e3
...
...
@@ -24,6 +24,10 @@ if(WITH_DOUBLE)
add_definitions
(
-DPADDLE_TYPE_DOUBLE
)
endif
(
WITH_DOUBLE
)
if
(
WITH_TESTING
)
add_definitions
(
-DPADDLE_WITH_TESTING
)
endif
(
WITH_TESTING
)
if
(
NOT WITH_TIMER
)
add_definitions
(
-DPADDLE_DISABLE_TIMER
)
endif
(
NOT WITH_TIMER
)
...
...
paddle/framework/operator.h
浏览文件 @
843ed8e3
...
...
@@ -142,9 +142,9 @@ class OperatorBase {
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
// register it. i.e. `Clone` method is not needed to define by yourself.
#define DEFINE_OP_CLONE_METHOD(cls) \
std::unique_ptr<OperatorBase> Clone() const final { \
return std::unique_ptr<OperatorBase>(new cls(*this)); \
#define DEFINE_OP_CLONE_METHOD(cls)
\
std::unique_ptr<
::paddle::framework::
OperatorBase> Clone() const final { \
return std::unique_ptr<
::paddle::framework::
OperatorBase>(new cls(*this)); \
}
// Macro for define a default constructor for Operator.
...
...
paddle/framework/tensor_array.h
浏览文件 @
843ed8e3
...
...
@@ -87,12 +87,12 @@ class TensorArray {
LoDTensor
Stack
()
const
;
/*
* Un
p
acks the given division of a rank-`R` tensor into rank-`(R-1)` tensors.
* Un
st
acks the given division of a rank-`R` tensor into rank-`(R-1)` tensors.
*/
void
Unstack
(
const
LoDTensor
&
source
)
const
;
/*
* Un
p
acks the given division of a rank-`R` tensor into rank-`(R-1)` tensors,
* Un
st
acks the given division of a rank-`R` tensor into rank-`(R-1)` tensors,
* with memory of tensors shared.
*/
void
UnstackShared
(
const
LoDTensor
&
source
)
const
;
...
...
paddle/operators/CMakeLists.txt
浏览文件 @
843ed8e3
...
...
@@ -133,3 +133,4 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net_op
)
cc_test
(
scatter_test SRCS scatter_test.cc DEPS tensor
)
cc_test
(
strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory
)
cc_test
(
dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array
)
paddle/operators/dynamic_recurrent_op.cc
0 → 100644
浏览文件 @
843ed8e3
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve .
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/dynamic_recurrent_op.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Scope
;
using
framework
::
TensorArray
;
using
framework
::
LoDTensor
;
using
framework
::
Variable
;
namespace
detail
{
inline
void
CreateVariables
(
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
var_names
)
{
for
(
const
auto
&
name
:
var_names
)
{
scope
.
NewVar
(
name
);
}
}
}
// namespace detail
class
DynamicRecurrentOpProtoAndCheckerMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
DynamicRecurrentOpProtoAndCheckerMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
const
auto
&
name
=
DynamicRecurrentOp
::
kArgName
;
// inputs and outputs stored in proto
AddInput
(
name
.
inlinks
,
"the inputs that need to be segmented for each step."
)
.
AsDuplicable
();
AddInput
(
name
.
boot_memories
,
"variables to initialize memories."
)
.
AsDuplicable
();
AddOutput
(
name
.
outlinks
,
"the outputs that need to concated for all steps."
)
.
AsDuplicable
();
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
// Attributes stored in AttributeMap
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
pre_memories
,
"names of pre-memories"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
name
.
memories
,
"names of memories"
);
AddComment
(
"This is a RNN operator for varience-length sequences."
);
}
};
void
DynamicRecurrentOp
::
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
cache_
.
Init
(
kArgName
,
*
this
,
scope
,
&
arg_
);
SplitInputs
();
CreateScopes
();
WriteStepInputs
();
InitStates
();
// call stepnet in all the time steps
for
(
size_t
step
=
0
;
step
<
cache_
.
num_steps
;
step
++
)
{
auto
&
step_scope
=
cache_
.
GetScope
(
step
);
stepnet_
->
Run
(
step_scope
,
dev_ctx
);
}
WriteStepOutputs
();
ConcatOutputs
();
}
void
DynamicRecurrentOp
::
SplitInputs
()
const
{
// TODO(superjom) make level a config
// TODO(superjom) check all the inputs has the same LoD
int
level
=
0
;
const
auto
&
inlinks
=
cache_
.
inlinks
;
for
(
const
auto
&
item
:
inlinks
)
{
const
auto
&
var
=
item
.
second
;
const
auto
&
tensor
=
var
->
Get
<
LoDTensor
>
();
TensorArray
&
ta
=
step_inputs_
[
item
.
first
];
dy_seq_metas_
[
item
.
first
]
=
ta
.
Unpack
(
tensor
,
level
,
true
/*length_descend*/
);
if
(
cache_
.
num_steps
)
{
PADDLE_ENFORCE_EQ
(
ta
.
size
(),
cache_
.
num_steps
,
"inputs should have the same steps"
);
}
else
{
cache_
.
num_steps
=
ta
.
size
();
}
}
}
void
DynamicRecurrentOp
::
WriteStepInputs
()
const
{
for
(
const
auto
&
item
:
cache_
.
inlinks
)
{
auto
ta_it
=
step_inputs_
.
find
(
item
.
first
);
PADDLE_ENFORCE
(
ta_it
!=
step_inputs_
.
end
(),
"step_inputs_ not compatible with memory set"
);
TensorArray
&
ta
=
ta_it
->
second
;
for
(
size_t
step
=
0
;
step
<
ta
.
size
();
step
++
)
{
auto
tensor
=
ta
.
Read
(
step
);
auto
&
step_scope
=
cache_
.
GetScope
(
step
);
Variable
*
var
=
step_scope
.
FindVar
(
item
.
first
);
if
(
var
==
nullptr
)
{
var
=
step_scope
.
NewVar
(
item
.
first
);
}
var
->
GetMutable
<
LoDTensor
>
()
->
ShareDataWith
<
value_type
>
(
tensor
);
}
}
}
void
DynamicRecurrentOp
::
WriteStepOutputs
()
const
{
for
(
size_t
step
=
0
;
step
<
cache_
.
scopes
->
size
();
step
++
)
{
auto
&
scope
=
cache_
.
GetScope
(
step
);
for
(
auto
&
item
:
step_outputs_
)
{
auto
*
var
=
scope
.
FindVar
(
item
.
first
);
if
(
var
==
nullptr
)
{
var
=
scope
.
NewVar
(
item
.
first
);
}
auto
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
item
.
second
.
WriteShared
(
step
,
*
tensor
);
}
}
}
void
DynamicRecurrentOp
::
CreateScopes
()
const
{
PADDLE_ENFORCE_GT
(
cache_
.
num_steps
,
0
);
// resize scopes
size_t
num_scopes_need_create
=
cache_
.
num_steps
-
cache_
.
scopes
->
size
();
for
(
size_t
i
=
0
;
i
<
num_scopes_need_create
;
i
++
)
{
cache_
.
scopes
->
emplace_back
(
&
cache_
.
scope
->
NewScope
());
}
// init temporary inputs
PADDLE_ENFORCE_NOT_NULL
(
stepnet_
,
"stepnet should be set first"
);
std
::
vector
<
std
::
string
>
memories
;
std
::
vector
<
std
::
string
>
pre_memories
;
std
::
transform
(
arg_
.
memories
.
begin
(),
arg_
.
memories
.
end
(),
std
::
back_inserter
(
memories
),
[](
const
rnn
::
MemoryAttr
&
m
)
{
return
m
.
var
;
});
std
::
transform
(
arg_
.
memories
.
begin
(),
arg_
.
memories
.
end
(),
std
::
back_inserter
(
pre_memories
),
[](
const
rnn
::
MemoryAttr
&
m
)
{
return
m
.
pre_var
;
});
for
(
size_t
step
=
0
;
step
<
cache_
.
num_steps
;
step
++
)
{
auto
&
scope
=
cache_
.
GetScope
(
step
);
detail
::
CreateVariables
(
scope
,
arg_
.
inlinks
);
detail
::
CreateVariables
(
scope
,
arg_
.
outlinks
);
detail
::
CreateVariables
(
scope
,
memories
);
detail
::
CreateVariables
(
scope
,
pre_memories
);
}
}
void
DynamicRecurrentOp
::
ConcatOutputs
()
const
{
// TODO(superjom) transform this to a config
int
level
=
0
;
// TODO(superjom) pass in some lod
// just a placeholder
framework
::
LoD
lod
;
for
(
auto
&
item
:
step_outputs_
)
{
auto
tensor
=
item
.
second
.
Pack
(
level
,
dy_seq_metas_
[
item
.
first
],
lod
);
auto
&
output
=
cache_
.
outlinks
[
item
.
first
]
->
Get
<
LoDTensor
>
();
const_cast
<
LoDTensor
*>
(
&
output
)
->
ShareDataWith
<
value_type
>
(
tensor
);
}
}
void
DynamicRecurrentOp
::
InitStates
()
const
{
// init the first state
// TODO(superjom) parepare the scenerio that boot state not exists
for
(
auto
memory
:
arg_
.
memories
)
{
auto
*
boot_state_var
=
cache_
.
scope
->
FindVar
(
memory
.
boot_var
);
PADDLE_ENFORCE_NOT_NULL
(
boot_state_var
);
auto
&
boot_state
=
boot_state_var
->
Get
<
LoDTensor
>
();
const
auto
&
dims
=
boot_state
.
dims
();
for
(
size_t
step
=
0
;
step
<
cache_
.
num_steps
;
step
++
)
{
auto
&
cur_scope
=
cache_
.
GetScope
(
step
);
// link pre-state to boot_state
// init state and pre-state
auto
*
pre_state
=
cur_scope
.
FindVar
(
memory
.
pre_var
);
PADDLE_ENFORCE_NOT_NULL
(
pre_state
);
pre_state
->
GetMutable
<
LoDTensor
>
();
auto
*
state
=
cur_scope
.
FindVar
(
memory
.
var
);
PADDLE_ENFORCE_NOT_NULL
(
state
);
state
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
dims
);
state
->
GetMutable
<
LoDTensor
>
()
->
mutable_data
<
value_type
>
(
platform
::
CPUPlace
());
if
(
step
==
0
)
{
auto
*
pre_state_tensor
=
pre_state
->
GetMutable
<
LoDTensor
>
();
pre_state_tensor
->
Resize
(
boot_state
.
dims
());
pre_state_tensor
->
ShareDataWith
<
value_type
>
(
boot_state
);
}
else
{
auto
&
pre_scope
=
cache_
.
GetScope
(
step
-
1
);
auto
*
state_pre
=
pre_scope
.
FindVar
(
memory
.
var
);
PADDLE_ENFORCE_NOT_NULL
(
state_pre
);
pre_state
->
GetMutable
<
LoDTensor
>
()
->
ShareDataWith
<
value_type
>
(
*
state_pre
->
GetMutable
<
LoDTensor
>
());
}
}
}
}
void
DynamicRecurrentOp
::
ArgCache
::
Init
(
const
rnn
::
ArgumentName
&
name
,
const
paddle
::
framework
::
OperatorBase
&
op
,
const
paddle
::
framework
::
Scope
&
scope
,
rnn
::
Argument
*
arg
)
{
this
->
scope
=
&
scope
;
InitArgument
(
name
,
op
,
arg
);
CacheScopes
(
scope
,
*
arg
);
CacheInlinks
(
scope
,
arg
->
inlinks
);
CacheOutlinks
(
scope
,
arg
->
outlinks
);
}
void
DynamicRecurrentOp
::
ArgCache
::
InitArgument
(
const
rnn
::
ArgumentName
&
name
,
const
OperatorBase
&
op
,
rnn
::
Argument
*
arg
)
{
rnn
::
InitArgument
(
name
,
arg
,
op
,
false
/*is_grad*/
);
}
void
DynamicRecurrentOp
::
ArgCache
::
CacheScopes
(
const
Scope
&
scope
,
const
rnn
::
Argument
&
arg
)
{
auto
scopes_var
=
scope
.
FindVar
(
arg
.
step_scopes
);
PADDLE_ENFORCE
(
scopes_var
!=
nullptr
,
"the step_scopes output argument [%s] should be created first "
"by framework."
,
arg
.
step_scopes
);
this
->
scopes
=
scopes_var
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
}
void
DynamicRecurrentOp
::
ArgCache
::
CacheInlinks
(
const
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
names
)
{
for
(
auto
name
:
names
)
{
auto
*
var
=
GetVariable
(
scope
,
name
);
inlinks
[
name
]
=
var
;
}
}
void
DynamicRecurrentOp
::
ArgCache
::
CacheOutlinks
(
const
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
names
)
{
for
(
auto
name
:
names
)
{
auto
*
var
=
GetVariable
(
scope
,
name
);
outlinks
[
name
]
=
var
;
}
}
Variable
*
DynamicRecurrentOp
::
ArgCache
::
GetVariable
(
const
Scope
&
scope
,
const
std
::
string
&
name
)
{
auto
*
var
=
scope
.
FindVar
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"variable [%s] not exist in scope"
,
name
);
return
var
;
}
const
rnn
::
ArgumentName
DynamicRecurrentOp
::
kArgName
{
"step_net"
,
"step_scopes"
,
"inlinks"
,
"outlinks"
,
"memories"
,
"pre_memories"
,
"boot_memories"
};
void
DynamicRecurrentGradientOp
::
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{}
}
// namespace operators
}
// namespace paddle
REGISTER_OP_WITHOUT_GRADIENT
(
dynamic_recurrent
,
paddle
::
operators
::
DynamicRecurrentOp
,
paddle
::
operators
::
DynamicRecurrentOpProtoAndCheckerMaker
);
paddle/operators/dynamic_recurrent_op.h
0 → 100644
浏览文件 @
843ed8e3
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_TESTING
#include "gtest/gtest.h"
#endif
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor_array.h"
#include "paddle/framework/variable.h"
#include "paddle/operators/rnn/recurrent_op_utils.h"
namespace
paddle
{
namespace
operators
{
class
DynamicRecurrentOp
:
public
framework
::
OperatorBase
{
public:
static
const
rnn
::
ArgumentName
kArgName
;
using
value_type
=
float
;
DynamicRecurrentOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
DynamicRecurrentOp
(
const
DynamicRecurrentOp
&
o
)
:
framework
::
OperatorBase
(
static_cast
<
const
framework
::
OperatorBase
&>
(
o
))
{
// TODO(yuyang18): Implement copy ctor well.
PADDLE_THROW
(
"Not implemented"
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
;
/*
* Split the inputs(LoDTensors) to segments for each time step.
*/
void
SplitInputs
()
const
;
/*
* Create step-scopes to store temporary outputs in each time steps.
*/
void
CreateScopes
()
const
;
/*
* Link TensorArray steps to the corresponding variables located in
* step-scopes.
*/
void
WriteStepInputs
()
const
;
/*
* Write output of each step to the corresponding TensorArray.
*/
void
WriteStepOutputs
()
const
;
/*
* Initialize the states, each state will have a corresponding pre-state,
* which share the memory with the state in the previous time state. The
* pre-state in the first time step will be initialized with an zero tensor or
* a tensor in parent scope if is provided.
*/
void
InitStates
()
const
;
/*
* Concatenate outputs in each time step and generate a LoDTensor.
*/
void
ConcatOutputs
()
const
;
/*
* set a stepnet that is created according to a RecurrentOp's stepnet.
*/
void
SetStepNet
(
std
::
unique_ptr
<
OperatorBase
>
net
)
{
PADDLE_ENFORCE_NOT_NULL
(
net
);
stepnet_
=
std
::
move
(
net
);
}
const
OperatorBase
&
GetStepNet
()
const
{
return
*
stepnet_
;
}
protected:
struct
ArgCache
{
framework
::
Scope
const
*
scope
;
std
::
vector
<
framework
::
Scope
*>*
scopes
;
std
::
map
<
std
::
string
,
framework
::
Variable
*>
inlinks
;
std
::
map
<
std
::
string
,
framework
::
Variable
*>
outlinks
;
size_t
num_steps
{
0
};
void
Init
(
const
rnn
::
ArgumentName
&
name
,
const
OperatorBase
&
op
,
const
framework
::
Scope
&
scope
,
rnn
::
Argument
*
arg
);
framework
::
Scope
&
GetScope
(
size_t
index
)
{
PADDLE_ENFORCE_LT
(
index
,
num_steps
);
return
*
scopes
->
at
(
index
);
}
private:
void
InitArgument
(
const
rnn
::
ArgumentName
&
name
,
const
OperatorBase
&
op
,
rnn
::
Argument
*
arg
);
void
CacheScopes
(
const
framework
::
Scope
&
scope
,
const
rnn
::
Argument
&
arg
);
void
CacheInlinks
(
const
framework
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
names
);
void
CacheOutlinks
(
const
framework
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
names
);
framework
::
Variable
*
GetVariable
(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
name
);
};
private:
std
::
unique_ptr
<
OperatorBase
>
stepnet_
;
mutable
framework
::
TensorArray
states_
;
mutable
std
::
map
<
std
::
string
,
framework
::
TensorArray
>
step_inputs_
;
mutable
std
::
map
<
std
::
string
,
framework
::
TensorArray
>
step_outputs_
;
mutable
std
::
map
<
std
::
string
,
std
::
vector
<
framework
::
DySeqMeta
>>
dy_seq_metas_
;
mutable
rnn
::
Argument
arg_
;
mutable
ArgCache
cache_
;
#ifdef PADDLE_WITH_TESTING
friend
class
DynamicRecurrentOpTestHelper
;
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
SplitInputs
);
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
CreateCache
);
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
CreateScopes
);
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
WriteStepInputs
);
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
WriteStepOutputs
);
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
InitStates
);
FRIEND_TEST
(
DynamicRecurrentOpTestHelper
,
ConcatOutputs
);
#endif
};
class
DynamicRecurrentGradientOp
:
public
framework
::
OperatorBase
{
public:
DynamicRecurrentGradientOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
;
};
}
// namespace operators
}
// namespace paddle
paddle/operators/dynamic_recurrent_op_test.cc
0 → 100644
浏览文件 @
843ed8e3
#include "paddle/operators/dynamic_recurrent_op.h"
#include <gtest/gtest.h>
#include "paddle/framework/ddim.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Scope
;
using
framework
::
TensorArray
;
using
framework
::
LoDTensor
;
using
framework
::
Variable
;
class
TestOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
DEFINE_OP_CLONE_METHOD
(
TestOp
);
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
};
void
OpDescNewVar
(
const
std
::
string
&
param_name
,
std
::
initializer_list
<
const
char
*>
arguments
,
paddle
::
framework
::
OpDesc
::
Var
*
var
)
{
var
->
set_parameter
(
param_name
);
for
(
auto
&
arg_name
:
arguments
)
{
var
->
add_arguments
(
arg_name
);
}
}
// create a LoD tensor in scope with specific dims
LoDTensor
*
CreateVar
(
Scope
&
scope
,
std
::
string
name
,
framework
::
DDim
dims
,
const
platform
::
Place
&
place
)
{
auto
*
var
=
scope
.
NewVar
(
name
);
auto
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
tensor
->
Resize
(
dims
);
tensor
->
mutable_data
<
float
>
(
place
);
return
tensor
;
}
class
DynamicRecurrentOpTestHelper
:
public
::
testing
::
Test
{
protected:
const
rnn
::
ArgumentName
argname
=
DynamicRecurrentOp
::
kArgName
;
virtual
void
SetUp
()
override
{
CreateGlobalVariables
();
auto
op_desc
=
CreateOpDesc
();
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
dop
=
dynamic_cast
<
DynamicRecurrentOp
*>
(
op
.
get
());
InitCacheManually
();
InitStepNet
();
}
framework
::
OpDesc
CreateOpDesc
()
{
// create op
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"dynamic_recurrent"
);
OpDescNewVar
(
argname
.
inlinks
,
{
"in0"
},
op_desc
.
add_inputs
());
OpDescNewVar
(
argname
.
boot_memories
,
{
"boot_mem"
},
op_desc
.
add_inputs
());
OpDescNewVar
(
argname
.
step_scopes
,
{
"step_scopes"
},
op_desc
.
add_outputs
());
OpDescNewVar
(
argname
.
outlinks
,
{
"out0"
},
op_desc
.
add_outputs
());
// set pre-memories
auto
pre_memories
=
op_desc
.
mutable_attrs
()
->
Add
();
pre_memories
->
set_name
(
argname
.
pre_memories
);
pre_memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
pre_memories_item
=
pre_memories
->
add_strings
();
*
pre_memories_item
=
"mem@pre"
;
// set memories
auto
memories
=
op_desc
.
mutable_attrs
()
->
Add
();
memories
->
set_name
(
argname
.
memories
);
memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
memories_item
=
memories
->
add_strings
();
*
memories_item
=
"mem"
;
return
op_desc
;
}
void
CreateGlobalVariables
()
{
platform
::
CPUPlace
place
;
scope
.
NewVar
(
"step_scopes"
);
CreateVar
(
scope
,
"boot_mem"
,
framework
::
make_ddim
({
10
,
20
}),
place
);
// auto* out0 =
CreateVar
(
scope
,
"out0"
,
framework
::
make_ddim
({
10
,
20
}),
place
);
auto
*
in0
=
CreateVar
(
scope
,
"in0"
,
framework
::
make_ddim
({
10
,
8
}),
place
);
// 10 instanes with 4 sentences, length is 4, 3, 2, 1 respectively.
framework
::
LoD
in0_lod
(
1
);
for
(
int
x
:
std
::
vector
<
int
>
{
0
,
4
,
7
,
9
,
10
})
{
in0_lod
[
0
].
push_back
(
x
);
}
in0
->
set_lod
(
in0_lod
);
in0
->
Resize
(
framework
::
make_ddim
({
10
,
8
}));
// set the content, each sentence content is seqid.batchid
// the seqid starts from 0
int
start
=
0
;
for
(
size_t
seqid
=
0
;
seqid
<
in0_lod
.
size
()
-
1
;
seqid
++
)
{
for
(
size_t
batchid
=
0
;
batchid
<
in0_lod
[
0
][
seqid
+
1
]
-
in0_lod
[
0
][
seqid
];
batchid
++
)
{
float
v
=
seqid
+
batchid
*
0.1
;
for
(
size_t
dim
=
0
;
dim
<
8
;
dim
++
)
{
in0
->
data
<
float
>
()[
start
*
8
+
dim
]
=
v
;
}
start
++
;
}
}
}
void
InitCacheManually
()
{
dop
->
cache_
.
Init
(
DynamicRecurrentOp
::
kArgName
,
*
dop
,
scope
,
&
dop
->
arg_
);
}
void
InitStepNet
()
{
std
::
unique_ptr
<
framework
::
OperatorBase
>
stepnet
{
new
NetOp
};
dynamic_cast
<
NetOp
*>
(
stepnet
.
get
())
->
AppendOp
(
std
::
unique_ptr
<
TestOp
>
(
new
TestOp
(
"test"
,
{{
"inlinks"
,
{
"in0"
}},
{
"boot_memories"
,
{
"boot_mem"
}}},
{{
"outlinks"
,
{
"out0"
}},
{
"step_scopes"
,
{
"step_scopes"
}}},
{})));
dop
->
SetStepNet
(
std
::
move
(
stepnet
));
}
protected:
DynamicRecurrentOp
*
dop
;
std
::
unique_ptr
<
framework
::
OperatorBase
>
op
;
paddle
::
platform
::
CPUDeviceContext
device_context
;
paddle
::
framework
::
Scope
scope
;
};
TEST_F
(
DynamicRecurrentOpTestHelper
,
CreateCache
)
{
const
rnn
::
Argument
&
arg
=
dop
->
arg_
;
ASSERT_EQ
(
arg
.
inlinks
.
size
(),
1UL
);
ASSERT_EQ
(
arg
.
outlinks
.
size
(),
1UL
);
}
TEST_F
(
DynamicRecurrentOpTestHelper
,
SplitInputs
)
{
dop
->
SplitInputs
();
auto
&
in0_ta
=
dop
->
step_inputs_
[
"in0"
];
ASSERT_EQ
(
in0_ta
.
size
(),
4UL
);
const
auto
&
batch0
=
in0_ta
.
Read
(
0
);
const
auto
&
batch1
=
in0_ta
.
Read
(
1
);
const
auto
&
batch2
=
in0_ta
.
Read
(
2
);
const
auto
&
batch3
=
in0_ta
.
Read
(
3
);
EXPECT_EQ
(
batch0
.
dims
()[
0
],
4
);
EXPECT_EQ
(
batch1
.
dims
()[
0
],
3
);
EXPECT_EQ
(
batch2
.
dims
()[
0
],
2
);
EXPECT_EQ
(
batch3
.
dims
()[
0
],
1
);
}
TEST_F
(
DynamicRecurrentOpTestHelper
,
CreateScopes
)
{
dop
->
SplitInputs
();
dop
->
CreateScopes
();
ASSERT_EQ
(
dop
->
cache_
.
num_steps
,
4UL
);
ASSERT_EQ
(
dop
->
cache_
.
scopes
->
size
(),
4UL
);
}
TEST_F
(
DynamicRecurrentOpTestHelper
,
WriteStepInputs
)
{
dop
->
SplitInputs
();
dop
->
CreateScopes
();
dop
->
WriteStepInputs
();
for
(
size_t
step
=
0
;
step
<
dop
->
cache_
.
num_steps
;
step
++
)
{
auto
&
scope
=
dop
->
cache_
.
GetScope
(
step
);
for
(
auto
name
:
std
::
vector
<
std
::
string
>
({
"in0"
}))
{
ASSERT_TRUE
(
scope
.
FindVar
(
name
)
!=
nullptr
);
}
}
}
TEST_F
(
DynamicRecurrentOpTestHelper
,
WriteStepOutputs
)
{
dop
->
SplitInputs
();
dop
->
CreateScopes
();
dop
->
WriteStepInputs
();
dop
->
WriteStepOutputs
();
for
(
size_t
step
=
0
;
step
<
dop
->
cache_
.
num_steps
;
step
++
)
{
auto
&
scope
=
dop
->
cache_
.
GetScope
(
step
);
for
(
auto
name
:
std
::
vector
<
std
::
string
>
({
"out0"
}))
{
ASSERT_TRUE
(
scope
.
FindVar
(
name
));
}
}
}
TEST_F
(
DynamicRecurrentOpTestHelper
,
ConcatOutputs
)
{
// Let's leave this test to python unittest.
}
TEST_F
(
DynamicRecurrentOpTestHelper
,
InitStates
)
{
dop
->
SplitInputs
();
dop
->
CreateScopes
();
dop
->
WriteStepInputs
();
dop
->
WriteStepOutputs
();
dop
->
InitStates
();
for
(
size_t
step
=
0
;
step
<
dop
->
cache_
.
num_steps
;
step
++
)
{
auto
&
scope
=
dop
->
cache_
.
GetScope
(
step
);
auto
state
=
scope
.
FindVar
(
"mem"
);
ASSERT_TRUE
(
state
!=
nullptr
);
auto
*
pre_state
=
scope
.
FindVar
(
"mem@pre"
);
ASSERT_TRUE
(
pre_state
!=
nullptr
);
auto
*
boot_state
=
scope
.
FindVar
(
"boot_mem"
);
ASSERT_TRUE
(
boot_state
!=
nullptr
);
if
(
step
==
0
)
{
// check pre_state is a reference of boot_state
ASSERT_EQ
(
boot_state
->
Get
<
LoDTensor
>
().
data
<
float
>
(),
pre_state
->
Get
<
LoDTensor
>
().
data
<
float
>
());
}
}
}
}
// operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录