Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
eb612a82
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eb612a82
编写于
12月 22, 2017
作者:
Y
Yu Yang
提交者:
GitHub
12月 22, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6181 from emailweixu/enforce_drop_empty_ig
Enforce drop_empty_grad=false When the input of an op is duplicable.
上级
abde3130
0bfa1f7c
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
66 addition
and
27 deletion
+66
-27
paddle/framework/grad_op_desc_maker.h
paddle/framework/grad_op_desc_maker.h
+18
-0
paddle/framework/op_desc.h
paddle/framework/op_desc.h
+2
-0
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+30
-13
paddle/operators/concat_op.cc
paddle/operators/concat_op.cc
+2
-2
paddle/operators/conditional_block_op.cc
paddle/operators/conditional_block_op.cc
+3
-2
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+1
-1
paddle/operators/sequence_concat_op.cc
paddle/operators/sequence_concat_op.cc
+7
-6
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+3
-3
未找到文件。
paddle/framework/grad_op_desc_maker.h
浏览文件 @
eb612a82
...
...
@@ -22,6 +22,14 @@
namespace
paddle
{
namespace
framework
{
/*
This functor class is responsible for creating the gradient ops for the given
operator fwd_op. After it is called (through operator()), the pairs of
(gradient variable, corresponding input variable of fwd_op) will be added to
grad_to_var. If an input variable of fwd_op is contained in no_grad_set, its
gradient varialbe will be ignored or kEmptyVarName depending on the template
argument DropEmptyIG in the derived classes.
*/
class
GradOpDescMakerBase
{
public:
explicit
GradOpDescMakerBase
(
...
...
@@ -56,6 +64,16 @@ class GradOpDescMakerBase {
if
(
!
drop_empty_grad
)
{
return
ret_val
;
}
PADDLE_ENFORCE_LE
(
var_names
.
size
(),
1UL
,
"BUG from operator developer:"
" for input argument with a list of variables, "
" drop_empty_grad is not allowed because it makes"
" the correspondence bewteen a variable and its gradient"
" ambiguous. Use REGISTER_OP_EX to register the op"
" or call InputGrad(?,false) in GradOpDescMaker."
" Op type %s"
,
fwd_op_
.
Type
());
std
::
vector
<
std
::
string
>
dropped_ret_val
;
dropped_ret_val
.
reserve
(
ret_val
.
size
());
std
::
copy_if
(
ret_val
.
begin
(),
ret_val
.
end
(),
...
...
paddle/framework/op_desc.h
浏览文件 @
eb612a82
...
...
@@ -127,7 +127,9 @@ class OpDesc {
}
proto
::
OpDesc
desc_
;
// input arg name => output variable names
VariableNameMap
inputs_
;
// output arg name => output variable names
VariableNameMap
outputs_
;
AttributeMap
attrs_
;
...
...
paddle/framework/op_registry.h
浏览文件 @
eb612a82
...
...
@@ -126,6 +126,14 @@ class OpKernelRegistrar : public Registrar {
__test_global_namespace_##uniq_name##__>::value, \
msg)
/*
The variadic arguments should be class types derived from one of the
following classes:
OpProtoAndCheckerMaker
GradOpDescMakerBase
VarTypeInference
InferShapeBase
*/
#define REGISTER_OPERATOR(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \
...
...
@@ -144,15 +152,24 @@ class OpKernelRegistrar : public Registrar {
}
/**
* Macro to register Operator.
* Macro to register Operator. When the input is duplicable, you should
* use REGISTER_OP_EX with deop_empty_grad=false instead.
*/
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
REGISTER_OP_EX(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class, true)
// When an argument is duplicable, we need to use this version.
// Perhaps we can omit DropEmptyIG template parameter and
// only have one version of REGISTER_OP.
#define REGISTER_OP_EX(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class, drop_empty_grad) \
REGISTER_OPERATOR(grad_op_type, grad_op_class); \
class _GradOpDescMaker_##grad_op_type##_ \
: public ::paddle::framework::DefaultGradOpDescMaker<
true> {
\
: public ::paddle::framework::DefaultGradOpDescMaker<
drop_empty_grad> {
\
using ::paddle::framework::DefaultGradOpDescMaker< \
true>::DefaultGradOpDescMaker;
\
drop_empty_grad>::DefaultGradOpDescMaker;
\
\
protected: \
virtual std::string GradOpType() const { return #grad_op_type; } \
...
...
paddle/operators/concat_op.cc
浏览文件 @
eb612a82
...
...
@@ -98,8 +98,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
concat
,
ops
::
ConcatOp
,
ops
::
ConcatOpMaker
,
concat_grad
,
ops
::
ConcatOpGrad
)
REGISTER_OP
_EX
(
concat
,
ops
::
ConcatOp
,
ops
::
ConcatOpMaker
,
concat_grad
,
ops
::
ConcatOpGrad
,
false
)
REGISTER_OP_CPU_KERNEL
(
concat
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
)
REGISTER_OP_CPU_KERNEL
(
concat_grad
,
...
...
paddle/operators/conditional_block_op.cc
浏览文件 @
eb612a82
...
...
@@ -178,8 +178,9 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
grad_op
->
SetInput
(
"Out"
,
Output
(
"Out"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
"Scope"
,
Output
(
"Scope"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"Params"
),
InputGrad
(
"Params"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
,
false
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"Params"
),
InputGrad
(
"Params"
,
false
));
grad_op
->
SetBlockAttr
(
"sub_block"
,
*
this
->
grad_block_
[
0
]);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
eb612a82
...
...
@@ -570,7 +570,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
for
(
auto
&
input_param
:
this
->
InputNames
())
{
grad
->
SetInput
(
input_param
,
this
->
Input
(
input_param
));
grad
->
SetOutput
(
framework
::
GradVarName
(
input_param
),
this
->
InputGrad
(
input_param
));
this
->
InputGrad
(
input_param
,
false
));
}
for
(
auto
&
output_param
:
this
->
OutputNames
())
{
...
...
paddle/operators/sequence_concat_op.cc
浏览文件 @
eb612a82
...
...
@@ -124,8 +124,9 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
sequence_concat
,
ops
::
SequenceConcatOp
,
ops
::
SequenceConcatOpMaker
,
sequence_concat_grad
,
ops
::
SequenceConcatGradOp
);
REGISTER_OP_EX
(
sequence_concat
,
ops
::
SequenceConcatOp
,
ops
::
SequenceConcatOpMaker
,
sequence_concat_grad
,
ops
::
SequenceConcatGradOp
,
false
);
REGISTER_OP_CPU_KERNEL
(
sequence_concat
,
ops
::
SequenceConcatOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
...
...
paddle/operators/sum_op.cc
浏览文件 @
eb612a82
...
...
@@ -170,7 +170,7 @@ class SumGradMaker : public framework::GradOpDescMakerBase {
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDesc
>>
operator
()()
const
override
{
auto
x_grads
=
InputGrad
(
"X"
);
auto
x_grads
=
InputGrad
(
"X"
,
false
);
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDesc
>>
grad_ops
;
grad_ops
.
reserve
(
x_grads
.
size
());
auto
og
=
OutputGrad
(
"Out"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录