Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
301a21d8
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
301a21d8
编写于
8月 04, 2017
作者:
Y
Yi Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cpplint recurrent_op*
上级
5ae7a5f1
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
43 addition
and
28 deletion
+43
-28
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+7
-7
paddle/operators/recurrent_op.h
paddle/operators/recurrent_op.h
+31
-20
paddle/operators/recurrent_op_test.cc
paddle/operators/recurrent_op_test.cc
+5
-1
未找到文件。
paddle/operators/recurrent_op.cc
浏览文件 @
301a21d8
...
...
@@ -38,10 +38,10 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
"input link [%s] is not in scope."
,
inlinks
[
i
].
external
);
Tensor
*
input
=
input_var
->
GetMutable
<
Tensor
>
();
DDim
dims
=
input
->
dims
();
framework
::
DDim
dims
=
input
->
dims
();
PADDLE_ENFORCE
(
static_cast
<
size_t
>
(
dims
[
0
])
==
seq_len
,
"all the inlinks must have same length"
);
DDim
step_dims
=
slice_ddim
(
dims
,
1
,
dims
.
size
());
framework
::
DDim
step_dims
=
slice_ddim
(
dims
,
1
,
dims
.
size
());
for
(
size_t
j
=
0
;
j
<
seq_len
;
j
++
)
{
Tensor
*
step_input
=
step_scopes
[
j
]
->
NewVar
(
inlinks
[
i
].
internal
)
->
GetMutable
<
Tensor
>
();
...
...
@@ -64,13 +64,13 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
outlinks
[
i
].
external
);
Tensor
*
output
=
output_var
->
GetMutable
<
Tensor
>
();
if
(
infer_shape_mode
)
{
DDim
step_dims
=
step_scopes
[
0
]
framework
::
DDim
step_dims
=
step_scopes
[
0
]
->
FindVar
(
outlinks
[
i
].
internal
)
->
GetMutable
<
Tensor
>
()
->
dims
();
std
::
vector
<
int
>
dims_vec
=
vectorize
(
step_dims
);
dims_vec
.
insert
(
dims_vec
.
begin
(),
seq_len
);
output
->
Resize
(
make_ddim
(
dims_vec
));
output
->
Resize
(
framework
::
make_ddim
(
dims_vec
));
}
else
{
output
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
size_t
j
=
0
;
j
<
seq_len
;
j
++
)
{
...
...
paddle/operators/recurrent_op.h
浏览文件 @
301a21d8
...
...
@@ -68,7 +68,7 @@ struct ArgumentName {
/**
* Prepare inputs for each step net.
*/
void
SegmentInputs
(
const
std
::
vector
<
Scope
*>&
step_scopes
,
void
SegmentInputs
(
const
std
::
vector
<
framework
::
Scope
*>&
step_scopes
,
const
std
::
vector
<
Link
>&
inlinks
,
const
size_t
seq_len
,
bool
infer_shape_mode
);
...
...
@@ -76,12 +76,12 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
/**
* Process outputs of step nets and merge to variables.
*/
void
ConcatOutputs
(
const
std
::
vector
<
Scope
*>&
step_scopes
,
void
ConcatOutputs
(
const
std
::
vector
<
framework
::
Scope
*>&
step_scopes
,
const
std
::
vector
<
Link
>&
outlinks
,
const
size_t
seq_len
,
bool
infer_shape_mode
);
void
LinkMemories
(
const
std
::
vector
<
Scope
*>&
step_scopes
,
void
LinkMemories
(
const
std
::
vector
<
framework
::
Scope
*>&
step_scopes
,
const
std
::
vector
<
MemoryAttr
>&
memories
,
const
size_t
step_id
,
const
int
offset
,
...
...
@@ -101,14 +101,15 @@ void InitArgument(const ArgumentName& name, Argument* arg);
class
RecurrentAlgorithm
{
public:
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
;
void
Init
(
std
::
unique_ptr
<
rnn
::
Argument
>
arg
)
{
arg_
=
std
::
move
(
arg
);
}
/**
* InferShape must be called before Run.
*/
void
InferShape
(
const
Scope
&
scope
)
const
;
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
;
protected:
/*
...
...
@@ -117,13 +118,15 @@ protected:
* NOTE the scopes are reused in both the forward and backward, so just
* create once and expand its size if more steps need.
*/
void
CreateScopes
(
const
Scope
&
scope
)
const
;
void
CreateScopes
(
const
framework
::
Scope
&
scope
)
const
;
const
std
::
vector
<
Scope
*>&
GetStepScopes
(
const
Scope
&
scope
)
const
{
return
*
scope
.
FindVar
(
arg_
->
step_scopes
)
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
const
std
::
vector
<
framework
::
Scope
*>&
GetStepScopes
(
const
framework
::
Scope
&
scope
)
const
{
return
*
scope
.
FindVar
(
arg_
->
step_scopes
)
->
GetMutable
<
std
::
vector
<
framework
::
Scope
*>>
();
}
void
InitMemories
(
Scope
*
step_scopes
,
bool
infer_shape_mode
)
const
;
void
InitMemories
(
framework
::
Scope
*
step_scopes
,
bool
infer_shape_mode
)
const
;
private:
std
::
unique_ptr
<
rnn
::
Argument
>
arg_
;
...
...
@@ -144,18 +147,22 @@ class RecurrentGradientAlgorithm {
public:
void
Init
(
std
::
unique_ptr
<
rnn
::
Argument
>
arg
)
{
arg_
=
std
::
move
(
arg
);
}
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
;
void
LinkBootMemoryGradients
(
Scope
*
step_scopes
,
bool
infer_shape_mode
)
const
;
void
LinkBootMemoryGradients
(
framework
::
Scope
*
step_scopes
,
bool
infer_shape_mode
)
const
;
/**
* InferShape must be called before Run.
*/
void
InferShape
(
const
Scope
&
scope
)
const
;
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
;
protected:
inline
const
std
::
vector
<
Scope
*>&
GetStepScopes
(
const
Scope
&
scope
)
const
{
return
*
scope
.
FindVar
(
arg_
->
step_scopes
)
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
inline
const
std
::
vector
<
framework
::
Scope
*>&
GetStepScopes
(
const
framework
::
Scope
&
scope
)
const
{
return
*
scope
.
FindVar
(
arg_
->
step_scopes
)
->
GetMutable
<
std
::
vector
<
framework
::
Scope
*>>
();
}
private:
...
...
@@ -163,16 +170,18 @@ private:
mutable
size_t
seq_len_
;
};
class
RecurrentOp
final
:
public
OperatorBase
{
class
RecurrentOp
final
:
public
framework
::
OperatorBase
{
public:
void
Init
()
override
;
/**
* InferShape must be called before Run.
*/
void
InferShape
(
const
Scope
&
scope
)
const
override
{
alg_
.
InferShape
(
scope
);
}
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
override
{
alg_
.
InferShape
(
scope
);
}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
alg_
.
Run
(
scope
,
dev_ctx
);
}
...
...
@@ -183,16 +192,18 @@ private:
RecurrentAlgorithm
alg_
;
};
class
RecurrentGradientOp
final
:
public
OperatorBase
{
class
RecurrentGradientOp
final
:
public
framework
::
OperatorBase
{
public:
void
Init
()
override
;
/**
* InferShape must be called before Run.
*/
void
InferShape
(
const
Scope
&
scope
)
const
override
{
alg_
.
InferShape
(
scope
);
}
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
override
{
alg_
.
InferShape
(
scope
);
}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
alg_
.
Run
(
scope
,
dev_ctx
);
}
...
...
paddle/operators/recurrent_op_test.cc
浏览文件 @
301a21d8
...
...
@@ -16,6 +16,7 @@
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/framework/ddim.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
...
...
@@ -24,6 +25,9 @@
namespace
paddle
{
namespace
operators
{
using
framework
::
make_ddim
;
using
framework
::
DDim
;
class
RecurrentOpTest
:
public
::
testing
::
Test
{
protected:
virtual
void
SetUp
()
override
{
...
...
@@ -72,7 +76,7 @@ protected:
}
void
CreateRNNOp
()
{
OpDesc
op_desc
;
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"recurrent_op"
);
// inlinks 0
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录