Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8c653ba7
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8c653ba7
编写于
8月 16, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Complete remove std::shared_ptr
上级
c7f25325
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
107 addition
and
95 deletion
+107
-95
paddle/framework/backward.cc
paddle/framework/backward.cc
+19
-21
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+5
-6
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+2
-4
paddle/framework/pybind.cc
paddle/framework/pybind.cc
+16
-21
paddle/operators/net_op.h
paddle/operators/net_op.h
+31
-10
paddle/operators/net_op_test.cc
paddle/operators/net_op_test.cc
+10
-13
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+10
-10
paddle/operators/recurrent_op.h
paddle/operators/recurrent_op.h
+14
-10
未找到文件。
paddle/framework/backward.cc
浏览文件 @
8c653ba7
...
...
@@ -15,6 +15,8 @@
#include "paddle/framework/backward.h"
#include <list>
#include <memory>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
...
...
@@ -43,11 +45,11 @@ static bool AllInSet(
return
all_in_set
;
}
static
std
::
shared
_ptr
<
OperatorBase
>
NOP
()
{
auto
net_op
=
std
::
make_shared
<
operators
::
NetOp
>
();
static
std
::
unique
_ptr
<
OperatorBase
>
NOP
()
{
auto
net_op
=
new
operators
::
NetOp
();
net_op
->
SetType
(
"@NOP@"
);
net_op
->
CompleteAddOp
();
return
net_op
;
return
std
::
unique_ptr
<
OperatorBase
>
(
net_op
)
;
}
// Get backward operator from a forward operator, a recursive implementation.
...
...
@@ -62,11 +64,7 @@ static std::shared_ptr<OperatorBase> NOP() {
// operator, in a complex situation, it maybe a NetOp.
//
// See Backward.h for details
static
std
::
shared_ptr
<
OperatorBase
>
BackwardRecursive
(
const
OperatorBase
&
forwardOp
,
std
::
unordered_set
<
std
::
string
>&
no_grad_names
,
size_t
&
uniq_id
);
std
::
shared_ptr
<
OperatorBase
>
BackwardRecursive
(
static
std
::
unique_ptr
<
OperatorBase
>
BackwardRecursive
(
const
OperatorBase
&
forwardOp
,
std
::
unordered_set
<
std
::
string
>&
no_grad_names
,
size_t
&
uniq_id
)
{
// If all input gradients of forwarding operator do not need to calculate,
...
...
@@ -91,7 +89,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
}
// Returned gradient network
auto
net
=
std
::
make_shared
<
operators
::
NetOp
>
();
auto
net
=
std
::
unique_ptr
<
operators
::
NetOp
>
();
if
(
forwardOp
.
IsNetOp
())
{
// Because forwardOp is a net op, it can static_cast.
...
...
@@ -105,14 +103,14 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// reversely travel forwardNet and collect all duplicate outputs.
for
(
auto
it
=
forwardNet
.
ops_
.
rbegin
();
it
!=
forwardNet
.
ops_
.
rend
();
++
it
,
++
local_op_id
)
{
auto
fwd
=
*
it
;
auto
&
fwd
=
*
it
;
auto
bwd
=
BackwardRecursive
(
*
fwd
,
no_grad_names
,
uniq_id
);
net
->
AddOp
(
bwd
);
ForEachVarName
(
bwd
->
Outputs
(),
[
&
dup_output_ops
,
local_op_id
](
const
std
::
string
&
out
)
{
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
return
false
;
});
net
->
AddOp
(
std
::
move
(
bwd
));
}
// Get unique ID for this method.
auto
uid
=
uniq_id
++
;
...
...
@@ -122,7 +120,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// to handle this case. For each duplicate output, rename it to an alias
// (original name with a offset), append an `add` op for its operator,
// and finally sum all the alias variable to the final output variable y.
using
Pos
=
std
::
pair
<
size_t
,
std
::
shared
_ptr
<
OperatorBase
>>
;
using
Pos
=
std
::
pair
<
size_t
,
std
::
unique
_ptr
<
OperatorBase
>>
;
std
::
list
<
Pos
>
insert_position
;
for
(
auto
&
dup_output_op
:
dup_output_ops
)
{
const
std
::
string
&
name
=
dup_output_op
.
first
;
...
...
@@ -150,13 +148,13 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
[](
const
Pos
&
l
,
const
Pos
&
r
)
{
return
l
.
first
>
r
.
first
;
});
for
(
auto
&
pos
:
insert_position
)
{
net
->
InsertOp
(
pos
.
first
+
1
,
pos
.
second
);
net
->
InsertOp
(
pos
.
first
+
1
,
std
::
move
(
pos
.
second
)
);
}
}
else
{
std
::
shared_ptr
<
OperatorBase
>
grad_op
=
OpRegistry
::
CreateGradOp
(
forwardOp
);
std
::
unique_ptr
<
OperatorBase
>
grad_op
(
OpRegistry
::
CreateGradOp
(
forwardOp
)
);
ForEachVarName
(
grad_op
->
Inputs
(),
[
&
no_grad_names
,
&
net
,
grad_op
](
const
std
::
string
&
grad_input
)
{
ForEachVarName
(
grad_op
->
Inputs
(),
[
&
no_grad_names
,
&
net
,
&
grad_op
](
const
std
::
string
&
grad_input
)
{
if
(
no_grad_names
.
count
(
grad_input
))
{
// +1 for \0
std
::
string
prefix
=
grad_input
.
substr
(
...
...
@@ -190,20 +188,20 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
const
auto
&
stepnet_op
=
*
static_cast
<
const
OperatorBase
*>
(
&
rnnop
.
stepnet
());
// create stepnet's gradient op
auto
grad_stepnet
=
BackwardRecursive
(
stepnet_op
,
no_grad_names
,
uniq_id
);
rnn_grad_op
->
set_stepnet
(
std
::
static_pointer_cast
<
operators
::
NetOp
>
(
grad_stepnet
));
BackwardRecursive
(
stepnet_op
,
no_grad_names
,
uniq_id
));
}
if
(
net
->
ops_
.
empty
())
{
// Current no aux op is added to network
return
grad_op
;
}
net
->
AddOp
(
grad_op
);
net
->
AddOp
(
std
::
move
(
grad_op
)
);
}
net
->
SetType
(
"@GENERATED_BACKWARD@"
);
net
->
CompleteAddOp
();
return
net
;
}
// namespace framework
return
std
::
unique_ptr
<
OperatorBase
>
(
static_cast
<
OperatorBase
*>
(
net
.
release
()));
}
// See header for comments
std
::
shared_ptr
<
OperatorBase
>
Backward
(
...
...
paddle/framework/op_registry.h
浏览文件 @
8c653ba7
...
...
@@ -174,7 +174,7 @@ class OpRegistry {
}
}
static
std
::
shared
_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
static
std
::
unique
_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
AttributeMap
attrs
)
{
...
...
@@ -183,7 +183,7 @@ class OpRegistry {
"Operator '%s' has not been registered."
,
type
);
it
->
second
.
checker_
->
Check
(
attrs
);
auto
op
=
it
->
second
.
creator_
(
type
,
inputs
,
outputs
,
attrs
);
return
std
::
shared
_ptr
<
OperatorBase
>
(
op
);
return
std
::
unique
_ptr
<
OperatorBase
>
(
op
);
}
static
VarNameMap
ConvertOpDescVarsToVarNameMap
(
...
...
@@ -199,7 +199,7 @@ class OpRegistry {
return
ret_val
;
}
static
std
::
shared
_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
)
{
static
std
::
unique
_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
)
{
VarNameMap
inputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
inputs
());
VarNameMap
outputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
outputs
());
AttributeMap
attrs
;
...
...
@@ -210,11 +210,10 @@ class OpRegistry {
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
}
static
std
::
shared
_ptr
<
OperatorBase
>
CreateGradOp
(
const
OperatorBase
&
op
)
{
static
std
::
unique
_ptr
<
OperatorBase
>
CreateGradOp
(
const
OperatorBase
&
op
)
{
PADDLE_ENFORCE
(
!
op
.
IsNetOp
(),
"Use framework::Backward to get backward ops"
);
std
::
shared_ptr
<
OperatorBase
>
grad_op
(
BuildGradOp
(
&
op
));
return
grad_op
;
return
std
::
unique_ptr
<
OperatorBase
>
(
BuildGradOp
(
&
op
));
}
static
std
::
unordered_map
<
std
::
string
,
const
OpInfo
>&
op_info_map
()
{
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
8c653ba7
...
...
@@ -76,8 +76,7 @@ TEST(OpRegistry, CreateOp) {
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
scale
);
std
::
shared_ptr
<
paddle
::
framework
::
OperatorBase
>
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
Scope
scope
;
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
op
->
Run
(
scope
,
dev_ctx
);
...
...
@@ -118,8 +117,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE
(
op_desc
.
IsInitialized
());
std
::
shared_ptr
<
paddle
::
framework
::
OperatorBase
>
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
Scope
scope
;
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
op
->
Run
(
scope
,
dev_ctx
);
...
...
paddle/framework/pybind.cc
浏览文件 @
8c653ba7
...
...
@@ -207,8 +207,7 @@ All parameter, weight, gradient are variables in Paddle.
.
def
(
py
::
init
<>
())
.
def
(
"__str__"
,
string
::
to_string
<
const
platform
::
CPUPlace
&>
);
py
::
class_
<
OperatorBase
,
std
::
shared_ptr
<
OperatorBase
>>
operator_base
(
m
,
"Operator"
);
py
::
class_
<
OperatorBase
>
operator_base
(
m
,
"Operator"
);
operator_base
.
def_static
(
"create"
,
[](
py
::
bytes
protobin
)
{
OpDesc
desc
;
...
...
@@ -228,25 +227,23 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator
(
operator_base
);
py
::
class_
<
operators
::
NetOp
,
std
::
shared_ptr
<
operators
::
NetOp
>
>
net
(
m
,
"Net"
);
py
::
class_
<
operators
::
NetOp
>
net
(
m
,
"Net"
);
net
.
def_static
(
"create"
,
[]()
->
std
::
shared_ptr
<
operators
::
NetOp
>
{
auto
retv
=
std
::
make_shared
<
operators
::
NetOp
>
()
;
[]()
->
operators
::
NetOp
*
{
auto
*
retv
=
new
operators
::
NetOp
;
retv
->
SetType
(
"plain_net"
);
return
retv
;
})
.
def
(
"add_op"
,
&
operators
::
NetOp
::
AddOp
)
.
def
(
"add_op"
,
[](
operators
::
NetOp
&
self
,
const
OperatorBase
&
op
)
{
self
.
AddOp
(
op
);
})
.
def
(
"add_op"
,
[](
operators
::
NetOp
&
self
,
const
std
::
shared_ptr
<
operators
::
NetOp
>
&
net
)
->
void
{
self
.
AddOp
(
std
::
static_pointer_cast
<
OperatorBase
>
(
net
));
[](
operators
::
NetOp
&
self
,
const
operators
::
NetOp
&
net
)
->
void
{
self
.
AddOp
(
net
);
})
.
def
(
"add_op"
,
[](
operators
::
NetOp
&
self
,
const
std
::
shared_ptr
<
operators
::
RecurrentOp
>
&
rnn
)
->
void
{
self
.
AddOp
(
std
::
static_pointer_cast
<
OperatorBase
>
(
rnn
));
})
const
operators
::
RecurrentOp
&
rnn
)
->
void
{
self
.
AddOp
(
rnn
);
})
.
def
(
"complete_add_op"
,
&
operators
::
NetOp
::
CompleteAddOp
)
.
def
(
"complete_add_op"
,
[](
std
::
shared_ptr
<
operators
::
NetOp
>
&
self
)
{
self
->
CompleteAddOp
();
...
...
@@ -255,12 +252,11 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator
(
net
);
// recurrent_op
py
::
class_
<
operators
::
RecurrentOp
,
std
::
shared_ptr
<
operators
::
RecurrentOp
>>
rnn
(
m
,
"RecurrentOp"
);
py
::
class_
<
operators
::
RecurrentOp
>
rnn
(
m
,
"RecurrentOp"
);
rnn
.
def_static
(
"create"
,
[](
py
::
bytes
protobin
)
->
std
::
shared_ptr
<
operators
::
RecurrentOp
>
{
[](
py
::
bytes
protobin
)
->
operators
::
RecurrentOp
*
{
OpDesc
desc
;
PADDLE_ENFORCE
(
desc
.
ParsePartialFromString
(
protobin
),
"Cannot parse user input to OpDesc"
);
...
...
@@ -268,13 +264,12 @@ All parameter, weight, gradient are variables in Paddle.
"User OpDesc is not initialized, reason %s"
,
desc
.
InitializationErrorString
());
auto
rnn_op
=
OpRegistry
::
CreateOp
(
desc
);
return
st
d
::
dynamic_pointer_cast
<
operators
::
RecurrentOp
>
(
rnn_op
);
return
st
atic_cast
<
operators
::
RecurrentOp
*>
(
rnn_op
.
release
()
);
})
.
def
(
"set_stepnet"
,
[](
operators
::
RecurrentOp
&
self
,
const
std
::
shared_ptr
<
operators
::
NetOp
>
&
net
)
->
void
{
self
.
set_stepnet
(
net
);
});
.
def
(
"set_stepnet"
,
[](
operators
::
RecurrentOp
&
self
,
const
operators
::
NetOp
&
net
)
->
void
{
self
.
set_stepnet
(
net
.
Clone
());
});
ExposeOperator
(
rnn
);
m
.
def
(
"unique_integer"
,
UniqueIntegerGenerator
);
...
...
paddle/operators/net_op.h
浏览文件 @
8c653ba7
...
...
@@ -45,11 +45,11 @@ class NetOp : public framework::OperatorBase {
:
framework
::
OperatorBase
(
static_cast
<
const
framework
::
OperatorBase
&>
(
o
))
{
this
->
ops_
.
reserve
(
o
.
ops_
.
size
());
std
::
transform
(
o
.
ops_
.
begin
(),
o
.
ops_
.
end
(),
std
::
back_inserter
(
this
->
ops_
),
[](
const
std
::
shared_ptr
<
OperatorBase
>&
op
)
->
std
::
shared_ptr
<
OperatorBase
>
{
return
std
::
shared_ptr
<
OperatorBase
>
(
op
->
Clone
());
});
std
::
transform
(
o
.
ops_
.
begin
(),
o
.
ops_
.
end
(),
std
::
back_inserter
(
this
->
ops_
),
[](
const
std
::
unique_ptr
<
framework
::
OperatorBase
>&
op
)
{
return
std
::
unique_ptr
<
framework
::
OperatorBase
>
(
op
->
Clone
());
});
this
->
CompleteAddOp
();
}
...
...
@@ -86,21 +86,42 @@ class NetOp : public framework::OperatorBase {
return
true
;
}
void
AddOp
(
const
framework
::
OperatorBase
&
op
)
{
AddOp
(
op
.
Clone
());
}
/**
* @brief Add an operator by ptr
*/
void
AddOp
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
)
{
void
AddOp
(
framework
::
OperatorBase
*
op
,
bool
own
)
{
PADDLE_ENFORCE
(
!
add_op_done_
,
"Cannot AddOp when this network is sealed"
);
PADDLE_ENFORCE_NOT_NULL
(
op
,
"Cannot Insert Null op"
);
ops_
.
push_back
(
op
);
if
(
!
own
)
{
op
=
op
->
Clone
().
release
();
}
ops_
.
emplace_back
(
op
);
}
void
InsertOp
(
size_t
pos
,
const
std
::
shared_ptr
<
OperatorBase
>&
op
)
{
void
AddOp
(
std
::
unique_ptr
<
framework
::
OperatorBase
>&&
op
)
{
AddOp
(
op
.
release
(),
true
);
}
void
InsertOp
(
size_t
pos
,
framework
::
OperatorBase
*
op
,
bool
own
)
{
PADDLE_ENFORCE
(
!
add_op_done_
,
"Cannot InsertOp when this network is sealed"
);
PADDLE_ENFORCE_NOT_NULL
(
op
,
"Cannot Insert Null op"
);
PADDLE_ENFORCE_LE
(
pos
,
ops_
.
size
(),
"Out of range"
);
ops_
.
insert
(
ops_
.
begin
()
+
pos
,
op
);
if
(
!
own
)
{
op
=
op
->
Clone
().
release
();
}
ops_
.
insert
(
ops_
.
begin
()
+
pos
,
std
::
unique_ptr
<
framework
::
OperatorBase
>
(
op
));
}
void
InsertOp
(
size_t
pos
,
std
::
unique_ptr
<
framework
::
OperatorBase
>&&
op
)
{
InsertOp
(
pos
,
op
.
release
(),
true
);
}
void
InsertOp
(
size_t
pos
,
const
framework
::
OperatorBase
&
op
)
{
InsertOp
(
pos
,
op
.
Clone
());
}
void
CompleteAddOp
(
bool
calculate
=
true
);
...
...
@@ -112,7 +133,7 @@ class NetOp : public framework::OperatorBase {
std
::
unique_ptr
<
framework
::
OperatorBase
>
Clone
()
const
override
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops_
;
private:
bool
add_op_done_
{
false
};
...
...
paddle/operators/net_op_test.cc
浏览文件 @
8c653ba7
...
...
@@ -38,15 +38,12 @@ TEST(OpKernel, all) {
auto
net
=
std
::
make_shared
<
NetOp
>
();
ASSERT_NE
(
net
,
nullptr
);
auto
op1
=
std
::
shared
_ptr
<
TestOp
>
(
net
->
AddOp
(
std
::
unique
_ptr
<
TestOp
>
(
new
TestOp
(
"test"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
{{
"Out"
,
{
"y"
}}},
{}));
net
->
AddOp
(
op1
);
auto
op2
=
std
::
shared_ptr
<
TestOp
>
(
{{
"Out"
,
{
"y"
}}},
{})));
net
->
AddOp
(
std
::
unique_ptr
<
TestOp
>
(
new
TestOp
(
"test"
,
{{
"X"
,
{
"y"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}},
{{
"Out"
,
{
"z"
}}},
{}));
net
->
AddOp
(
op2
);
{{
"Out"
,
{
"z"
}}},
{})));
net
->
CompleteAddOp
();
AssertSameVectorWithoutOrder
({
"x"
,
"w1"
,
"b1"
,
"w2"
,
"b2"
},
...
...
@@ -61,21 +58,21 @@ TEST(OpKernel, all) {
TEST
(
NetOp
,
insert_op
)
{
NetOp
net
;
auto
op1
=
std
::
shared
_ptr
<
framework
::
NOP
>
(
auto
op1
=
std
::
unique
_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
(
"empty"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
{{
"Out"
,
{
"y"
}}},
{}));
net
.
AddOp
(
op1
);
net
.
InsertOp
(
0
,
op1
);
net
.
AddOp
(
*
op1
);
net
.
InsertOp
(
0
,
*
op1
);
ASSERT_EQ
(
2UL
,
net
.
ops_
.
size
());
net
.
InsertOp
(
2
,
op1
);
net
.
InsertOp
(
2
,
std
::
move
(
op1
)
);
ASSERT_EQ
(
3UL
,
net
.
ops_
.
size
());
}
TEST
(
NetOp
,
Clone
)
{
NetOp
net
;
net
.
AddOp
(
std
::
shared
_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
{
"empty"
,
{},
{},
{}}));
net
.
AddOp
(
std
::
shared
_ptr
<
framework
::
NOP
>
(
std
::
unique
_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
{
"empty"
,
{},
{},
{}}));
net
.
AddOp
(
std
::
unique
_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
{
"empty2"
,
{},
{},
{}}));
net
.
CompleteAddOp
(
true
);
auto
new_net_op
=
net
.
Clone
();
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
8c653ba7
...
...
@@ -42,7 +42,7 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
i
,
-
1
,
true
/*infer_shape_mode*/
);
}
(
*
stepnet_
)
->
InferShape
(
*
step_scopes
[
i
]);
stepnet_
->
InferShape
(
*
step_scopes
[
i
]);
}
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
true
/*infer_shape_mode*/
);
...
...
@@ -61,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
-
1
,
false
/*infer_shape_mode*/
);
}
(
*
stepnet_
)
->
Run
(
*
step_scopes
[
step_id
],
dev_ctx
);
stepnet_
->
Run
(
*
step_scopes
[
step_id
],
dev_ctx
);
}
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
false
/*infer_shape_mode*/
);
...
...
@@ -76,15 +76,15 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// Now all variables in scope must be created outside of op.
PADDLE_ENFORCE_NOT_NULL
(
stepnet_
);
PADDLE_ENFORCE
(
!
(
*
stepnet_
)
->
Outputs
().
empty
(),
"stepnet_ op has no outputs"
);
PADDLE_ENFORCE
(
!
(
*
stepnet_
)
->
Outputs
().
empty
(),
"net_op has no outputs"
);
PADDLE_ENFORCE
(
!
stepnet_
->
Outputs
().
empty
(),
"stepnet_ op has no outputs"
);
PADDLE_ENFORCE
(
!
stepnet_
->
Outputs
().
empty
(),
"net_op has no outputs"
);
if
(
seq_len_
>
step_scopes
->
size
())
{
for
(
size_t
i
=
step_scopes
->
size
();
i
<
seq_len_
;
++
i
)
{
auto
&
step_scope
=
scope
.
NewScope
();
// create step net's temp inputs
for
(
auto
&
input
:
(
*
stepnet_
)
->
Inputs
())
{
for
(
auto
&
input
:
stepnet_
->
Inputs
())
{
// the weight are located in parent scope
for
(
auto
&
var_name
:
input
.
second
)
{
if
(
!
step_scope
.
FindVar
(
var_name
))
{
...
...
@@ -93,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
}
}
// create stepnet's outputs
for
(
const
auto
&
output
:
(
*
stepnet_
)
->
Outputs
())
{
for
(
const
auto
&
output
:
stepnet_
->
Outputs
())
{
for
(
auto
&
var_name
:
output
.
second
)
{
step_scope
.
NewVar
(
var_name
);
}
...
...
@@ -136,7 +136,7 @@ RecurrentOp::RecurrentOp(const std::string& type,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
rnn
::
InitArgument
(
kArgName
,
&
arg_
,
*
this
);
alg_
.
Init
(
&
arg_
,
&
stepnet_
);
alg_
.
Init
(
&
arg_
,
stepnet_
.
get
()
);
}
class
RecurrentAlgorithmProtoAndCheckerMaker
...
...
@@ -178,7 +178,7 @@ void RecurrentGradientAlgorithm::Run(
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
1
,
false
/*infer_shape_mode*/
);
}
(
*
stepnet_
)
->
Run
(
*
step_scopes
[
step_id
],
dev_ctx
);
stepnet_
->
Run
(
*
step_scopes
[
step_id
],
dev_ctx
);
}
LinkBootMemoryGradients
(
step_scopes
[
0
],
false
);
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
...
...
@@ -215,7 +215,7 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
1
,
true
/*infer_shape_mode*/
);
}
(
*
stepnet_
)
->
InferShape
(
*
step_scopes
[
step_id
]);
stepnet_
->
InferShape
(
*
step_scopes
[
step_id
]);
}
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
true
/*infer_shape_mode*/
);
...
...
@@ -228,7 +228,7 @@ RecurrentGradientOp::RecurrentGradientOp(
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
rnn
::
InitArgument
(
kArgName
,
&
arg_
,
*
this
);
alg_
.
Init
(
&
arg_
,
&
stepnet_
);
alg_
.
Init
(
&
arg_
,
stepnet_
.
get
()
);
}
}
// namespace operators
...
...
paddle/operators/recurrent_op.h
浏览文件 @
8c653ba7
...
...
@@ -34,7 +34,7 @@ class RecurrentAlgorithm {
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
;
void
Init
(
rnn
::
Argument
*
arg
,
std
::
shared_ptr
<
NetOp
>
*
stepnet
)
{
void
Init
(
rnn
::
Argument
*
arg
,
framework
::
OperatorBase
*
stepnet
)
{
PADDLE_ENFORCE_NOT_NULL
(
stepnet
,
"stepnet should be set before."
);
arg_
=
arg
;
stepnet_
=
stepnet
;
...
...
@@ -63,7 +63,7 @@ class RecurrentAlgorithm {
void
InitMemories
(
framework
::
Scope
*
step_scopes
,
bool
infer_shape_mode
)
const
;
private:
std
::
shared_ptr
<
NetOp
>
*
stepnet_
;
framework
::
OperatorBase
*
stepnet_
;
rnn
::
Argument
*
arg_
;
mutable
size_t
seq_len_
;
};
...
...
@@ -80,7 +80,7 @@ class RecurrentGradientAlgorithm {
* operator.
*/
public:
void
Init
(
rnn
::
Argument
*
arg
,
std
::
shared_ptr
<
NetOp
>
*
stepnet
)
{
void
Init
(
rnn
::
Argument
*
arg
,
framework
::
OperatorBase
*
stepnet
)
{
PADDLE_ENFORCE_NOT_NULL
(
stepnet
,
"stepnet should be set before."
);
arg_
=
std
::
move
(
arg
);
stepnet_
=
stepnet
;
...
...
@@ -107,7 +107,7 @@ class RecurrentGradientAlgorithm {
private:
rnn
::
Argument
*
arg_
;
mutable
size_t
seq_len_
;
std
::
shared_ptr
<
NetOp
>
*
stepnet_
;
framework
::
OperatorBase
*
stepnet_
;
};
class
RecurrentOp
:
public
framework
::
OperatorBase
{
...
...
@@ -133,15 +133,17 @@ class RecurrentOp : public framework::OperatorBase {
alg_
.
Run
(
scope
,
dev_ctx
);
}
void
set_stepnet
(
std
::
shared_ptr
<
NetOp
>
net
)
{
stepnet_
=
net
;
}
const
NetOp
&
stepnet
()
const
{
return
*
stepnet_
;
}
void
set_stepnet
(
std
::
unique_ptr
<
OperatorBase
>
net
)
{
stepnet_
=
std
::
move
(
net
);
}
const
OperatorBase
&
stepnet
()
const
{
return
*
stepnet_
;
}
static
const
rnn
::
ArgumentName
kArgName
;
private:
RecurrentAlgorithm
alg_
;
rnn
::
Argument
arg_
;
std
::
shared_ptr
<
NetOp
>
stepnet_
;
std
::
unique_ptr
<
OperatorBase
>
stepnet_
;
};
class
RecurrentGradientOp
:
public
framework
::
OperatorBase
{
...
...
@@ -171,12 +173,14 @@ class RecurrentGradientOp : public framework::OperatorBase {
static
const
rnn
::
ArgumentName
kArgName
;
void
set_stepnet
(
const
std
::
shared_ptr
<
NetOp
>&
net
)
{
stepnet_
=
net
;
}
const
NetOp
&
stepnet
()
const
{
return
*
stepnet_
;
}
void
set_stepnet
(
std
::
unique_ptr
<
OperatorBase
>
net
)
{
stepnet_
=
std
::
move
(
net
);
}
const
OperatorBase
&
stepnet
()
const
{
return
*
stepnet_
;
}
private:
RecurrentGradientAlgorithm
alg_
;
std
::
shared_ptr
<
NetOp
>
stepnet_
;
std
::
unique_ptr
<
OperatorBase
>
stepnet_
;
rnn
::
Argument
arg_
;
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录