Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e2fd2bd0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e2fd2bd0
编写于
8月 01, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Follow comments and merge develop
上级
80baf861
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
60 addition
and
74 deletion
+60
-74
paddle/framework/backward.cc
paddle/framework/backward.cc
+40
-52
paddle/framework/backward.h
paddle/framework/backward.h
+2
-6
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+0
-1
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+10
-8
paddle/operators/fill_zeros_like_op.h
paddle/operators/fill_zeros_like_op.h
+2
-2
paddle/operators/recurrent_network_op.cc
paddle/operators/recurrent_network_op.cc
+6
-5
未找到文件。
paddle/framework/backward.cc
浏览文件 @
e2fd2bd0
...
...
@@ -31,88 +31,74 @@ static bool AllInSet(const std::vector<std::string>& names,
return
true
;
}
static
std
::
vector
<
size_t
>
InSetIdx
(
const
std
::
vector
<
std
::
string
>&
names
,
const
std
::
string
&
suffix
,
const
std
::
unordered_set
<
std
::
string
>&
set
)
{
std
::
vector
<
size_t
>
ret_val
;
ret_val
.
reserve
(
names
.
size
());
for
(
size_t
i
=
0
;
i
<
names
.
size
();
++
i
)
{
if
(
set
.
find
(
names
[
i
]
+
suffix
)
!=
set
.
end
())
{
ret_val
.
push_back
(
i
);
}
}
return
ret_val
;
}
static
std
::
shared_ptr
<
OperatorBase
>
EmptyOp
()
{
static
std
::
shared_ptr
<
OperatorBase
>
NOP
()
{
auto
net_op
=
std
::
make_shared
<
NetOp
>
();
net_op
->
type_
=
"@
EMPTY_
OP@"
;
net_op
->
type_
=
"@
N
OP@"
;
net_op
->
CompleteAddOp
();
return
net_op
;
}
/**
* @brief Backward an operator, implementation
* @param forwardOp the forward operator
* @param no_grad_names variable names not calculate for gradient. Like X@GRAD
* is not needed.
* @param uniq_id a unique index used inside BackwardImpl, it will be shared
* through recursive invoke.
* @return The backward operator. For simple situation, it is a simple operator.
* For complex situation, it is a NetOp.
*
* See Backward.h for details
*/
static
std
::
shared_ptr
<
OperatorBase
>
BackwardImpl
(
// Get backward operator from a forward operator, recursively implementation.
//
// no_grad_names the gradient variable names without gradient calculating.
//
// uniq_id is a unique index used inside recursively calling BackwardRecursive.
// use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through
// recursive calling.
//
// returns The backward operator. For simple situation, it is a simple
// operator. For complex situation, it is a NetOp.
//
// See Backward.h for details
static
std
::
shared_ptr
<
OperatorBase
>
BackwardRecursive
(
const
OperatorBase
&
forwardOp
,
std
::
unordered_set
<
std
::
string
>&
no_grad_names
,
size_t
&
uniq_id
);
std
::
shared_ptr
<
OperatorBase
>
BackwardRecursive
(
const
OperatorBase
&
forwardOp
,
std
::
unordered_set
<
std
::
string
>&
no_grad_names
,
size_t
&
uniq_id
)
{
/**
* If all input gradients of forwarding operator do not need to calculate,
* just return an EmptyOp. Not return null ptr because EmptyOp does not take
* too much time for calculation, but it is useful for simplifying logic.
*/
// If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic.
if
(
AllInSet
(
forwardOp
.
inputs_
,
OperatorBase
::
GRAD_VAR_SUFFIX
(),
no_grad_names
))
{
return
EmptyOp
();
return
NOP
();
}
/**
* All output gradients of forwarding operator do not need to calculate. Then
* all input gradients cannot be computed at all, and we put them into
* `no_grad_names` set. Return an EmptyOp.
*/
// All output gradients of forwarding operator do not need to calculate. Then
// all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
if
(
AllInSet
(
forwardOp
.
outputs_
,
OperatorBase
::
GRAD_VAR_SUFFIX
(),
no_grad_names
))
{
for
(
auto
&
name
:
forwardOp
.
inputs_
)
{
//
/
Mark all input is not need
// Mark all input is not need
no_grad_names
.
insert
(
name
+
OperatorBase
::
GRAD_VAR_SUFFIX
());
}
return
EmptyOp
();
return
NOP
();
}
//
!
Returned gradient network
// Returned gradient network
auto
net
=
std
::
make_shared
<
NetOp
>
();
if
(
forwardOp
.
IsNetOp
())
{
//
/
Because forwardOp is a net op, it can static_cast.
// Because forwardOp is a net op, it can static_cast.
auto
&
forwardNet
=
static_cast
<
const
NetOp
&>
(
forwardOp
);
//
!
Map from output gradient variable name to operator's indices in backward
//
!
net. That operator generates that variable.
// Map from output gradient variable name to operator's indices in backward
// net. That operator generates that variable.
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
size_t
>>
dup_output_ops
;
size_t
local_op_id
=
0
;
//
/
reversely travel forwardNet
// reversely travel forwardNet
for
(
auto
it
=
forwardNet
.
ops_
.
rbegin
();
it
!=
forwardNet
.
ops_
.
rend
();
++
it
,
++
local_op_id
)
{
auto
fwd
=
*
it
;
auto
bwd
=
Backward
Impl
(
*
fwd
,
no_grad_names
,
uniq_id
);
auto
bwd
=
Backward
Recursive
(
*
fwd
,
no_grad_names
,
uniq_id
);
net
->
AddOp
(
bwd
);
for
(
auto
&
out
:
bwd
->
outputs_
)
{
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
}
}
//
/
Get unique ID for this method.
// Get unique ID for this method.
auto
uid
=
uniq_id
++
;
// TODO(dzh): more comment
using
Pos
=
std
::
pair
<
size_t
,
std
::
shared_ptr
<
OperatorBase
>>
;
...
...
@@ -145,13 +131,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
}
}
else
{
//! TODO(fjy)
std
::
shared_ptr
<
OperatorBase
>
grad_op
=
OpRegistry
::
CreateGradOp
(
forwardOp
);
for
(
std
::
string
&
grad_input
:
grad_op
->
inputs_
)
{
if
(
no_grad_names
.
count
(
grad_input
))
{
std
::
string
prefix
=
grad_input
.
substr
(
0
,
grad_input
.
size
()
-
OperatorBase
::
GRAD_VAR_SUFFIX
().
size
());
grad_input
=
prefix
+
OperatorBase
::
ZERO_VAR_SUFFIX
();
// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{
prefix
},
{
grad_input
},
{}));
}
...
...
@@ -173,8 +161,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return
net
;
}
//
!
See header for comments
extern
std
::
shared_ptr
<
OperatorBase
>
Backward
(
// See header for comments
std
::
shared_ptr
<
OperatorBase
>
Backward
(
const
OperatorBase
&
forwardOp
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
)
{
std
::
unordered_set
<
std
::
string
>
no_grad_names
;
...
...
@@ -184,7 +172,7 @@ extern std::shared_ptr<OperatorBase> Backward(
no_grad_names
.
insert
(
name
+
OperatorBase
::
GRAD_VAR_SUFFIX
());
}
size_t
uid
=
0
;
return
Backward
Impl
(
forwardOp
,
no_grad_names
,
uid
);
return
Backward
Recursive
(
forwardOp
,
no_grad_names
,
uid
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/backward.h
浏览文件 @
e2fd2bd0
...
...
@@ -18,12 +18,8 @@
namespace
paddle
{
namespace
framework
{
/**
* @brief
* @param forwardOp
* @param no_grad_vars ignored input name of forward
* @return
*/
// Create the backward operator from a forward operator.
// TODO(yuyang18): Add more API reference comment.
extern
std
::
shared_ptr
<
OperatorBase
>
Backward
(
const
OperatorBase
&
forwardOp
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
);
...
...
paddle/framework/backward_test.cc
浏览文件 @
e2fd2bd0
...
...
@@ -169,7 +169,6 @@ TEST(Backward, simple_op_grad) {
ASSERT_EQ
(
"X"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
(),
gop
->
Output
(
"X"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
()));
// LOG(INFO) << gop->Output("X" + "@GRAD");
}
TEST
(
Backward
,
simple_op_not_need_grad
)
{
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
e2fd2bd0
...
...
@@ -21,15 +21,17 @@ namespace operators {
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
1
,
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1UL
,
"Input size of FillZerosLikeOp must be one."
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Output size of AddOp must be one."
);
PADDLE_ENFORCE
(
inputs
[
0
]
!=
nullptr
&&
outputs
[
0
]
!=
nullptr
,
"Outputs of FillZerosLikeOp must all be set."
);
outputs
[
0
]
->
Resize
(
inputs
[
0
]
->
dims
());
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1UL
,
"Output size of AddOp must be one."
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
0
)
!=
nullptr
,
"Input of FillZerosLikeOp must be set."
);
PADDLE_ENFORCE
(
ctx
.
OutputVar
(
0
)
!=
nullptr
,
"Output of FillZerosLikeOp must be set."
);
ctx
.
Output
<
framework
::
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
framework
::
Tensor
>
(
0
)
->
dims
());
}
};
...
...
paddle/operators/fill_zeros_like_op.h
浏览文件 @
e2fd2bd0
...
...
@@ -23,8 +23,8 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
class
FillZerosLikeKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
Kernel
Context
&
context
)
const
override
{
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
(
);
void
Compute
(
const
framework
::
Execution
Context
&
context
)
const
override
{
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
).
setZero
();
}
...
...
paddle/operators/recurrent_network_op.cc
浏览文件 @
e2fd2bd0
...
...
@@ -312,13 +312,14 @@ public:
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
const
auto
&
name
=
RecurrentOp
::
kArgName
;
// inputs and outputs stored in proto
AddInputs
(
name
.
inlinks
,
"the input that need to be segmented for each step."
);
AddInputs
(
name
.
boot_memories
,
"variables to initialize memories."
);
AddInput
(
name
.
inlinks
,
"the input that need to be segmented for each step."
)
.
SetMultiple
();
AddInput
(
name
.
boot_memories
,
"variables to initialize memories."
)
.
SetMultiple
();
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddOutput
s
(
name
.
outlinks
,
"the output that need to concated for all steps."
);
AddOutput
(
name
.
outlinks
,
"the output that need to concated for all steps."
)
.
SetMultiple
(
);
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
// Attributes stored in AttributeMap
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录