Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
SummerGao.
Paddle
提交
46c551b2
P
Paddle
项目概览
SummerGao.
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
46c551b2
编写于
10月 02, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Complete Register Gradient in compile time
上级
479e4a50
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
152 addition
and
94 deletion
+152
-94
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+21
-11
paddle/framework/details/op_registry.h
paddle/framework/details/op_registry.h
+0
-1
paddle/framework/framework.proto
paddle/framework/framework.proto
+0
-1
paddle/framework/op_info.h
paddle/framework/op_info.h
+3
-0
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+0
-1
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+19
-2
paddle/operators/minus_op.cc
paddle/operators/minus_op.cc
+22
-24
paddle/operators/pad_op.cc
paddle/operators/pad_op.cc
+19
-3
paddle/operators/scale_op.cc
paddle/operators/scale_op.cc
+15
-18
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+32
-13
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+21
-20
未找到文件。
paddle/framework/backward_test.cc
浏览文件 @
46c551b2
...
...
@@ -21,24 +21,34 @@
namespace
paddle
{
namespace
framework
{
using
OperatorBase
=
framework
::
OperatorBase
;
using
OpProtoAndCheckerMaker
=
framework
::
OpProtoAndCheckerMaker
;
using
OpProto
=
framework
::
OpProto
;
using
OpAttrChecker
=
framework
::
OpAttrChecker
;
using
Scope
=
framework
::
Scope
;
using
DeviceContext
=
platform
::
DeviceContext
;
class
RowWiseAddOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input X of Add"
)
.
NotInGradient
()
;
AddInput
(
"b"
,
"Bias of Add"
)
.
NotInGradient
()
;
AddOutput
(
"Out"
,
"Out of Add"
)
.
NotInGradient
()
;
AddInput
(
"X"
,
"Input X of Add"
);
AddInput
(
"b"
,
"Bias of Add"
);
AddOutput
(
"Out"
,
"Out of Add"
);
AddComment
(
"Add Op"
);
}
};
class
RowWiseAddGradMaker
:
public
SingleGradOpDescMaker
{
public:
using
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
OpDescBind
Apply
()
const
override
{
OpDescBind
grad_op
;
grad_op
.
SetInput
(
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
.
SetOutput
(
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
.
SetOutput
(
GradVarName
(
"b"
),
InputGrad
(
"b"
));
grad_op
.
SetType
(
"rowwise_add_grad"
);
return
grad_op
;
}
};
class
MulOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
...
...
@@ -148,8 +158,9 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
namespace
f
=
paddle
::
framework
;
namespace
ops
=
paddle
::
operators
;
using
EnforceNotMet
=
paddle
::
platform
::
EnforceNotMet
;
REGISTER_OP
(
rowwise_add
,
f
::
NOP
,
f
::
RowWiseAddOpMaker
,
rowwise_add_grad
,
f
::
NOP
);
REGISTER_OPERATOR
(
rowwise_add
,
f
::
NOP
,
f
::
RowWiseAddOpMaker
,
f
::
RowWiseAddGradMaker
);
REGISTER_OPERATOR
(
rowwise_add_grad
,
f
::
NOP
);
REGISTER_OP
(
mul
,
f
::
NOP
,
f
::
MulOpMaker
,
mul_grad
,
f
::
NOP
);
REGISTER_OP
(
sigmoid
,
f
::
NOP
,
f
::
SigmoidOpMaker
,
sigmoid_grad
,
f
::
NOP
);
REGISTER_OP_WITHOUT_GRADIENT
(
nograd
,
f
::
NOP
,
f
::
NoGradOpMaker
);
...
...
@@ -378,7 +389,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
+
1UL
/* external output number*/
+
1UL
/* number of gradient of external output*/
+
2U
/* internal variable number*/
);
std
::
cerr
<<
grad_fc
.
DebugString
()
<<
std
::
endl
;
EXPECT_EQ
(
grad_fc
.
Outputs
(
all
).
size
(),
2UL
/* input number of mul*/
...
...
paddle/framework/details/op_registry.h
浏览文件 @
46c551b2
...
...
@@ -85,7 +85,6 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
info
->
proto_
=
new
OpProto
;
info
->
checker_
=
new
OpAttrChecker
();
auto
maker
=
T
(
info
->
proto_
,
info
->
checker_
);
std
::
cerr
<<
"Assign Maker "
<<
op_type
<<
std
::
endl
;
maker
.
Validate
();
info
->
proto_
->
set_type
(
op_type
);
PADDLE_ENFORCE
(
...
...
paddle/framework/framework.proto
浏览文件 @
46c551b2
...
...
@@ -66,7 +66,6 @@ message OpProto {
optional
bool
duplicable
=
3
[
default
=
false
];
optional
bool
intermediate
=
4
[
default
=
false
];
optional
bool
not_in_gradient
=
5
[
default
=
false
];
}
// AttrProto describes the C++ type Attribute.
...
...
paddle/framework/op_info.h
浏览文件 @
46c551b2
...
...
@@ -17,11 +17,14 @@
#include <map>
#include <string>
#include <unordered_map>
#include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/type_defs.h"
#include "paddle/platform/macros.h"
#include "glog/logging.h"
namespace
paddle
{
namespace
framework
{
...
...
paddle/framework/op_registry.h
浏览文件 @
46c551b2
...
...
@@ -46,7 +46,6 @@ class Registrar {
template
<
typename
...
ARGS
>
struct
OperatorRegistrar
:
public
Registrar
{
explicit
OperatorRegistrar
(
const
char
*
op_type
)
:
op_type
(
op_type
)
{
std
::
cerr
<<
"Reg operator "
<<
op_type
<<
std
::
endl
;
PADDLE_ENFORCE
(
!
OpInfoMap
::
Instance
().
Has
(
op_type
),
"'%s' is registered more than once."
,
op_type
);
static_assert
(
sizeof
...(
ARGS
)
!=
0
,
...
...
paddle/operators/mean_op.cc
浏览文件 @
46c551b2
...
...
@@ -36,7 +36,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input of mean op"
);
AddOutput
(
"Out"
,
"The output of mean op"
)
.
NotInGradient
()
;
AddOutput
(
"Out"
,
"The output of mean op"
);
AddComment
(
R"DOC( Mean Operator
)DOC"
);
}
...
...
@@ -52,11 +52,28 @@ class MeanGradOp : public framework::OperatorWithKernel {
}
};
class
MeanGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
framework
::
OpDescBind
Apply
()
const
override
{
framework
::
OpDescBind
grad_op
;
grad_op
.
SetType
(
"mean_grad"
);
grad_op
.
SetInput
(
"X"
,
Input
(
"X"
));
grad_op
.
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
.
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
return
grad_op
;
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
mean
,
ops
::
MeanOp
,
ops
::
MeanOpMaker
,
mean_grad
,
ops
::
MeanGradOp
);
REGISTER_OPERATOR
(
mean
,
ops
::
MeanOp
,
ops
::
MeanOpMaker
,
ops
::
MeanGradMaker
);
REGISTER_OPERATOR
(
mean_grad
,
ops
::
MeanGradOp
);
REGISTER_OP_CPU_KERNEL
(
mean
,
ops
::
MeanKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
mean_grad
,
...
...
paddle/operators/minus_op.cc
浏览文件 @
46c551b2
...
...
@@ -49,9 +49,9 @@ class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MinusOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The left tensor of minus operator."
)
.
NotInGradient
()
;
AddInput
(
"Y"
,
"The right tensor of minus operator."
)
.
NotInGradient
()
;
AddOutput
(
"Out"
,
"The output tensor of minus operator."
)
.
NotInGradient
()
;
AddInput
(
"X"
,
"The left tensor of minus operator."
);
AddInput
(
"Y"
,
"The right tensor of minus operator."
);
AddOutput
(
"Out"
,
"The output tensor of minus operator."
);
AddComment
(
R"DOC(Minus Operator
...
...
@@ -64,26 +64,25 @@ or not. But the output only shares the LoD with input `X`.
)DOC"
);
}
};
template
<
typename
AttrType
>
class
MinusGrad
Op
:
public
NetOp
{
class
MinusGrad
Maker
:
public
framework
::
GradOpDescMakerBase
{
public:
MinusGradOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
auto
out_grad
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
x_grad
=
Output
(
framework
::
GradVarName
(
"X"
));
auto
y_grad
=
Output
(
framework
::
GradVarName
(
"Y"
));
// x_grad = out_grad
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"identity"
,
{{
"X"
,
{
out_grad
}}},
{{
"Y"
,
{
x_grad
}}},
{}));
framework
::
AttributeMap
scale_attr
;
scale_attr
[
"scale"
]
=
static_cast
<
AttrType
>
(
-
1
);
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"scale"
,
{{
"X"
,
{
out_grad
}}},
{{
"Out"
,
{
y_grad
}}},
scale_attr
));
CompleteAddOp
(
false
);
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
std
::
vector
<
framework
::
OpDescBind
>
operator
()()
const
override
{
std
::
vector
<
framework
::
OpDescBind
>
ops
;
ops
.
resize
(
2
);
ops
[
0
].
SetType
(
"scale"
);
ops
[
0
].
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
ops
[
0
].
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
ops
[
0
].
SetAttr
(
"scale"
,
1.0
f
);
ops
[
1
].
SetType
(
"scale"
);
ops
[
1
].
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
ops
[
1
].
SetOutput
(
"Out"
,
InputGrad
(
"Y"
));
ops
[
1
].
SetAttr
(
"scale"
,
-
1.0
f
);
return
ops
;
}
};
...
...
@@ -91,7 +90,6 @@ class MinusGradOp : public NetOp {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
minus
,
ops
::
MinusOp
,
ops
::
MinusOpMaker
,
minus_grad
,
ops
::
MinusGradOp
<
float
>
);
REGISTER_OPERATOR
(
minus
,
ops
::
MinusOp
,
ops
::
MinusOpMaker
,
ops
::
MinusGradMaker
);
REGISTER_OP_CPU_KERNEL
(
minus
,
ops
::
MinusKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/pad_op.cc
浏览文件 @
46c551b2
...
...
@@ -56,8 +56,7 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker {
"The input should be a k-D tensor(k > 0 and k < 7)"
);
AddOutput
(
"Out"
,
"The output of pad op."
"A tensor with the same shape as X."
)
.
NotInGradient
();
"A tensor with the same shape as X."
);
AddComment
(
R"DOC(
Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example:
...
...
@@ -111,11 +110,28 @@ class PadOpGrad : public framework::OperatorWithKernel {
}
};
class
PadOpGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
protected:
framework
::
OpDescBind
Apply
()
const
override
{
framework
::
OpDescBind
bind
;
bind
.
SetInput
(
"X"
,
Input
(
"X"
));
bind
.
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
bind
.
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
bind
.
SetAttrMap
(
Attrs
());
return
bind
;
}
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
pad
,
ops
::
PadOp
,
ops
::
PadOpMaker
,
pad_grad
,
ops
::
PadOpGrad
);
REGISTER_OPERATOR
(
pad
,
ops
::
PadOp
,
ops
::
PadOpMaker
,
ops
::
PadOpGradMaker
);
REGISTER_OPERATOR
(
pad_grad
,
ops
::
PadOpGrad
);
REGISTER_OP_CPU_KERNEL
(
pad
,
ops
::
PadKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
pad_grad
,
ops
::
PadGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/scale_op.cc
浏览文件 @
46c551b2
...
...
@@ -41,8 +41,8 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input tensor of scale operator."
)
.
NotInGradient
()
;
AddOutput
(
"Out"
,
"The output tensor of scale operator."
)
.
NotInGradient
()
;
AddInput
(
"X"
,
"The input tensor of scale operator."
);
AddOutput
(
"Out"
,
"The output tensor of scale operator."
);
AddComment
(
R"DOC(Scale operator
The equation is: Out = scale*X
...
...
@@ -52,21 +52,18 @@ The equation is: Out = scale*X
}
};
// The operator to calculate gradients of a scale operator is just the scale
// operator itself.
// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
template
<
typename
AttrType
>
class
ScaleGradOp
:
public
NetOp
{
class
ScaleGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
ScaleGradOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"scale"
,
{{
"X"
,
{
Input
(
framework
::
GradVarName
(
"Out"
))}}},
{{
"Out"
,
{
Output
(
framework
::
GradVarName
(
"X"
))}}},
{{
"scale"
,
Attr
<
AttrType
>
(
"scale"
)}}));
CompleteAddOp
(
false
);
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
framework
::
OpDescBind
Apply
()
const
override
{
framework
::
OpDescBind
grad_op
;
grad_op
.
SetType
(
"scale"
);
grad_op
.
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
.
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
.
SetAttr
(
"scale"
,
GetAttr
(
"scale"
));
return
grad_op
;
}
};
...
...
@@ -75,7 +72,7 @@ class ScaleGradOp : public NetOp {
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
scale
,
ops
::
ScaleOp
,
ops
::
ScaleOpMaker
<
float
>
,
scale_grad
,
ops
::
ScaleGradOp
<
float
>
);
REGISTER_OP
ERATOR
(
scale
,
ops
::
ScaleOp
,
ops
::
ScaleOpMaker
<
float
>
,
ops
::
ScaleGradMaker
);
REGISTER_OP_CPU_KERNEL
(
scale
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
46c551b2
...
...
@@ -27,15 +27,14 @@ class SoftmaxWithCrossEntropyOpMaker
AddInput
(
"Logits"
,
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number."
)
.
NotInGradient
();
AddInput
(
"Label"
,
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
"tensor. "
"If softLable is set to 0, Label is a Tensor<int> with shape [N x 1]. "
"If softLable is set to 1, Label is a Tensor<float/double> "
"with shape [N x K]."
);
"and K is the class number."
);
AddInput
(
"Label"
,
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
"tensor. "
"If softLable is set to 0, Label is a Tensor<int> with shape [N x "
"1]. "
"If softLable is set to 1, Label is a Tensor<float/double> "
"with shape [N x K]."
);
AddOutput
(
"Softmax"
,
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
...
...
@@ -163,15 +162,35 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
}
};
class
SoftmaxGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
framework
::
OpDescBind
Apply
()
const
override
{
framework
::
OpDescBind
grad_op
;
grad_op
.
SetType
(
"softmax_with_cross_entropy_grad"
);
grad_op
.
SetInput
(
"Label"
,
Input
(
"Label"
));
grad_op
.
SetInput
(
"Softmax"
,
Output
(
"Softmax"
));
grad_op
.
SetInput
(
"Loss"
,
Output
(
"Loss"
));
grad_op
.
SetInput
(
framework
::
GradVarName
(
"Softmax"
),
OutputGrad
(
"Softmax"
));
grad_op
.
SetInput
(
framework
::
GradVarName
(
"Loss"
),
OutputGrad
(
"Loss"
));
grad_op
.
SetOutput
(
framework
::
GradVarName
(
"Logits"
),
InputGrad
(
"Logits"
));
grad_op
.
SetAttrMap
(
Attrs
());
return
grad_op
;
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyOp
,
ops
::
SoftmaxWithCrossEntropyOpMaker
,
softmax_with_cross_entropy_grad
,
ops
::
SoftmaxWithCrossEntropyOpGrad
);
REGISTER_OPERATOR
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyOp
,
ops
::
SoftmaxWithCrossEntropyOpMaker
,
ops
::
SoftmaxWithCrossEntropyOpMaker
);
REGISTER_OPERATOR
(
softmax_with_cross_entropy_grad
,
ops
::
SoftmaxWithCrossEntropyOpGrad
);
REGISTER_OP_CPU_KERNEL
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
softmax_with_cross_entropy_grad
,
...
...
paddle/operators/sum_op.cc
浏览文件 @
46c551b2
...
...
@@ -45,10 +45,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SumOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"the input tensors of sum operator."
)
.
AsDuplicable
()
.
NotInGradient
();
AddOutput
(
"Out"
,
"the output tensor of sum operator."
).
NotInGradient
();
AddInput
(
"X"
,
"the input tensors of sum operator."
).
AsDuplicable
();
AddOutput
(
"Out"
,
"the output tensor of sum operator."
);
AddComment
(
R"DOC(
Sum the input tensors.
...
...
@@ -58,23 +56,25 @@ or not. But the output only shares the LoD with the first input.
}
};
class
SumGrad
Op
:
public
NetOp
{
class
SumGrad
Maker
:
public
framework
::
GradOpDescMakerBase
{
public:
SumGradOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
auto
&
x_grad_names
=
Outputs
(
framework
::
GradVarName
(
"X"
));
auto
out_grad_name
=
this
->
Input
(
framework
::
GradVarName
(
"Out"
));
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
framework
::
AttributeMap
grad_attrs
;
grad_attrs
[
"scale"
]
=
1.0
f
;
for
(
auto
&
x_grad_name
:
x_grad_names
)
{
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"scale"
,
{{
"X"
,
{
out_grad_name
}}},
{{
"Out"
,
{
x_grad_name
}}},
grad_attrs
));
}
CompleteAddOp
(
false
);
std
::
vector
<
framework
::
OpDescBind
>
operator
()()
const
override
{
auto
x_grads
=
InputGrad
(
"X"
);
std
::
vector
<
framework
::
OpDescBind
>
grad_ops
;
grad_ops
.
reserve
(
x_grads
.
size
());
auto
og
=
OutputGrad
(
"Out"
);
std
::
transform
(
x_grads
.
begin
(),
x_grads
.
end
(),
std
::
back_inserter
(
grad_ops
),
[
&
og
](
const
std
::
string
&
x_grad
)
{
framework
::
OpDescBind
grad_op
;
grad_op
.
SetType
(
"scale"
);
grad_op
.
SetInput
(
"X"
,
og
);
grad_op
.
SetOutput
(
"Out"
,
{
x_grad
});
grad_op
.
SetAttr
(
"scale"
,
1.0
f
);
return
grad_op
;
});
return
grad_ops
;
}
};
...
...
@@ -82,5 +82,6 @@ class SumGradOp : public NetOp {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
sum
,
ops
::
SumOp
,
ops
::
SumOpMaker
,
sum_grad
,
ops
::
SumGradOp
);
REGISTER_OPERATOR
(
sum
,
ops
::
SumOp
,
ops
::
SumOpMaker
,
ops
::
SumGradMaker
);
REGISTER_OP_CPU_KERNEL
(
sum
,
ops
::
SumKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录