Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e2fd2bd0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e2fd2bd0
编写于
8月 01, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Follow comments and merge develop
上级
80baf861
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
60 addition
and
74 deletion
+60
-74
paddle/framework/backward.cc
paddle/framework/backward.cc
+40
-52
paddle/framework/backward.h
paddle/framework/backward.h
+2
-6
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+0
-1
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+10
-8
paddle/operators/fill_zeros_like_op.h
paddle/operators/fill_zeros_like_op.h
+2
-2
paddle/operators/recurrent_network_op.cc
paddle/operators/recurrent_network_op.cc
+6
-5
未找到文件。
paddle/framework/backward.cc
浏览文件 @
e2fd2bd0
...
@@ -31,88 +31,74 @@ static bool AllInSet(const std::vector<std::string>& names,
...
@@ -31,88 +31,74 @@ static bool AllInSet(const std::vector<std::string>& names,
return
true
;
return
true
;
}
}
static
std
::
vector
<
size_t
>
InSetIdx
(
static
std
::
shared_ptr
<
OperatorBase
>
NOP
()
{
const
std
::
vector
<
std
::
string
>&
names
,
const
std
::
string
&
suffix
,
const
std
::
unordered_set
<
std
::
string
>&
set
)
{
std
::
vector
<
size_t
>
ret_val
;
ret_val
.
reserve
(
names
.
size
());
for
(
size_t
i
=
0
;
i
<
names
.
size
();
++
i
)
{
if
(
set
.
find
(
names
[
i
]
+
suffix
)
!=
set
.
end
())
{
ret_val
.
push_back
(
i
);
}
}
return
ret_val
;
}
static
std
::
shared_ptr
<
OperatorBase
>
EmptyOp
()
{
auto
net_op
=
std
::
make_shared
<
NetOp
>
();
auto
net_op
=
std
::
make_shared
<
NetOp
>
();
net_op
->
type_
=
"@
EMPTY_
OP@"
;
net_op
->
type_
=
"@
N
OP@"
;
net_op
->
CompleteAddOp
();
net_op
->
CompleteAddOp
();
return
net_op
;
return
net_op
;
}
}
/**
// Get backward operator from a forward operator, recursively implementation.
* @brief Backward an operator, implementation
//
* @param forwardOp the forward operator
// no_grad_names the gradient variable names without gradient calculating.
* @param no_grad_names variable names not calculate for gradient. Like X@GRAD
//
* is not needed.
// uniq_id is a unique index used inside recursively calling BackwardRecursive.
* @param uniq_id a unique index used inside BackwardImpl, it will be shared
// use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through
* through recursive invoke.
// recursive calling.
* @return The backward operator. For simple situation, it is a simple operator.
//
* For complex situation, it is a NetOp.
// returns The backward operator. For simple situation, it is a simple
*
// operator. For complex situation, it is a NetOp.
* See Backward.h for details
//
*/
// See Backward.h for details
static
std
::
shared_ptr
<
OperatorBase
>
BackwardImpl
(
static
std
::
shared_ptr
<
OperatorBase
>
BackwardRecursive
(
const
OperatorBase
&
forwardOp
,
std
::
unordered_set
<
std
::
string
>&
no_grad_names
,
size_t
&
uniq_id
);
std
::
shared_ptr
<
OperatorBase
>
BackwardRecursive
(
const
OperatorBase
&
forwardOp
,
const
OperatorBase
&
forwardOp
,
std
::
unordered_set
<
std
::
string
>&
no_grad_names
,
size_t
&
uniq_id
)
{
std
::
unordered_set
<
std
::
string
>&
no_grad_names
,
size_t
&
uniq_id
)
{
/**
// If all input gradients of forwarding operator do not need to calculate,
* If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take
* just return an EmptyOp. Not return null ptr because EmptyOp does not take
// too much time for calculation, but it is useful for simplifying logic.
* too much time for calculation, but it is useful for simplifying logic.
*/
if
(
AllInSet
(
forwardOp
.
inputs_
,
OperatorBase
::
GRAD_VAR_SUFFIX
(),
if
(
AllInSet
(
forwardOp
.
inputs_
,
OperatorBase
::
GRAD_VAR_SUFFIX
(),
no_grad_names
))
{
no_grad_names
))
{
return
EmptyOp
();
return
NOP
();
}
}
/**
// All output gradients of forwarding operator do not need to calculate. Then
* All output gradients of forwarding operator do not need to calculate. Then
// all input gradients cannot be computed at all, and we put them into
* all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
* `no_grad_names` set. Return an EmptyOp.
*/
if
(
AllInSet
(
forwardOp
.
outputs_
,
OperatorBase
::
GRAD_VAR_SUFFIX
(),
if
(
AllInSet
(
forwardOp
.
outputs_
,
OperatorBase
::
GRAD_VAR_SUFFIX
(),
no_grad_names
))
{
no_grad_names
))
{
for
(
auto
&
name
:
forwardOp
.
inputs_
)
{
for
(
auto
&
name
:
forwardOp
.
inputs_
)
{
//
/
Mark all input is not need
// Mark all input is not need
no_grad_names
.
insert
(
name
+
OperatorBase
::
GRAD_VAR_SUFFIX
());
no_grad_names
.
insert
(
name
+
OperatorBase
::
GRAD_VAR_SUFFIX
());
}
}
return
EmptyOp
();
return
NOP
();
}
}
//
!
Returned gradient network
// Returned gradient network
auto
net
=
std
::
make_shared
<
NetOp
>
();
auto
net
=
std
::
make_shared
<
NetOp
>
();
if
(
forwardOp
.
IsNetOp
())
{
if
(
forwardOp
.
IsNetOp
())
{
//
/
Because forwardOp is a net op, it can static_cast.
// Because forwardOp is a net op, it can static_cast.
auto
&
forwardNet
=
static_cast
<
const
NetOp
&>
(
forwardOp
);
auto
&
forwardNet
=
static_cast
<
const
NetOp
&>
(
forwardOp
);
//
!
Map from output gradient variable name to operator's indices in backward
// Map from output gradient variable name to operator's indices in backward
//
!
net. That operator generates that variable.
// net. That operator generates that variable.
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
size_t
>>
dup_output_ops
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
size_t
>>
dup_output_ops
;
size_t
local_op_id
=
0
;
size_t
local_op_id
=
0
;
//
/
reversely travel forwardNet
// reversely travel forwardNet
for
(
auto
it
=
forwardNet
.
ops_
.
rbegin
();
it
!=
forwardNet
.
ops_
.
rend
();
for
(
auto
it
=
forwardNet
.
ops_
.
rbegin
();
it
!=
forwardNet
.
ops_
.
rend
();
++
it
,
++
local_op_id
)
{
++
it
,
++
local_op_id
)
{
auto
fwd
=
*
it
;
auto
fwd
=
*
it
;
auto
bwd
=
Backward
Impl
(
*
fwd
,
no_grad_names
,
uniq_id
);
auto
bwd
=
Backward
Recursive
(
*
fwd
,
no_grad_names
,
uniq_id
);
net
->
AddOp
(
bwd
);
net
->
AddOp
(
bwd
);
for
(
auto
&
out
:
bwd
->
outputs_
)
{
for
(
auto
&
out
:
bwd
->
outputs_
)
{
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
}
}
}
}
//
/
Get unique ID for this method.
// Get unique ID for this method.
auto
uid
=
uniq_id
++
;
auto
uid
=
uniq_id
++
;
// TODO(dzh): more comment
// TODO(dzh): more comment
using
Pos
=
std
::
pair
<
size_t
,
std
::
shared_ptr
<
OperatorBase
>>
;
using
Pos
=
std
::
pair
<
size_t
,
std
::
shared_ptr
<
OperatorBase
>>
;
...
@@ -145,13 +131,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
...
@@ -145,13 +131,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
}
}
}
else
{
}
else
{
//! TODO(fjy)
std
::
shared_ptr
<
OperatorBase
>
grad_op
=
OpRegistry
::
CreateGradOp
(
forwardOp
);
std
::
shared_ptr
<
OperatorBase
>
grad_op
=
OpRegistry
::
CreateGradOp
(
forwardOp
);
for
(
std
::
string
&
grad_input
:
grad_op
->
inputs_
)
{
for
(
std
::
string
&
grad_input
:
grad_op
->
inputs_
)
{
if
(
no_grad_names
.
count
(
grad_input
))
{
if
(
no_grad_names
.
count
(
grad_input
))
{
std
::
string
prefix
=
grad_input
.
substr
(
std
::
string
prefix
=
grad_input
.
substr
(
0
,
grad_input
.
size
()
-
OperatorBase
::
GRAD_VAR_SUFFIX
().
size
());
0
,
grad_input
.
size
()
-
OperatorBase
::
GRAD_VAR_SUFFIX
().
size
());
grad_input
=
prefix
+
OperatorBase
::
ZERO_VAR_SUFFIX
();
grad_input
=
prefix
+
OperatorBase
::
ZERO_VAR_SUFFIX
();
// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{
prefix
},
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{
prefix
},
{
grad_input
},
{}));
{
grad_input
},
{}));
}
}
...
@@ -173,8 +161,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
...
@@ -173,8 +161,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return
net
;
return
net
;
}
}
//
!
See header for comments
// See header for comments
extern
std
::
shared_ptr
<
OperatorBase
>
Backward
(
std
::
shared_ptr
<
OperatorBase
>
Backward
(
const
OperatorBase
&
forwardOp
,
const
OperatorBase
&
forwardOp
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
)
{
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
)
{
std
::
unordered_set
<
std
::
string
>
no_grad_names
;
std
::
unordered_set
<
std
::
string
>
no_grad_names
;
...
@@ -184,7 +172,7 @@ extern std::shared_ptr<OperatorBase> Backward(
...
@@ -184,7 +172,7 @@ extern std::shared_ptr<OperatorBase> Backward(
no_grad_names
.
insert
(
name
+
OperatorBase
::
GRAD_VAR_SUFFIX
());
no_grad_names
.
insert
(
name
+
OperatorBase
::
GRAD_VAR_SUFFIX
());
}
}
size_t
uid
=
0
;
size_t
uid
=
0
;
return
Backward
Impl
(
forwardOp
,
no_grad_names
,
uid
);
return
Backward
Recursive
(
forwardOp
,
no_grad_names
,
uid
);
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/backward.h
浏览文件 @
e2fd2bd0
...
@@ -18,12 +18,8 @@
...
@@ -18,12 +18,8 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
/**
// Create the backward operator from a forward operator.
* @brief
// TODO(yuyang18): Add more API reference comment.
* @param forwardOp
* @param no_grad_vars ignored input name of forward
* @return
*/
extern
std
::
shared_ptr
<
OperatorBase
>
Backward
(
extern
std
::
shared_ptr
<
OperatorBase
>
Backward
(
const
OperatorBase
&
forwardOp
,
const
OperatorBase
&
forwardOp
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
);
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
);
...
...
paddle/framework/backward_test.cc
浏览文件 @
e2fd2bd0
...
@@ -169,7 +169,6 @@ TEST(Backward, simple_op_grad) {
...
@@ -169,7 +169,6 @@ TEST(Backward, simple_op_grad) {
ASSERT_EQ
(
"X"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
(),
ASSERT_EQ
(
"X"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
(),
gop
->
Output
(
"X"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
()));
gop
->
Output
(
"X"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
()));
// LOG(INFO) << gop->Output("X" + "@GRAD");
}
}
TEST
(
Backward
,
simple_op_not_need_grad
)
{
TEST
(
Backward
,
simple_op_not_need_grad
)
{
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
e2fd2bd0
...
@@ -21,15 +21,17 @@ namespace operators {
...
@@ -21,15 +21,17 @@ namespace operators {
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1UL
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
1
,
"Input size of FillZerosLikeOp must be one."
);
"Input size of FillZerosLikeOp must be one."
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Output size of AddOp must be one."
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1UL
,
PADDLE_ENFORCE
(
inputs
[
0
]
!=
nullptr
&&
outputs
[
0
]
!=
nullptr
,
"Output size of AddOp must be one."
);
"Outputs of FillZerosLikeOp must all be set."
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
0
)
!=
nullptr
,
outputs
[
0
]
->
Resize
(
inputs
[
0
]
->
dims
());
"Input of FillZerosLikeOp must be set."
);
PADDLE_ENFORCE
(
ctx
.
OutputVar
(
0
)
!=
nullptr
,
"Output of FillZerosLikeOp must be set."
);
ctx
.
Output
<
framework
::
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
framework
::
Tensor
>
(
0
)
->
dims
());
}
}
};
};
...
...
paddle/operators/fill_zeros_like_op.h
浏览文件 @
e2fd2bd0
...
@@ -23,8 +23,8 @@ namespace operators {
...
@@ -23,8 +23,8 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
FillZerosLikeKernel
:
public
framework
::
OpKernel
{
class
FillZerosLikeKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
Kernel
Context
&
context
)
const
override
{
void
Compute
(
const
framework
::
Execution
Context
&
context
)
const
override
{
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
(
);
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
).
setZero
();
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
).
setZero
();
}
}
...
...
paddle/operators/recurrent_network_op.cc
浏览文件 @
e2fd2bd0
...
@@ -312,13 +312,14 @@ public:
...
@@ -312,13 +312,14 @@ public:
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
const
auto
&
name
=
RecurrentOp
::
kArgName
;
const
auto
&
name
=
RecurrentOp
::
kArgName
;
// inputs and outputs stored in proto
// inputs and outputs stored in proto
AddInputs
(
name
.
inlinks
,
AddInput
(
name
.
inlinks
,
"the input that need to be segmented for each step."
)
"the input that need to be segmented for each step."
);
.
SetMultiple
();
AddInputs
(
name
.
boot_memories
,
"variables to initialize memories."
);
AddInput
(
name
.
boot_memories
,
"variables to initialize memories."
)
.
SetMultiple
();
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddOutput
s
(
name
.
outlinks
,
AddOutput
(
name
.
outlinks
,
"the output that need to concated for all steps."
)
"the output that need to concated for all steps."
);
.
SetMultiple
(
);
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
// Attributes stored in AttributeMap
// Attributes stored in AttributeMap
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录