Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e2fd2bd0
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e2fd2bd0
编写于
8月 01, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Follow comments and merge develop
上级
80baf861
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
60 addition
and
74 deletion
+60
-74
paddle/framework/backward.cc
paddle/framework/backward.cc
+40
-52
paddle/framework/backward.h
paddle/framework/backward.h
+2
-6
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+0
-1
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+10
-8
paddle/operators/fill_zeros_like_op.h
paddle/operators/fill_zeros_like_op.h
+2
-2
paddle/operators/recurrent_network_op.cc
paddle/operators/recurrent_network_op.cc
+6
-5
未找到文件。
paddle/framework/backward.cc
浏览文件 @
e2fd2bd0
...
@@ -31,88 +31,74 @@ static bool AllInSet(const std::vector<std::string>& names,
...
@@ -31,88 +31,74 @@ static bool AllInSet(const std::vector<std::string>& names,
return
true
;
return
true
;
}
}
static
std
::
vector
<
size_t
>
InSetIdx
(
static
std
::
shared_ptr
<
OperatorBase
>
NOP
()
{
const
std
::
vector
<
std
::
string
>&
names
,
const
std
::
string
&
suffix
,
const
std
::
unordered_set
<
std
::
string
>&
set
)
{
std
::
vector
<
size_t
>
ret_val
;
ret_val
.
reserve
(
names
.
size
());
for
(
size_t
i
=
0
;
i
<
names
.
size
();
++
i
)
{
if
(
set
.
find
(
names
[
i
]
+
suffix
)
!=
set
.
end
())
{
ret_val
.
push_back
(
i
);
}
}
return
ret_val
;
}
static
std
::
shared_ptr
<
OperatorBase
>
EmptyOp
()
{
auto
net_op
=
std
::
make_shared
<
NetOp
>
();
auto
net_op
=
std
::
make_shared
<
NetOp
>
();
net_op
->
type_
=
"@
EMPTY_
OP@"
;
net_op
->
type_
=
"@
N
OP@"
;
net_op
->
CompleteAddOp
();
net_op
->
CompleteAddOp
();
return
net_op
;
return
net_op
;
}
}
/**
// Get backward operator from a forward operator, recursively implementation.
* @brief Backward an operator, implementation
//
* @param forwardOp the forward operator
// no_grad_names the gradient variable names without gradient calculating.
* @param no_grad_names variable names not calculate for gradient. Like X@GRAD
//
* is not needed.
// uniq_id is a unique index used inside recursively calling BackwardRecursive.
* @param uniq_id a unique index used inside BackwardImpl, it will be shared
// use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through
* through recursive invoke.
// recursive calling.
* @return The backward operator. For simple situation, it is a simple operator.
//
* For complex situation, it is a NetOp.
// returns The backward operator. For simple situation, it is a simple
*
// operator. For complex situation, it is a NetOp.
* See Backward.h for details
//
*/
// See Backward.h for details
static
std
::
shared_ptr
<
OperatorBase
>
BackwardImpl
(
static
std
::
shared_ptr
<
OperatorBase
>
BackwardRecursive
(
const
OperatorBase
&
forwardOp
,
std
::
unordered_set
<
std
::
string
>&
no_grad_names
,
size_t
&
uniq_id
);
std
::
shared_ptr
<
OperatorBase
>
BackwardRecursive
(
const
OperatorBase
&
forwardOp
,
const
OperatorBase
&
forwardOp
,
std
::
unordered_set
<
std
::
string
>&
no_grad_names
,
size_t
&
uniq_id
)
{
std
::
unordered_set
<
std
::
string
>&
no_grad_names
,
size_t
&
uniq_id
)
{
/**
// If all input gradients of forwarding operator do not need to calculate,
* If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take
* just return an EmptyOp. Not return null ptr because EmptyOp does not take
// too much time for calculation, but it is useful for simplifying logic.
* too much time for calculation, but it is useful for simplifying logic.
*/
if
(
AllInSet
(
forwardOp
.
inputs_
,
OperatorBase
::
GRAD_VAR_SUFFIX
(),
if
(
AllInSet
(
forwardOp
.
inputs_
,
OperatorBase
::
GRAD_VAR_SUFFIX
(),
no_grad_names
))
{
no_grad_names
))
{
return
EmptyOp
();
return
NOP
();
}
}
/**
// All output gradients of forwarding operator do not need to calculate. Then
* All output gradients of forwarding operator do not need to calculate. Then
// all input gradients cannot be computed at all, and we put them into
* all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
* `no_grad_names` set. Return an EmptyOp.
*/
if
(
AllInSet
(
forwardOp
.
outputs_
,
OperatorBase
::
GRAD_VAR_SUFFIX
(),
if
(
AllInSet
(
forwardOp
.
outputs_
,
OperatorBase
::
GRAD_VAR_SUFFIX
(),
no_grad_names
))
{
no_grad_names
))
{
for
(
auto
&
name
:
forwardOp
.
inputs_
)
{
for
(
auto
&
name
:
forwardOp
.
inputs_
)
{
//
/
Mark all input is not need
// Mark all input is not need
no_grad_names
.
insert
(
name
+
OperatorBase
::
GRAD_VAR_SUFFIX
());
no_grad_names
.
insert
(
name
+
OperatorBase
::
GRAD_VAR_SUFFIX
());
}
}
return
EmptyOp
();
return
NOP
();
}
}
//
!
Returned gradient network
// Returned gradient network
auto
net
=
std
::
make_shared
<
NetOp
>
();
auto
net
=
std
::
make_shared
<
NetOp
>
();
if
(
forwardOp
.
IsNetOp
())
{
if
(
forwardOp
.
IsNetOp
())
{
//
/
Because forwardOp is a net op, it can static_cast.
// Because forwardOp is a net op, it can static_cast.
auto
&
forwardNet
=
static_cast
<
const
NetOp
&>
(
forwardOp
);
auto
&
forwardNet
=
static_cast
<
const
NetOp
&>
(
forwardOp
);
//
!
Map from output gradient variable name to operator's indices in backward
// Map from output gradient variable name to operator's indices in backward
//
!
net. That operator generates that variable.
// net. That operator generates that variable.
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
size_t
>>
dup_output_ops
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
size_t
>>
dup_output_ops
;
size_t
local_op_id
=
0
;
size_t
local_op_id
=
0
;
//
/
reversely travel forwardNet
// reversely travel forwardNet
for
(
auto
it
=
forwardNet
.
ops_
.
rbegin
();
it
!=
forwardNet
.
ops_
.
rend
();
for
(
auto
it
=
forwardNet
.
ops_
.
rbegin
();
it
!=
forwardNet
.
ops_
.
rend
();
++
it
,
++
local_op_id
)
{
++
it
,
++
local_op_id
)
{
auto
fwd
=
*
it
;
auto
fwd
=
*
it
;
auto
bwd
=
Backward
Impl
(
*
fwd
,
no_grad_names
,
uniq_id
);
auto
bwd
=
Backward
Recursive
(
*
fwd
,
no_grad_names
,
uniq_id
);
net
->
AddOp
(
bwd
);
net
->
AddOp
(
bwd
);
for
(
auto
&
out
:
bwd
->
outputs_
)
{
for
(
auto
&
out
:
bwd
->
outputs_
)
{
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
}
}
}
}
//
/
Get unique ID for this method.
// Get unique ID for this method.
auto
uid
=
uniq_id
++
;
auto
uid
=
uniq_id
++
;
// TODO(dzh): more comment
// TODO(dzh): more comment
using
Pos
=
std
::
pair
<
size_t
,
std
::
shared_ptr
<
OperatorBase
>>
;
using
Pos
=
std
::
pair
<
size_t
,
std
::
shared_ptr
<
OperatorBase
>>
;
...
@@ -145,13 +131,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
...
@@ -145,13 +131,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
}
}
}
else
{
}
else
{
//! TODO(fjy)
std
::
shared_ptr
<
OperatorBase
>
grad_op
=
OpRegistry
::
CreateGradOp
(
forwardOp
);
std
::
shared_ptr
<
OperatorBase
>
grad_op
=
OpRegistry
::
CreateGradOp
(
forwardOp
);
for
(
std
::
string
&
grad_input
:
grad_op
->
inputs_
)
{
for
(
std
::
string
&
grad_input
:
grad_op
->
inputs_
)
{
if
(
no_grad_names
.
count
(
grad_input
))
{
if
(
no_grad_names
.
count
(
grad_input
))
{
std
::
string
prefix
=
grad_input
.
substr
(
std
::
string
prefix
=
grad_input
.
substr
(
0
,
grad_input
.
size
()
-
OperatorBase
::
GRAD_VAR_SUFFIX
().
size
());
0
,
grad_input
.
size
()
-
OperatorBase
::
GRAD_VAR_SUFFIX
().
size
());
grad_input
=
prefix
+
OperatorBase
::
ZERO_VAR_SUFFIX
();
grad_input
=
prefix
+
OperatorBase
::
ZERO_VAR_SUFFIX
();
// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{
prefix
},
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{
prefix
},
{
grad_input
},
{}));
{
grad_input
},
{}));
}
}
...
@@ -173,8 +161,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
...
@@ -173,8 +161,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return
net
;
return
net
;
}
}
//
!
See header for comments
// See header for comments
extern
std
::
shared_ptr
<
OperatorBase
>
Backward
(
std
::
shared_ptr
<
OperatorBase
>
Backward
(
const
OperatorBase
&
forwardOp
,
const
OperatorBase
&
forwardOp
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
)
{
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
)
{
std
::
unordered_set
<
std
::
string
>
no_grad_names
;
std
::
unordered_set
<
std
::
string
>
no_grad_names
;
...
@@ -184,7 +172,7 @@ extern std::shared_ptr<OperatorBase> Backward(
...
@@ -184,7 +172,7 @@ extern std::shared_ptr<OperatorBase> Backward(
no_grad_names
.
insert
(
name
+
OperatorBase
::
GRAD_VAR_SUFFIX
());
no_grad_names
.
insert
(
name
+
OperatorBase
::
GRAD_VAR_SUFFIX
());
}
}
size_t
uid
=
0
;
size_t
uid
=
0
;
return
Backward
Impl
(
forwardOp
,
no_grad_names
,
uid
);
return
Backward
Recursive
(
forwardOp
,
no_grad_names
,
uid
);
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/backward.h
浏览文件 @
e2fd2bd0
...
@@ -18,12 +18,8 @@
...
@@ -18,12 +18,8 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
/**
// Create the backward operator from a forward operator.
* @brief
// TODO(yuyang18): Add more API reference comment.
* @param forwardOp
* @param no_grad_vars ignored input name of forward
* @return
*/
extern
std
::
shared_ptr
<
OperatorBase
>
Backward
(
extern
std
::
shared_ptr
<
OperatorBase
>
Backward
(
const
OperatorBase
&
forwardOp
,
const
OperatorBase
&
forwardOp
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
);
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
);
...
...
paddle/framework/backward_test.cc
浏览文件 @
e2fd2bd0
...
@@ -169,7 +169,6 @@ TEST(Backward, simple_op_grad) {
...
@@ -169,7 +169,6 @@ TEST(Backward, simple_op_grad) {
ASSERT_EQ
(
"X"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
(),
ASSERT_EQ
(
"X"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
(),
gop
->
Output
(
"X"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
()));
gop
->
Output
(
"X"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
()));
// LOG(INFO) << gop->Output("X" + "@GRAD");
}
}
TEST
(
Backward
,
simple_op_not_need_grad
)
{
TEST
(
Backward
,
simple_op_not_need_grad
)
{
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
e2fd2bd0
...
@@ -21,15 +21,17 @@ namespace operators {
...
@@ -21,15 +21,17 @@ namespace operators {
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1UL
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{
PADDLE_ENFORCE
(
inputs
.
size
()
==
1
,
"Input size of FillZerosLikeOp must be one."
);
"Input size of FillZerosLikeOp must be one."
);
PADDLE_ENFORCE
(
outputs
.
size
()
==
1
,
"Output size of AddOp must be one."
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1UL
,
PADDLE_ENFORCE
(
inputs
[
0
]
!=
nullptr
&&
outputs
[
0
]
!=
nullptr
,
"Output size of AddOp must be one."
);
"Outputs of FillZerosLikeOp must all be set."
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
0
)
!=
nullptr
,
outputs
[
0
]
->
Resize
(
inputs
[
0
]
->
dims
());
"Input of FillZerosLikeOp must be set."
);
PADDLE_ENFORCE
(
ctx
.
OutputVar
(
0
)
!=
nullptr
,
"Output of FillZerosLikeOp must be set."
);
ctx
.
Output
<
framework
::
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
framework
::
Tensor
>
(
0
)
->
dims
());
}
}
};
};
...
...
paddle/operators/fill_zeros_like_op.h
浏览文件 @
e2fd2bd0
...
@@ -23,8 +23,8 @@ namespace operators {
...
@@ -23,8 +23,8 @@ namespace operators {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
FillZerosLikeKernel
:
public
framework
::
OpKernel
{
class
FillZerosLikeKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
Kernel
Context
&
context
)
const
override
{
void
Compute
(
const
framework
::
Execution
Context
&
context
)
const
override
{
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
(
);
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
).
setZero
();
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
).
setZero
();
}
}
...
...
paddle/operators/recurrent_network_op.cc
浏览文件 @
e2fd2bd0
...
@@ -312,13 +312,14 @@ public:
...
@@ -312,13 +312,14 @@ public:
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
const
auto
&
name
=
RecurrentOp
::
kArgName
;
const
auto
&
name
=
RecurrentOp
::
kArgName
;
// inputs and outputs stored in proto
// inputs and outputs stored in proto
AddInputs
(
name
.
inlinks
,
AddInput
(
name
.
inlinks
,
"the input that need to be segmented for each step."
)
"the input that need to be segmented for each step."
);
.
SetMultiple
();
AddInputs
(
name
.
boot_memories
,
"variables to initialize memories."
);
AddInput
(
name
.
boot_memories
,
"variables to initialize memories."
)
.
SetMultiple
();
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddOutput
s
(
name
.
outlinks
,
AddOutput
(
name
.
outlinks
,
"the output that need to concated for all steps."
)
"the output that need to concated for all steps."
);
.
SetMultiple
(
);
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
// Attributes stored in AttributeMap
// Attributes stored in AttributeMap
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录