diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 40390d41502ddf073a6b6220d3a644e8257e1fea..9193a1593efadad03fceeeac9dfce98ebcacbfa5 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -33,7 +33,7 @@ static inline std::unique_ptr CreateGradOp( op_desc.SetType(op.Type()); op_desc.SetAttrMap(op.Attrs()); auto& info = OpInfoMap::Instance().Get(op.Type()); - auto grad_descs = info.grad_op_maker_(op_desc); + auto grad_descs = info.GradOpMaker()(op_desc); std::vector> grad_ops; grad_ops.reserve(grad_descs.size()); std::transform(grad_descs.begin(), grad_descs.end(), @@ -49,6 +49,7 @@ static inline std::unique_ptr CreateGradOp( for (auto& grad_op : grad_ops) { net_op->AppendOp(std::move(grad_op)); } + net_op->CompleteAddOp(); return std::unique_ptr(net_op); } } diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 6f87e055b492dcde1584a6a92d95817b3c17f33e..968f587b46a3ed9528819fdd31bd3ef8059c67a3 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -30,7 +30,6 @@ namespace framework { struct OpInfo { OpCreator creator_; - std::string grad_op_type_; GradOpMakerFN grad_op_maker_; OpProto* proto_{nullptr}; OpAttrChecker* checker_{nullptr}; @@ -51,6 +50,12 @@ struct OpInfo { "Operator Creator has not been registered"); return creator_; } + + const GradOpMakerFN& GradOpMaker() const { + PADDLE_ENFORCE_NOT_NULL(grad_op_maker_, + "Operator GradOpMaker has not been registered."); + return grad_op_maker_; + } }; class OpInfoMap { diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index da112fa488c914d23e2c3ad2f91f0bcf1d8b772e..a4f0144ce845bda2b4792d03b637fd98b5c7ce14 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -137,23 +137,21 @@ class OpKernelRegistrar : public Registrar { __test_global_namespace_##uniq_name##__>::value, \ msg) -#define VA_ARGS(...) , ##__VA_ARGS__ - -#define REGISTER_OPERATOR(op_type, op_class, ...) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_op__##op_type, \ - "REGISTER_OPERATOR must be called in global namespace"); \ - class _OpClass_##op_type##_ : public op_class { \ - public: \ - DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ - DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ - }; \ - static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_ VA_ARGS( \ - __VA_ARGS__)> \ - __op_registrar_##op_type##__(#op_type); \ - int TouchOpRegistrar_##op_type() { \ - __op_registrar_##op_type##__.Touch(); \ - return 0; \ +#define REGISTER_OPERATOR(op_type, op_class, ...) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_op__##op_type, \ + "REGISTER_OPERATOR must be called in global namespace"); \ + class _OpClass_##op_type##_ : public op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ + DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ + }; \ + static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_, \ + ##__VA_ARGS__> \ + __op_registrar_##op_type##__(#op_type); \ + int TouchOpRegistrar_##op_type() { \ + __op_registrar_##op_type##__.Touch(); \ + return 0; \ } /** @@ -170,7 +168,7 @@ class OpKernelRegistrar : public Registrar { virtual std::string GradOpType() const { return #grad_op_type; } \ }; \ REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \ - op_maker_class) + op_maker_class); #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ REGISTER_OPERATOR(op_type, op_class, op_maker_class) diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index aced8636b96d139c69895d7eba11892f190e2173..7057dcbd6e375adef57d17a13afdfade67e938b6 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -72,19 +72,26 @@ class MinusGradMaker : public framework::GradOpDescMakerBase { std::vector> operator()() const override { std::vector> ops; - ops.resize(2); - - ops[0].reset(new framework::OpDescBind()); - ops[0]->SetType("scale"); - ops[0]->SetInput("X", OutputGrad("Out")); - ops[0]->SetOutput("Out", InputGrad("X")); - ops[0]->SetAttr("scale", 1.0f); - - ops[1].reset(new framework::OpDescBind()); - ops[1]->SetType("scale"); - ops[1]->SetInput("X", OutputGrad("Out")); - ops[1]->SetOutput("Out", InputGrad("Y")); - ops[1]->SetAttr("scale", -1.0f); + auto x_g = InputGrad("X"); + if (!x_g.empty()) { + auto *x_g_op = new framework::OpDescBind(); + x_g_op->SetType("scale"); + x_g_op->SetInput("X", OutputGrad("Out")); + x_g_op->SetOutput("Out", x_g); + x_g_op->SetAttr("scale", 1.0f); + ops.emplace_back(x_g_op); + } + + auto y_g = InputGrad("Y"); + if (!y_g.empty()) { + auto *y_g_op = new framework::OpDescBind(); + y_g_op->SetType("scale"); + y_g_op->SetInput("X", OutputGrad("Out")); + y_g_op->SetOutput("Out", y_g); + y_g_op->SetAttr("scale", -1.0f); + ops.emplace_back(y_g_op); + } + return ops; } }; diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 944591773929bf97ad4bfc8961f58c17e424b71d..15aa05f26610be14e4c37be35137a259e00eb947 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -121,6 +121,7 @@ class PadOpGradMaker : public framework::SingleGradOpDescMaker { bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); bind->SetOutput(framework::GradVarName("X"), InputGrad("X")); bind->SetAttrMap(Attrs()); + bind->SetType("pad_grad"); return std::unique_ptr(bind); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index bc9868874da259a640b3bb0f38cb4ef08fc0fb83..70fe429f59254515a9e5fd648f98ebe958db35c1 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -14,6 +14,12 @@ #include "paddle/operators/softmax_with_cross_entropy_op.h" #include +#include + +#define DBG_LINE() \ + do { \ + std::cerr << "Run at " << __LINE__ << std::endl; \ + } while (false) namespace paddle { namespace operators { @@ -187,8 +193,7 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, - ops::SoftmaxWithCrossEntropyOpMaker, - ops::SoftmaxWithCrossEntropyOpMaker); + ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker); REGISTER_OPERATOR(softmax_with_cross_entropy_grad, ops::SoftmaxWithCrossEntropyOpGrad); REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,