提交 46c551b2 编写于 作者: Y Yu Yang

Complete Register Gradient in compile time

上级 479e4a50
...@@ -21,24 +21,34 @@ ...@@ -21,24 +21,34 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext; using DeviceContext = platform::DeviceContext;
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public: public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").NotInGradient(); AddInput("X", "Input X of Add");
AddInput("b", "Bias of Add").NotInGradient(); AddInput("b", "Bias of Add");
AddOutput("Out", "Out of Add").NotInGradient(); AddOutput("Out", "Out of Add");
AddComment("Add Op"); AddComment("Add Op");
} }
}; };
class RowWiseAddGradMaker : public SingleGradOpDescMaker {
public:
using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
OpDescBind Apply() const override {
OpDescBind grad_op;
grad_op.SetInput(GradVarName("Out"), OutputGrad("Out"));
grad_op.SetOutput(GradVarName("X"), InputGrad("X"));
grad_op.SetOutput(GradVarName("b"), InputGrad("b"));
grad_op.SetType("rowwise_add_grad");
return grad_op;
}
};
class MulOpMaker : public OpProtoAndCheckerMaker { class MulOpMaker : public OpProtoAndCheckerMaker {
public: public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
...@@ -148,8 +158,9 @@ class AddOpMaker : public OpProtoAndCheckerMaker { ...@@ -148,8 +158,9 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework; namespace f = paddle::framework;
namespace ops = paddle::operators; namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet; using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad, REGISTER_OPERATOR(rowwise_add, f::NOP, f::RowWiseAddOpMaker,
f::NOP); f::RowWiseAddGradMaker);
REGISTER_OPERATOR(rowwise_add_grad, f::NOP);
REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
...@@ -378,7 +389,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -378,7 +389,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
+ 1UL /* external output number*/ + 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/ + 1UL /* number of gradient of external output*/
+ 2U /* internal variable number*/); + 2U /* internal variable number*/);
std::cerr << grad_fc.DebugString() << std::endl;
EXPECT_EQ(grad_fc.Outputs(all).size(), EXPECT_EQ(grad_fc.Outputs(all).size(),
2UL /* input number of mul*/ 2UL /* input number of mul*/
......
...@@ -85,7 +85,6 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { ...@@ -85,7 +85,6 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
info->proto_ = new OpProto; info->proto_ = new OpProto;
info->checker_ = new OpAttrChecker(); info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_); auto maker = T(info->proto_, info->checker_);
std::cerr << "Assign Maker " << op_type << std::endl;
maker.Validate(); maker.Validate();
info->proto_->set_type(op_type); info->proto_->set_type(op_type);
PADDLE_ENFORCE( PADDLE_ENFORCE(
......
...@@ -66,7 +66,6 @@ message OpProto { ...@@ -66,7 +66,6 @@ message OpProto {
optional bool duplicable = 3 [ default = false ]; optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ]; optional bool intermediate = 4 [ default = false ];
optional bool not_in_gradient = 5 [ default = false ];
} }
// AttrProto describes the C++ type Attribute. // AttrProto describes the C++ type Attribute.
......
...@@ -17,11 +17,14 @@ ...@@ -17,11 +17,14 @@
#include <map> #include <map>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include "paddle/framework/type_defs.h" #include "paddle/framework/type_defs.h"
#include "paddle/platform/macros.h" #include "paddle/platform/macros.h"
#include "glog/logging.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -46,7 +46,6 @@ class Registrar { ...@@ -46,7 +46,6 @@ class Registrar {
template <typename... ARGS> template <typename... ARGS>
struct OperatorRegistrar : public Registrar { struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const char* op_type) : op_type(op_type) { explicit OperatorRegistrar(const char* op_type) : op_type(op_type) {
std::cerr << "Reg operator " << op_type << std::endl;
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type); "'%s' is registered more than once.", op_type);
static_assert(sizeof...(ARGS) != 0, static_assert(sizeof...(ARGS) != 0,
......
...@@ -36,7 +36,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -36,7 +36,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) MeanOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").NotInGradient(); AddOutput("Out", "The output of mean op");
AddComment(R"DOC( Mean Operator AddComment(R"DOC( Mean Operator
)DOC"); )DOC");
} }
...@@ -52,11 +52,28 @@ class MeanGradOp : public framework::OperatorWithKernel { ...@@ -52,11 +52,28 @@ class MeanGradOp : public framework::OperatorWithKernel {
} }
}; };
class MeanGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
framework::OpDescBind Apply() const override {
framework::OpDescBind grad_op;
grad_op.SetType("mean_grad");
grad_op.SetInput("X", Input("X"));
grad_op.SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op.SetOutput(framework::GradVarName("X"), InputGrad("X"));
return grad_op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp);
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean, REGISTER_OP_CPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::CPUPlace, float>); ops::MeanKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mean_grad, REGISTER_OP_CPU_KERNEL(mean_grad,
......
...@@ -49,9 +49,9 @@ class MinusOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,9 +49,9 @@ class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MinusOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MinusOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left tensor of minus operator.").NotInGradient(); AddInput("X", "The left tensor of minus operator.");
AddInput("Y", "The right tensor of minus operator.").NotInGradient(); AddInput("Y", "The right tensor of minus operator.");
AddOutput("Out", "The output tensor of minus operator.").NotInGradient(); AddOutput("Out", "The output tensor of minus operator.");
AddComment(R"DOC(Minus Operator AddComment(R"DOC(Minus Operator
...@@ -64,26 +64,25 @@ or not. But the output only shares the LoD with input `X`. ...@@ -64,26 +64,25 @@ or not. But the output only shares the LoD with input `X`.
)DOC"); )DOC");
} }
}; };
template <typename AttrType>
class MinusGradOp : public NetOp { class MinusGradMaker : public framework::GradOpDescMakerBase {
public: public:
MinusGradOp(const std::string &type, const framework::VariableNameMap &inputs, using framework::GradOpDescMakerBase::GradOpDescMakerBase;
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) std::vector<framework::OpDescBind> operator()() const override {
: NetOp(type, inputs, outputs, attrs) { std::vector<framework::OpDescBind> ops;
auto out_grad = Input(framework::GradVarName("Out")); ops.resize(2);
auto x_grad = Output(framework::GradVarName("X"));
auto y_grad = Output(framework::GradVarName("Y")); ops[0].SetType("scale");
ops[0].SetInput("X", OutputGrad("Out"));
// x_grad = out_grad ops[0].SetOutput("Out", InputGrad("X"));
AppendOp(framework::OpRegistry::CreateOp("identity", {{"X", {out_grad}}}, ops[0].SetAttr("scale", 1.0f);
{{"Y", {x_grad}}}, {}));
ops[1].SetType("scale");
framework::AttributeMap scale_attr; ops[1].SetInput("X", OutputGrad("Out"));
scale_attr["scale"] = static_cast<AttrType>(-1); ops[1].SetOutput("Out", InputGrad("Y"));
AppendOp(framework::OpRegistry::CreateOp("scale", {{"X", {out_grad}}}, ops[1].SetAttr("scale", -1.0f);
{{"Out", {y_grad}}}, scale_attr)); return ops;
CompleteAddOp(false);
} }
}; };
...@@ -91,7 +90,6 @@ class MinusGradOp : public NetOp { ...@@ -91,7 +90,6 @@ class MinusGradOp : public NetOp {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad, REGISTER_OPERATOR(minus, ops::MinusOp, ops::MinusOpMaker, ops::MinusGradMaker);
ops::MinusGradOp<float>);
REGISTER_OP_CPU_KERNEL(minus, REGISTER_OP_CPU_KERNEL(minus,
ops::MinusKernel<paddle::platform::CPUPlace, float>); ops::MinusKernel<paddle::platform::CPUPlace, float>);
...@@ -56,8 +56,7 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -56,8 +56,7 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker {
"The input should be a k-D tensor(k > 0 and k < 7)"); "The input should be a k-D tensor(k > 0 and k < 7)");
AddOutput("Out", AddOutput("Out",
"The output of pad op." "The output of pad op."
"A tensor with the same shape as X.") "A tensor with the same shape as X.");
.NotInGradient();
AddComment(R"DOC( AddComment(R"DOC(
Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example: Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example:
...@@ -111,11 +110,28 @@ class PadOpGrad : public framework::OperatorWithKernel { ...@@ -111,11 +110,28 @@ class PadOpGrad : public framework::OperatorWithKernel {
} }
}; };
class PadOpGradMaker : public framework::SingleGradOpDescMaker {
protected:
framework::OpDescBind Apply() const override {
framework::OpDescBind bind;
bind.SetInput("X", Input("X"));
bind.SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind.SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind.SetAttrMap(Attrs());
return bind;
}
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(pad, ops::PadOp, ops::PadOpMaker, pad_grad, ops::PadOpGrad);
REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker, ops::PadOpGradMaker);
REGISTER_OPERATOR(pad_grad, ops::PadOpGrad);
REGISTER_OP_CPU_KERNEL(pad, ops::PadKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(pad, ops::PadKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pad_grad, REGISTER_OP_CPU_KERNEL(pad_grad,
ops::PadGradKernel<paddle::platform::CPUPlace, float>); ops::PadGradKernel<paddle::platform::CPUPlace, float>);
...@@ -41,8 +41,8 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -41,8 +41,8 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient(); AddInput("X", "The input tensor of scale operator.");
AddOutput("Out", "The output tensor of scale operator.").NotInGradient(); AddOutput("Out", "The output tensor of scale operator.");
AddComment(R"DOC(Scale operator AddComment(R"DOC(Scale operator
The equation is: Out = scale*X The equation is: Out = scale*X
...@@ -52,21 +52,18 @@ The equation is: Out = scale*X ...@@ -52,21 +52,18 @@ The equation is: Out = scale*X
} }
}; };
// The operator to calculate gradients of a scale operator is just the scale class ScaleGradMaker : public framework::SingleGradOpDescMaker {
// operator itself.
// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
template <typename AttrType>
class ScaleGradOp : public NetOp {
public: public:
ScaleGradOp(const std::string &type, const framework::VariableNameMap &inputs, using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) protected:
: NetOp(type, inputs, outputs, attrs) { framework::OpDescBind Apply() const override {
AppendOp(framework::OpRegistry::CreateOp( framework::OpDescBind grad_op;
"scale", {{"X", {Input(framework::GradVarName("Out"))}}}, grad_op.SetType("scale");
{{"Out", {Output(framework::GradVarName("X"))}}}, grad_op.SetInput("X", OutputGrad("Out"));
{{"scale", Attr<AttrType>("scale")}})); grad_op.SetOutput("Out", InputGrad("X"));
CompleteAddOp(false); grad_op.SetAttr("scale", GetAttr("scale"));
return grad_op;
} }
}; };
...@@ -75,7 +72,7 @@ class ScaleGradOp : public NetOp { ...@@ -75,7 +72,7 @@ class ScaleGradOp : public NetOp {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker<float>, scale_grad, REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker<float>,
ops::ScaleGradOp<float>); ops::ScaleGradMaker);
REGISTER_OP_CPU_KERNEL(scale, REGISTER_OP_CPU_KERNEL(scale,
ops::ScaleKernel<paddle::platform::CPUPlace, float>); ops::ScaleKernel<paddle::platform::CPUPlace, float>);
...@@ -27,15 +27,14 @@ class SoftmaxWithCrossEntropyOpMaker ...@@ -27,15 +27,14 @@ class SoftmaxWithCrossEntropyOpMaker
AddInput("Logits", AddInput("Logits",
"(Tensor, default: Tensor<float>), The unscaled log probabilities " "(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, " "which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number.") "and K is the class number.");
.NotInGradient(); AddInput("Label",
AddInput( "(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
"Label", "tensor. "
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D " "If softLable is set to 0, Label is a Tensor<int> with shape [N x "
"tensor. " "1]. "
"If softLable is set to 0, Label is a Tensor<int> with shape [N x 1]. " "If softLable is set to 1, Label is a Tensor<float/double> "
"If softLable is set to 1, Label is a Tensor<float/double> " "with shape [N x K].");
"with shape [N x K].");
AddOutput( AddOutput(
"Softmax", "Softmax",
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. " "(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
...@@ -163,15 +162,35 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -163,15 +162,35 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
} }
}; };
class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
framework::OpDescBind Apply() const override {
framework::OpDescBind grad_op;
grad_op.SetType("softmax_with_cross_entropy_grad");
grad_op.SetInput("Label", Input("Label"));
grad_op.SetInput("Softmax", Output("Softmax"));
grad_op.SetInput("Loss", Output("Loss"));
grad_op.SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax"));
grad_op.SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
grad_op.SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
grad_op.SetAttrMap(Attrs());
return grad_op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxWithCrossEntropyOpMaker,
softmax_with_cross_entropy_grad, ops::SoftmaxWithCrossEntropyOpMaker);
ops::SoftmaxWithCrossEntropyOpGrad); REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<float>); ops::SoftmaxWithCrossEntropyKernel<float>);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad, REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
......
...@@ -45,10 +45,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -45,10 +45,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SumOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) SumOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.") AddInput("X", "the input tensors of sum operator.").AsDuplicable();
.AsDuplicable() AddOutput("Out", "the output tensor of sum operator.");
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddComment(R"DOC( AddComment(R"DOC(
Sum the input tensors. Sum the input tensors.
...@@ -58,23 +56,25 @@ or not. But the output only shares the LoD with the first input. ...@@ -58,23 +56,25 @@ or not. But the output only shares the LoD with the first input.
} }
}; };
class SumGradOp : public NetOp { class SumGradMaker : public framework::GradOpDescMakerBase {
public: public:
SumGradOp(const std::string& type, const framework::VariableNameMap& inputs, using framework::GradOpDescMakerBase::GradOpDescMakerBase;
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: NetOp(type, inputs, outputs, attrs) {
auto& x_grad_names = Outputs(framework::GradVarName("X"));
auto out_grad_name = this->Input(framework::GradVarName("Out"));
framework::AttributeMap grad_attrs; std::vector<framework::OpDescBind> operator()() const override {
grad_attrs["scale"] = 1.0f; auto x_grads = InputGrad("X");
for (auto& x_grad_name : x_grad_names) { std::vector<framework::OpDescBind> grad_ops;
AppendOp(framework::OpRegistry::CreateOp( grad_ops.reserve(x_grads.size());
"scale", {{"X", {out_grad_name}}}, {{"Out", {x_grad_name}}}, auto og = OutputGrad("Out");
grad_attrs)); std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
} [&og](const std::string& x_grad) {
CompleteAddOp(false); framework::OpDescBind grad_op;
grad_op.SetType("scale");
grad_op.SetInput("X", og);
grad_op.SetOutput("Out", {x_grad});
grad_op.SetAttr("scale", 1.0f);
return grad_op;
});
return grad_ops;
} }
}; };
...@@ -82,5 +82,6 @@ class SumGradOp : public NetOp { ...@@ -82,5 +82,6 @@ class SumGradOp : public NetOp {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker);
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册