From 77af58f8f73d19329c2703961d7cfc0581839308 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 24 Jul 2017 21:56:37 +0800 Subject: [PATCH] Change gradient Op registry mechanism OLD: op_type -> grad_op_creator NEW: grad_op_type -> grad_op_creator op_type -> grad_op_type --- ...{grad_op_creator.cc => grad_op_builder.cc} | 17 +++--- .../{grad_op_creator.h => grad_op_builder.h} | 6 +- paddle/framework/op_registry.h | 55 ++++++++++--------- paddle/operators/add_op.cc | 2 +- 4 files changed, 43 insertions(+), 37 deletions(-) rename paddle/framework/{grad_op_creator.cc => grad_op_builder.cc} (88%) rename paddle/framework/{grad_op_creator.h => grad_op_builder.h} (92%) diff --git a/paddle/framework/grad_op_creator.cc b/paddle/framework/grad_op_builder.cc similarity index 88% rename from paddle/framework/grad_op_creator.cc rename to paddle/framework/grad_op_builder.cc index 106c2eae9da..d9ec8a10a52 100644 --- a/paddle/framework/grad_op_creator.cc +++ b/paddle/framework/grad_op_builder.cc @@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/framework/grad_op_creator.h" +#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace framework { -OperatorBase* GradOpCreator::Create() { +OperatorBase* GradOpBuilder::Build() { BuildOpInOutArgList(); - OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)(); + std::string grad_op_type = OpRegistry::grad_ops().at(op->type_); + OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); + grad_op->type_ = grad_op_type; CompleteGradOp(grad_op); return grad_op; } -OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, +OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var, const VarIndexMap& var_map, const std::vector& format, InOutType type) { @@ -36,7 +38,7 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, end_idx); } -void GradOpCreator::BuildOpInOutArgList() { +void GradOpBuilder::BuildOpInOutArgList() { const OpProto& op_proto = OpRegistry::protos().at(op_->type_); const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_)); const std::vector& in_format = @@ -57,7 +59,7 @@ void GradOpCreator::BuildOpInOutArgList() { } } -void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg, +void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, std::vector& in_out, std::vector& format, VarIndexMap* varmap, int& idx, @@ -80,8 +82,7 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg, format.push_back(in_out.size()); } -void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const { - grad_op->type_ = op_->type_ + "@GRAD"; // not necessary +void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const { grad_op->attrs_ = op_->attrs_; grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("output_format"); diff --git a/paddle/framework/grad_op_creator.h b/paddle/framework/grad_op_builder.h similarity index 92% rename from paddle/framework/grad_op_creator.h rename to paddle/framework/grad_op_builder.h index 21b160a73f3..2ecf39479b4 100644 --- a/paddle/framework/grad_op_creator.h +++ b/paddle/framework/grad_op_builder.h @@ -25,12 +25,12 @@ struct OpInOutArg { size_t end_idx_; }; -class GradOpCreator { +class GradOpBuilder { using VarIndexMap = std::unordered_map; public: - GradOpCreator(const OperatorBase* op) : op_(op) {} - OperatorBase* Create(); + GradOpBuilder(const OperatorBase* op) : op_(op) {} + OperatorBase* Build(); private: OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map, diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 41c78309327..31a4151851f 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -222,7 +222,7 @@ class OpRegistry { public: template static void RegisterOp(const std::string& op_type) { - creators()[op_type] = [] { return new OpType; }; + op_creators()[op_type] = [] { return new OpType; }; OpAttrChecker& op_checker = op_checkers()[op_type]; OpProto& op_proto = protos()[op_type]; auto maker = ProtoMakerType(&op_proto, &op_checker); @@ -245,17 +245,19 @@ class OpRegistry { } } - template - static void RegisterGradOp(const std::string& op_type) { - grad_creators()[op_type] = [] { return new OpType; }; + template + static void RegisterGradOp(const std::string& op_type, + const std::string& grad_op_type) { + op_creators()[grad_op_type] = [] { return new GradOpType; }; + grad_ops()[op_type] = grad_op_type; } static std::shared_ptr CreateOp(const std::string& type, const VarNameList& inputs, const VarNameList& outputs, const AttributeMap& attrs) { - auto op_create_it = creators().find(type); - PADDLE_ENFORCE(op_create_it != creators().end(), + auto op_create_it = op_creators().find(type); + PADDLE_ENFORCE(op_create_it != op_creators().end(), "Operator %s cannot be found.", type); auto op = op_create_it->second(); @@ -300,8 +302,8 @@ class OpRegistry { static std::shared_ptr CreateGradOp( std::shared_ptr op) { - GradOpCreator creator(op.get()); - std::shared_ptr grad_op(creator.Create()); + GradOpBuilder builder(op.get()); + std::shared_ptr grad_op(builder.Build()); grad_op->Init(); return grad_op; } @@ -311,9 +313,9 @@ class OpRegistry { return protos_; }; - static std::unordered_map& grad_creators() { - static std::unordered_map grad_creators_; - return grad_creators_; + static std::unordered_map& grad_ops() { + static std::unordered_map grad_ops_; + return grad_ops_; } static std::unordered_map>& @@ -322,12 +324,12 @@ class OpRegistry { return maps_; } - private: - static std::unordered_map& creators() { - static std::unordered_map creators_; - return creators_; + static std::unordered_map& op_creators() { + static std::unordered_map op_creators_; + return op_creators_; } + private: static std::unordered_map& op_checkers() { static std::unordered_map op_checkers_; return op_checkers_; @@ -353,11 +355,11 @@ class OpRegisterHelper { } }; -template +template class GradOpRegisterHelper { public: - GradOpRegisterHelper(const char* op_type) { - OpRegistry::RegisterGradOp(op_type); + GradOpRegisterHelper(const char* op_type, const char* grad_op_type) { + OpRegistry::RegisterGradOp(op_type, grad_op_type); } }; @@ -383,13 +385,16 @@ class GradOpRegisterHelper { /** * Macro to Register Gradient Operator. */ -#define REGISTER_GRADIENT_OP(__op_type, __op_class) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_gradient_op__##__op_type, \ - "REGISTER_GRADIENT_OP must be in global namespace"); \ - static ::paddle::framework::GradOpRegisterHelper<__op_class> \ - __op_gradient_register_##__op_type##__(#__op_type); \ - int __op_gradient_register_##__op_type##_handle__() { return 0; } +#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_gradient_op__##__op_type##__grad_op_type, \ + "REGISTER_GRADIENT_OP must be in global namespace"); \ + static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \ + __op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \ + #__grad_op_type); \ + int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \ + return 0; \ + } /** * Macro to Register OperatorKernel. diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index ff60f9b314c..8d415fbd2e7 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -65,6 +65,6 @@ protected: } // namespace paddle REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); -REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad); +REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad); REGISTER_OP_CPU_KERNEL( add_two, paddle::operators::AddKernel); -- GitLab