提交 46c551b2 编写于 作者: Y Yu Yang

Complete Register Gradient in compile time

上级 479e4a50
......@@ -21,24 +21,34 @@
namespace paddle {
namespace framework {
using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext;
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").NotInGradient();
AddInput("b", "Bias of Add").NotInGradient();
AddOutput("Out", "Out of Add").NotInGradient();
AddInput("X", "Input X of Add");
AddInput("b", "Bias of Add");
AddOutput("Out", "Out of Add");
AddComment("Add Op");
}
};
class RowWiseAddGradMaker : public SingleGradOpDescMaker {
public:
using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
OpDescBind Apply() const override {
OpDescBind grad_op;
grad_op.SetInput(GradVarName("Out"), OutputGrad("Out"));
grad_op.SetOutput(GradVarName("X"), InputGrad("X"));
grad_op.SetOutput(GradVarName("b"), InputGrad("b"));
grad_op.SetType("rowwise_add_grad");
return grad_op;
}
};
class MulOpMaker : public OpProtoAndCheckerMaker {
public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
......@@ -148,8 +158,9 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework;
namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad,
f::NOP);
REGISTER_OPERATOR(rowwise_add, f::NOP, f::RowWiseAddOpMaker,
f::RowWiseAddGradMaker);
REGISTER_OPERATOR(rowwise_add_grad, f::NOP);
REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
......@@ -378,7 +389,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
+ 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/
+ 2U /* internal variable number*/);
std::cerr << grad_fc.DebugString() << std::endl;
EXPECT_EQ(grad_fc.Outputs(all).size(),
2UL /* input number of mul*/
......
......@@ -85,7 +85,6 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
info->proto_ = new OpProto;
info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_);
std::cerr << "Assign Maker " << op_type << std::endl;
maker.Validate();
info->proto_->set_type(op_type);
PADDLE_ENFORCE(
......
......@@ -66,7 +66,6 @@ message OpProto {
optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ];
optional bool not_in_gradient = 5 [ default = false ];
}
// AttrProto describes the C++ type Attribute.
......
......@@ -17,11 +17,14 @@
#include <map>
#include <string>
#include <unordered_map>
#include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/type_defs.h"
#include "paddle/platform/macros.h"
#include "glog/logging.h"
namespace paddle {
namespace framework {
......
......@@ -46,7 +46,6 @@ class Registrar {
template <typename... ARGS>
struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const char* op_type) : op_type(op_type) {
std::cerr << "Reg operator " << op_type << std::endl;
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type);
static_assert(sizeof...(ARGS) != 0,
......
......@@ -36,7 +36,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").NotInGradient();
AddOutput("Out", "The output of mean op");
AddComment(R"DOC( Mean Operator
)DOC");
}
......@@ -52,11 +52,28 @@ class MeanGradOp : public framework::OperatorWithKernel {
}
};
class MeanGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
framework::OpDescBind Apply() const override {
framework::OpDescBind grad_op;
grad_op.SetType("mean_grad");
grad_op.SetInput("X", Input("X"));
grad_op.SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op.SetOutput(framework::GradVarName("X"), InputGrad("X"));
return grad_op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp);
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mean_grad,
......
......@@ -49,9 +49,9 @@ class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MinusOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left tensor of minus operator.").NotInGradient();
AddInput("Y", "The right tensor of minus operator.").NotInGradient();
AddOutput("Out", "The output tensor of minus operator.").NotInGradient();
AddInput("X", "The left tensor of minus operator.");
AddInput("Y", "The right tensor of minus operator.");
AddOutput("Out", "The output tensor of minus operator.");
AddComment(R"DOC(Minus Operator
......@@ -64,26 +64,25 @@ or not. But the output only shares the LoD with input `X`.
)DOC");
}
};
template <typename AttrType>
class MinusGradOp : public NetOp {
class MinusGradMaker : public framework::GradOpDescMakerBase {
public:
MinusGradOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
auto out_grad = Input(framework::GradVarName("Out"));
auto x_grad = Output(framework::GradVarName("X"));
auto y_grad = Output(framework::GradVarName("Y"));
// x_grad = out_grad
AppendOp(framework::OpRegistry::CreateOp("identity", {{"X", {out_grad}}},
{{"Y", {x_grad}}}, {}));
framework::AttributeMap scale_attr;
scale_attr["scale"] = static_cast<AttrType>(-1);
AppendOp(framework::OpRegistry::CreateOp("scale", {{"X", {out_grad}}},
{{"Out", {y_grad}}}, scale_attr));
CompleteAddOp(false);
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<framework::OpDescBind> operator()() const override {
std::vector<framework::OpDescBind> ops;
ops.resize(2);
ops[0].SetType("scale");
ops[0].SetInput("X", OutputGrad("Out"));
ops[0].SetOutput("Out", InputGrad("X"));
ops[0].SetAttr("scale", 1.0f);
ops[1].SetType("scale");
ops[1].SetInput("X", OutputGrad("Out"));
ops[1].SetOutput("Out", InputGrad("Y"));
ops[1].SetAttr("scale", -1.0f);
return ops;
}
};
......@@ -91,7 +90,6 @@ class MinusGradOp : public NetOp {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad,
ops::MinusGradOp<float>);
REGISTER_OPERATOR(minus, ops::MinusOp, ops::MinusOpMaker, ops::MinusGradMaker);
REGISTER_OP_CPU_KERNEL(minus,
ops::MinusKernel<paddle::platform::CPUPlace, float>);
......@@ -56,8 +56,7 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker {
"The input should be a k-D tensor(k > 0 and k < 7)");
AddOutput("Out",
"The output of pad op."
"A tensor with the same shape as X.")
.NotInGradient();
"A tensor with the same shape as X.");
AddComment(R"DOC(
Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example:
......@@ -111,11 +110,28 @@ class PadOpGrad : public framework::OperatorWithKernel {
}
};
class PadOpGradMaker : public framework::SingleGradOpDescMaker {
protected:
framework::OpDescBind Apply() const override {
framework::OpDescBind bind;
bind.SetInput("X", Input("X"));
bind.SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind.SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind.SetAttrMap(Attrs());
return bind;
}
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(pad, ops::PadOp, ops::PadOpMaker, pad_grad, ops::PadOpGrad);
REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker, ops::PadOpGradMaker);
REGISTER_OPERATOR(pad_grad, ops::PadOpGrad);
REGISTER_OP_CPU_KERNEL(pad, ops::PadKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pad_grad,
ops::PadGradKernel<paddle::platform::CPUPlace, float>);
......@@ -41,8 +41,8 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient();
AddOutput("Out", "The output tensor of scale operator.").NotInGradient();
AddInput("X", "The input tensor of scale operator.");
AddOutput("Out", "The output tensor of scale operator.");
AddComment(R"DOC(Scale operator
The equation is: Out = scale*X
......@@ -52,21 +52,18 @@ The equation is: Out = scale*X
}
};
// The operator to calculate gradients of a scale operator is just the scale
// operator itself.
// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
template <typename AttrType>
class ScaleGradOp : public NetOp {
class ScaleGradMaker : public framework::SingleGradOpDescMaker {
public:
ScaleGradOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
AppendOp(framework::OpRegistry::CreateOp(
"scale", {{"X", {Input(framework::GradVarName("Out"))}}},
{{"Out", {Output(framework::GradVarName("X"))}}},
{{"scale", Attr<AttrType>("scale")}}));
CompleteAddOp(false);
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
framework::OpDescBind Apply() const override {
framework::OpDescBind grad_op;
grad_op.SetType("scale");
grad_op.SetInput("X", OutputGrad("Out"));
grad_op.SetOutput("Out", InputGrad("X"));
grad_op.SetAttr("scale", GetAttr("scale"));
return grad_op;
}
};
......@@ -75,7 +72,7 @@ class ScaleGradOp : public NetOp {
namespace ops = paddle::operators;
REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker<float>, scale_grad,
ops::ScaleGradOp<float>);
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker<float>,
ops::ScaleGradMaker);
REGISTER_OP_CPU_KERNEL(scale,
ops::ScaleKernel<paddle::platform::CPUPlace, float>);
......@@ -27,15 +27,14 @@ class SoftmaxWithCrossEntropyOpMaker
AddInput("Logits",
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number.")
.NotInGradient();
AddInput(
"Label",
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
"tensor. "
"If softLable is set to 0, Label is a Tensor<int> with shape [N x 1]. "
"If softLable is set to 1, Label is a Tensor<float/double> "
"with shape [N x K].");
"and K is the class number.");
AddInput("Label",
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
"tensor. "
"If softLable is set to 0, Label is a Tensor<int> with shape [N x "
"1]. "
"If softLable is set to 1, Label is a Tensor<float/double> "
"with shape [N x K].");
AddOutput(
"Softmax",
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
......@@ -163,15 +162,35 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
}
};
class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
framework::OpDescBind Apply() const override {
framework::OpDescBind grad_op;
grad_op.SetType("softmax_with_cross_entropy_grad");
grad_op.SetInput("Label", Input("Label"));
grad_op.SetInput("Softmax", Output("Softmax"));
grad_op.SetInput("Loss", Output("Loss"));
grad_op.SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax"));
grad_op.SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
grad_op.SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
grad_op.SetAttrMap(Attrs());
return grad_op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker,
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad);
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker,
ops::SoftmaxWithCrossEntropyOpMaker);
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<float>);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
......
......@@ -45,10 +45,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SumOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.")
.AsDuplicable()
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddInput("X", "the input tensors of sum operator.").AsDuplicable();
AddOutput("Out", "the output tensor of sum operator.");
AddComment(R"DOC(
Sum the input tensors.
......@@ -58,23 +56,25 @@ or not. But the output only shares the LoD with the first input.
}
};
class SumGradOp : public NetOp {
class SumGradMaker : public framework::GradOpDescMakerBase {
public:
SumGradOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: NetOp(type, inputs, outputs, attrs) {
auto& x_grad_names = Outputs(framework::GradVarName("X"));
auto out_grad_name = this->Input(framework::GradVarName("Out"));
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
framework::AttributeMap grad_attrs;
grad_attrs["scale"] = 1.0f;
for (auto& x_grad_name : x_grad_names) {
AppendOp(framework::OpRegistry::CreateOp(
"scale", {{"X", {out_grad_name}}}, {{"Out", {x_grad_name}}},
grad_attrs));
}
CompleteAddOp(false);
std::vector<framework::OpDescBind> operator()() const override {
auto x_grads = InputGrad("X");
std::vector<framework::OpDescBind> grad_ops;
grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Out");
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
[&og](const std::string& x_grad) {
framework::OpDescBind grad_op;
grad_op.SetType("scale");
grad_op.SetInput("X", og);
grad_op.SetOutput("Out", {x_grad});
grad_op.SetAttr("scale", 1.0f);
return grad_op;
});
return grad_ops;
}
};
......@@ -82,5 +82,6 @@ class SumGradOp : public NetOp {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker);
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册