提交 2ea2fbea 编写于 作者: F fengjiayi

Merge REGISTER_OP and REGISTER_GRADIENT_OP

上级 6768b310
...@@ -150,20 +150,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker { ...@@ -150,20 +150,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework; namespace f = paddle::framework;
namespace ops = paddle::operators; namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet; using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker, rowwise_add_grad); REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker, rowwise_add_grad,
REGISTER_GRADIENT_OP(rowwise_add_grad, f::EmptyOp); f::EmptyOp);
REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker, mul_grad); REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker, mul_grad, f::EmptyOp);
REGISTER_GRADIENT_OP(mul_grad, f::EmptyOp); REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker, sigmoid_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker, sigmoid_grad);
REGISTER_GRADIENT_OP(sigmoid_grad, f::EmptyOp);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::EmptyOp, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::EmptyOp, f::NoGradOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker, add_grad); REGISTER_OP(add, f::EmptyOp, f::AddOpMaker, add_grad, f::EmptyOp);
REGISTER_GRADIENT_OP(add_grad, f::EmptyOp);
REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker, REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker,
many_output_op_grad); many_output_op_grad, f::EmptyOp);
REGISTER_GRADIENT_OP(many_output_op_grad, f::EmptyOp);
TEST(Backward, simple_op_grad) { TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
......
...@@ -8,13 +8,6 @@ USE_OP(add_two); ...@@ -8,13 +8,6 @@ USE_OP(add_two);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class NOP : public OperatorBase {
public:
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
};
class MutiInOutOpMaker : public OpProtoAndCheckerMaker { class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
public: public:
MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
...@@ -61,10 +54,8 @@ TEST(GradOpBuilder, AddTwo) { ...@@ -61,10 +54,8 @@ TEST(GradOpBuilder, AddTwo) {
EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD"); EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD");
} }
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad); REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_GRADIENT_OP(mult_io_grad, f::NOP); REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad);
REGISTER_GRADIENT_OP(io_ignored_grad, f::NOP);
TEST(GradOpBuilder, MutiInOut) { TEST(GradOpBuilder, MutiInOut) {
f::AttributeMap attrs{{"input_format", std::vector<int>{0, 1, 4, 5}}, f::AttributeMap attrs{{"input_format", std::vector<int>{0, 1, 4, 5}},
......
...@@ -193,7 +193,7 @@ class OpRegistry { ...@@ -193,7 +193,7 @@ class OpRegistry {
using VarNameList = std::vector<std::string>; using VarNameList = std::vector<std::string>;
public: public:
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type, static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) { const std::string& grad_op_type) {
PADDLE_ENFORCE(op_info_map().count(op_type) == 0, PADDLE_ENFORCE(op_info_map().count(op_type) == 0,
...@@ -226,6 +226,10 @@ class OpRegistry { ...@@ -226,6 +226,10 @@ class OpRegistry {
// ================================================ // // ================================================ //
} }
op_info_map().insert(std::make_pair(op_type, op_info)); op_info_map().insert(std::make_pair(op_type, op_info));
// register gradient op
if (!grad_op_type.empty()) {
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
}
} }
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
...@@ -321,12 +325,13 @@ class Registrar { ...@@ -321,12 +325,13 @@ class Registrar {
void Touch() {} void Touch() {}
}; };
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
class OpRegistrar : public Registrar { class OpRegistrar : public Registrar {
public: public:
explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
OpRegistrar(const char* op_type, const char* grad_op_type) { OpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type, grad_op_type); OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type,
grad_op_type);
} }
}; };
...@@ -352,10 +357,12 @@ class OpKernelRegistrar : public Registrar { ...@@ -352,10 +357,12 @@ class OpKernelRegistrar : public Registrar {
/** /**
* Macro to register Operator. * Macro to register Operator.
*/ */
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type) \ #define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \ static ::paddle::framework::OpRegistrar<op_class, op_maker_class, \
grad_op_class> \
__op_registrar_##op_type##__(#op_type, #grad_op_type); \ __op_registrar_##op_type##__(#op_type, #grad_op_type); \
int TouchOpRegistrar_##op_type() { \ int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \ __op_registrar_##op_type##__.Touch(); \
...@@ -363,10 +370,7 @@ class OpKernelRegistrar : public Registrar { ...@@ -363,10 +370,7 @@ class OpKernelRegistrar : public Registrar {
} }
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OP(op_type, op_class, op_maker_class, ) REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP)
#define REGISTER_GRADIENT_OP(op_type, op_class) \
REGISTER_OP(op_type, op_class, ::paddle::framework::NOPMaker, )
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
......
...@@ -125,6 +125,13 @@ class OperatorBase { ...@@ -125,6 +125,13 @@ class OperatorBase {
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_; std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
}; };
class NOP : public OperatorBase {
public:
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
};
class InferShapeContext { class InferShapeContext {
public: public:
InferShapeContext(const OperatorBase& op, const Scope& scope) InferShapeContext(const OperatorBase& op, const Scope& scope)
......
...@@ -55,8 +55,7 @@ class AddOpGrad : public framework::OperatorWithKernel { ...@@ -55,8 +55,7 @@ class AddOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad); REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad);
REGISTER_GRADIENT_OP(add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add_two, REGISTER_OP_CPU_KERNEL(add_two,
ops::AddKernel<paddle::platform::CPUPlace, float>); ops::AddKernel<paddle::platform::CPUPlace, float>);
...@@ -69,12 +69,11 @@ OnehotCrossEntropy Operator. ...@@ -69,12 +69,11 @@ OnehotCrossEntropy Operator.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad); ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy, onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>); ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy_grad, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>); ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>);
...@@ -50,9 +50,8 @@ class MeanGradOp : public framework::OperatorWithKernel { ...@@ -50,9 +50,8 @@ class MeanGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad); REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean, REGISTER_OP_CPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::CPUPlace, float>); ops::MeanKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean_grad, REGISTER_OP_CPU_KERNEL(mean_grad,
ops::MeanGradKernel<paddle::platform::CPUPlace, float>); ops::MeanGradKernel<paddle::platform::CPUPlace, float>);
...@@ -65,7 +65,5 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -65,7 +65,5 @@ class MulOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad); REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad);
REGISTER_GRADIENT_OP(mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
...@@ -48,9 +48,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { ...@@ -48,9 +48,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad); REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad,
REGISTER_GRADIENT_OP(sigmoid_grad, ops::SigmoidOpGrad); ops::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid, REGISTER_OP_CPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::CPUPlace, float>); ops::SigmoidKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -64,9 +64,9 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -64,9 +64,9 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad); REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad,
ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax, REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>); ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
softmax_grad, ops::SoftmaxGradKernel<paddle::platform::CPUPlace, float>); softmax_grad, ops::SoftmaxGradKernel<paddle::platform::CPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册