Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
46c551b2
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
46c551b2
编写于
10月 02, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Complete Register Gradient in compile time
上级
479e4a50
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
152 addition
and
94 deletion
+152
-94
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+21
-11
paddle/framework/details/op_registry.h
paddle/framework/details/op_registry.h
+0
-1
paddle/framework/framework.proto
paddle/framework/framework.proto
+0
-1
paddle/framework/op_info.h
paddle/framework/op_info.h
+3
-0
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+0
-1
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+19
-2
paddle/operators/minus_op.cc
paddle/operators/minus_op.cc
+22
-24
paddle/operators/pad_op.cc
paddle/operators/pad_op.cc
+19
-3
paddle/operators/scale_op.cc
paddle/operators/scale_op.cc
+15
-18
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+32
-13
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+21
-20
未找到文件。
paddle/framework/backward_test.cc
浏览文件 @
46c551b2
...
...
@@ -21,24 +21,34 @@
namespace
paddle
{
namespace
framework
{
using
OperatorBase
=
framework
::
OperatorBase
;
using
OpProtoAndCheckerMaker
=
framework
::
OpProtoAndCheckerMaker
;
using
OpProto
=
framework
::
OpProto
;
using
OpAttrChecker
=
framework
::
OpAttrChecker
;
using
Scope
=
framework
::
Scope
;
using
DeviceContext
=
platform
::
DeviceContext
;
class
RowWiseAddOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input X of Add"
)
.
NotInGradient
()
;
AddInput
(
"b"
,
"Bias of Add"
)
.
NotInGradient
()
;
AddOutput
(
"Out"
,
"Out of Add"
)
.
NotInGradient
()
;
AddInput
(
"X"
,
"Input X of Add"
);
AddInput
(
"b"
,
"Bias of Add"
);
AddOutput
(
"Out"
,
"Out of Add"
);
AddComment
(
"Add Op"
);
}
};
class
RowWiseAddGradMaker
:
public
SingleGradOpDescMaker
{
public:
using
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
OpDescBind
Apply
()
const
override
{
OpDescBind
grad_op
;
grad_op
.
SetInput
(
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
.
SetOutput
(
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
.
SetOutput
(
GradVarName
(
"b"
),
InputGrad
(
"b"
));
grad_op
.
SetType
(
"rowwise_add_grad"
);
return
grad_op
;
}
};
class
MulOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
...
...
@@ -148,8 +158,9 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
namespace
f
=
paddle
::
framework
;
namespace
ops
=
paddle
::
operators
;
using
EnforceNotMet
=
paddle
::
platform
::
EnforceNotMet
;
REGISTER_OP
(
rowwise_add
,
f
::
NOP
,
f
::
RowWiseAddOpMaker
,
rowwise_add_grad
,
f
::
NOP
);
REGISTER_OPERATOR
(
rowwise_add
,
f
::
NOP
,
f
::
RowWiseAddOpMaker
,
f
::
RowWiseAddGradMaker
);
REGISTER_OPERATOR
(
rowwise_add_grad
,
f
::
NOP
);
REGISTER_OP
(
mul
,
f
::
NOP
,
f
::
MulOpMaker
,
mul_grad
,
f
::
NOP
);
REGISTER_OP
(
sigmoid
,
f
::
NOP
,
f
::
SigmoidOpMaker
,
sigmoid_grad
,
f
::
NOP
);
REGISTER_OP_WITHOUT_GRADIENT
(
nograd
,
f
::
NOP
,
f
::
NoGradOpMaker
);
...
...
@@ -378,7 +389,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
+
1UL
/* external output number*/
+
1UL
/* number of gradient of external output*/
+
2U
/* internal variable number*/
);
std
::
cerr
<<
grad_fc
.
DebugString
()
<<
std
::
endl
;
EXPECT_EQ
(
grad_fc
.
Outputs
(
all
).
size
(),
2UL
/* input number of mul*/
...
...
paddle/framework/details/op_registry.h
浏览文件 @
46c551b2
...
...
@@ -85,7 +85,6 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
info
->
proto_
=
new
OpProto
;
info
->
checker_
=
new
OpAttrChecker
();
auto
maker
=
T
(
info
->
proto_
,
info
->
checker_
);
std
::
cerr
<<
"Assign Maker "
<<
op_type
<<
std
::
endl
;
maker
.
Validate
();
info
->
proto_
->
set_type
(
op_type
);
PADDLE_ENFORCE
(
...
...
paddle/framework/framework.proto
浏览文件 @
46c551b2
...
...
@@ -66,7 +66,6 @@ message OpProto {
optional
bool
duplicable
=
3
[
default
=
false
];
optional
bool
intermediate
=
4
[
default
=
false
];
optional
bool
not_in_gradient
=
5
[
default
=
false
];
}
// AttrProto describes the C++ type Attribute.
...
...
paddle/framework/op_info.h
浏览文件 @
46c551b2
...
...
@@ -17,11 +17,14 @@
#include <map>
#include <string>
#include <unordered_map>
#include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/type_defs.h"
#include "paddle/platform/macros.h"
#include "glog/logging.h"
namespace
paddle
{
namespace
framework
{
...
...
paddle/framework/op_registry.h
浏览文件 @
46c551b2
...
...
@@ -46,7 +46,6 @@ class Registrar {
template
<
typename
...
ARGS
>
struct
OperatorRegistrar
:
public
Registrar
{
explicit
OperatorRegistrar
(
const
char
*
op_type
)
:
op_type
(
op_type
)
{
std
::
cerr
<<
"Reg operator "
<<
op_type
<<
std
::
endl
;
PADDLE_ENFORCE
(
!
OpInfoMap
::
Instance
().
Has
(
op_type
),
"'%s' is registered more than once."
,
op_type
);
static_assert
(
sizeof
...(
ARGS
)
!=
0
,
...
...
paddle/operators/mean_op.cc
浏览文件 @
46c551b2
...
...
@@ -36,7 +36,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input of mean op"
);
AddOutput
(
"Out"
,
"The output of mean op"
)
.
NotInGradient
()
;
AddOutput
(
"Out"
,
"The output of mean op"
);
AddComment
(
R"DOC( Mean Operator
)DOC"
);
}
...
...
@@ -52,11 +52,28 @@ class MeanGradOp : public framework::OperatorWithKernel {
}
};
class
MeanGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
framework
::
OpDescBind
Apply
()
const
override
{
framework
::
OpDescBind
grad_op
;
grad_op
.
SetType
(
"mean_grad"
);
grad_op
.
SetInput
(
"X"
,
Input
(
"X"
));
grad_op
.
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
.
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
return
grad_op
;
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
mean
,
ops
::
MeanOp
,
ops
::
MeanOpMaker
,
mean_grad
,
ops
::
MeanGradOp
);
REGISTER_OPERATOR
(
mean
,
ops
::
MeanOp
,
ops
::
MeanOpMaker
,
ops
::
MeanGradMaker
);
REGISTER_OPERATOR
(
mean_grad
,
ops
::
MeanGradOp
);
REGISTER_OP_CPU_KERNEL
(
mean
,
ops
::
MeanKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
mean_grad
,
...
...
paddle/operators/minus_op.cc
浏览文件 @
46c551b2
...
...
@@ -49,9 +49,9 @@ class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MinusOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The left tensor of minus operator."
)
.
NotInGradient
()
;
AddInput
(
"Y"
,
"The right tensor of minus operator."
)
.
NotInGradient
()
;
AddOutput
(
"Out"
,
"The output tensor of minus operator."
)
.
NotInGradient
()
;
AddInput
(
"X"
,
"The left tensor of minus operator."
);
AddInput
(
"Y"
,
"The right tensor of minus operator."
);
AddOutput
(
"Out"
,
"The output tensor of minus operator."
);
AddComment
(
R"DOC(Minus Operator
...
...
@@ -64,26 +64,25 @@ or not. But the output only shares the LoD with input `X`.
)DOC"
);
}
};
template
<
typename
AttrType
>
class
MinusGrad
Op
:
public
NetOp
{
class
MinusGrad
Maker
:
public
framework
::
GradOpDescMakerBase
{
public:
MinusGradOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
auto
out_grad
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
x_grad
=
Output
(
framework
::
GradVarName
(
"X"
));
auto
y_grad
=
Output
(
framework
::
GradVarName
(
"Y"
));
// x_grad = out_grad
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"identity"
,
{{
"X"
,
{
out_grad
}}},
{{
"Y"
,
{
x_grad
}}},
{}));
framework
::
AttributeMap
scale_attr
;
scale_attr
[
"scale"
]
=
static_cast
<
AttrType
>
(
-
1
);
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"scale"
,
{{
"X"
,
{
out_grad
}}},
{{
"Out"
,
{
y_grad
}}},
scale_attr
));
CompleteAddOp
(
false
);
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
std
::
vector
<
framework
::
OpDescBind
>
operator
()()
const
override
{
std
::
vector
<
framework
::
OpDescBind
>
ops
;
ops
.
resize
(
2
);
ops
[
0
].
SetType
(
"scale"
);
ops
[
0
].
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
ops
[
0
].
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
ops
[
0
].
SetAttr
(
"scale"
,
1.0
f
);
ops
[
1
].
SetType
(
"scale"
);
ops
[
1
].
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
ops
[
1
].
SetOutput
(
"Out"
,
InputGrad
(
"Y"
));
ops
[
1
].
SetAttr
(
"scale"
,
-
1.0
f
);
return
ops
;
}
};
...
...
@@ -91,7 +90,6 @@ class MinusGradOp : public NetOp {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
minus
,
ops
::
MinusOp
,
ops
::
MinusOpMaker
,
minus_grad
,
ops
::
MinusGradOp
<
float
>
);
REGISTER_OPERATOR
(
minus
,
ops
::
MinusOp
,
ops
::
MinusOpMaker
,
ops
::
MinusGradMaker
);
REGISTER_OP_CPU_KERNEL
(
minus
,
ops
::
MinusKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/pad_op.cc
浏览文件 @
46c551b2
...
...
@@ -56,8 +56,7 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker {
"The input should be a k-D tensor(k > 0 and k < 7)"
);
AddOutput
(
"Out"
,
"The output of pad op."
"A tensor with the same shape as X."
)
.
NotInGradient
();
"A tensor with the same shape as X."
);
AddComment
(
R"DOC(
Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example:
...
...
@@ -111,11 +110,28 @@ class PadOpGrad : public framework::OperatorWithKernel {
}
};
class
PadOpGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
protected:
framework
::
OpDescBind
Apply
()
const
override
{
framework
::
OpDescBind
bind
;
bind
.
SetInput
(
"X"
,
Input
(
"X"
));
bind
.
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
bind
.
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
bind
.
SetAttrMap
(
Attrs
());
return
bind
;
}
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
pad
,
ops
::
PadOp
,
ops
::
PadOpMaker
,
pad_grad
,
ops
::
PadOpGrad
);
REGISTER_OPERATOR
(
pad
,
ops
::
PadOp
,
ops
::
PadOpMaker
,
ops
::
PadOpGradMaker
);
REGISTER_OPERATOR
(
pad_grad
,
ops
::
PadOpGrad
);
REGISTER_OP_CPU_KERNEL
(
pad
,
ops
::
PadKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
pad_grad
,
ops
::
PadGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/scale_op.cc
浏览文件 @
46c551b2
...
...
@@ -41,8 +41,8 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input tensor of scale operator."
)
.
NotInGradient
()
;
AddOutput
(
"Out"
,
"The output tensor of scale operator."
)
.
NotInGradient
()
;
AddInput
(
"X"
,
"The input tensor of scale operator."
);
AddOutput
(
"Out"
,
"The output tensor of scale operator."
);
AddComment
(
R"DOC(Scale operator
The equation is: Out = scale*X
...
...
@@ -52,21 +52,18 @@ The equation is: Out = scale*X
}
};
// The operator to calculate gradients of a scale operator is just the scale
// operator itself.
// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
template
<
typename
AttrType
>
class
ScaleGradOp
:
public
NetOp
{
class
ScaleGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
ScaleGradOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"scale"
,
{{
"X"
,
{
Input
(
framework
::
GradVarName
(
"Out"
))}}},
{{
"Out"
,
{
Output
(
framework
::
GradVarName
(
"X"
))}}},
{{
"scale"
,
Attr
<
AttrType
>
(
"scale"
)}}));
CompleteAddOp
(
false
);
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
framework
::
OpDescBind
Apply
()
const
override
{
framework
::
OpDescBind
grad_op
;
grad_op
.
SetType
(
"scale"
);
grad_op
.
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
.
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
.
SetAttr
(
"scale"
,
GetAttr
(
"scale"
));
return
grad_op
;
}
};
...
...
@@ -75,7 +72,7 @@ class ScaleGradOp : public NetOp {
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
scale
,
ops
::
ScaleOp
,
ops
::
ScaleOpMaker
<
float
>
,
scale_grad
,
ops
::
ScaleGradOp
<
float
>
);
REGISTER_OP
ERATOR
(
scale
,
ops
::
ScaleOp
,
ops
::
ScaleOpMaker
<
float
>
,
ops
::
ScaleGradMaker
);
REGISTER_OP_CPU_KERNEL
(
scale
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
46c551b2
...
...
@@ -27,13 +27,12 @@ class SoftmaxWithCrossEntropyOpMaker
AddInput
(
"Logits"
,
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number."
)
.
NotInGradient
();
AddInput
(
"Label"
,
"and K is the class number."
);
AddInput
(
"Label"
,
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
"tensor. "
"If softLable is set to 0, Label is a Tensor<int> with shape [N x 1]. "
"If softLable is set to 0, Label is a Tensor<int> with shape [N x "
"1]. "
"If softLable is set to 1, Label is a Tensor<float/double> "
"with shape [N x K]."
);
AddOutput
(
...
...
@@ -163,14 +162,34 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
}
};
class
SoftmaxGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
framework
::
OpDescBind
Apply
()
const
override
{
framework
::
OpDescBind
grad_op
;
grad_op
.
SetType
(
"softmax_with_cross_entropy_grad"
);
grad_op
.
SetInput
(
"Label"
,
Input
(
"Label"
));
grad_op
.
SetInput
(
"Softmax"
,
Output
(
"Softmax"
));
grad_op
.
SetInput
(
"Loss"
,
Output
(
"Loss"
));
grad_op
.
SetInput
(
framework
::
GradVarName
(
"Softmax"
),
OutputGrad
(
"Softmax"
));
grad_op
.
SetInput
(
framework
::
GradVarName
(
"Loss"
),
OutputGrad
(
"Loss"
));
grad_op
.
SetOutput
(
framework
::
GradVarName
(
"Logits"
),
InputGrad
(
"Logits"
));
grad_op
.
SetAttrMap
(
Attrs
());
return
grad_op
;
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyOp
,
REGISTER_OP
ERATOR
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyOp
,
ops
::
SoftmaxWithCrossEntropyOpMaker
,
softmax_with_cross_entropy_grad
,
ops
::
SoftmaxWithCrossEntropyOpMaker
);
REGISTER_OPERATOR
(
softmax_with_cross_entropy_grad
,
ops
::
SoftmaxWithCrossEntropyOpGrad
);
REGISTER_OP_CPU_KERNEL
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyKernel
<
float
>
);
...
...
paddle/operators/sum_op.cc
浏览文件 @
46c551b2
...
...
@@ -45,10 +45,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SumOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"the input tensors of sum operator."
)
.
AsDuplicable
()
.
NotInGradient
();
AddOutput
(
"Out"
,
"the output tensor of sum operator."
).
NotInGradient
();
AddInput
(
"X"
,
"the input tensors of sum operator."
).
AsDuplicable
();
AddOutput
(
"Out"
,
"the output tensor of sum operator."
);
AddComment
(
R"DOC(
Sum the input tensors.
...
...
@@ -58,23 +56,25 @@ or not. But the output only shares the LoD with the first input.
}
};
class
SumGrad
Op
:
public
NetOp
{
class
SumGrad
Maker
:
public
framework
::
GradOpDescMakerBase
{
public:
SumGradOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
auto
&
x_grad_names
=
Outputs
(
framework
::
GradVarName
(
"X"
));
auto
out_grad_name
=
this
->
Input
(
framework
::
GradVarName
(
"Out"
));
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
framework
::
AttributeMap
grad_attrs
;
grad_attrs
[
"scale"
]
=
1.0
f
;
for
(
auto
&
x_grad_name
:
x_grad_names
)
{
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"scale"
,
{{
"X"
,
{
out_grad_name
}}},
{{
"Out"
,
{
x_grad_name
}}},
grad_attrs
));
}
CompleteAddOp
(
false
);
std
::
vector
<
framework
::
OpDescBind
>
operator
()()
const
override
{
auto
x_grads
=
InputGrad
(
"X"
);
std
::
vector
<
framework
::
OpDescBind
>
grad_ops
;
grad_ops
.
reserve
(
x_grads
.
size
());
auto
og
=
OutputGrad
(
"Out"
);
std
::
transform
(
x_grads
.
begin
(),
x_grads
.
end
(),
std
::
back_inserter
(
grad_ops
),
[
&
og
](
const
std
::
string
&
x_grad
)
{
framework
::
OpDescBind
grad_op
;
grad_op
.
SetType
(
"scale"
);
grad_op
.
SetInput
(
"X"
,
og
);
grad_op
.
SetOutput
(
"Out"
,
{
x_grad
});
grad_op
.
SetAttr
(
"scale"
,
1.0
f
);
return
grad_op
;
});
return
grad_ops
;
}
};
...
...
@@ -82,5 +82,6 @@ class SumGradOp : public NetOp {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
sum
,
ops
::
SumOp
,
ops
::
SumOpMaker
,
sum_grad
,
ops
::
SumGradOp
);
REGISTER_OPERATOR
(
sum
,
ops
::
SumOp
,
ops
::
SumOpMaker
,
ops
::
SumGradMaker
);
REGISTER_OP_CPU_KERNEL
(
sum
,
ops
::
SumKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录