diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index a76a95644dae2755a9599a57259a1f9b2ed604b7..433edbfda742d3be9915eade7b0a455398a501dc 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -19,10 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) -cc_library(grad_op_creator SRCS grad_op_creator.cc DEPS op_proto operator) -cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_creator) +cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator) +cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) -cc_test(grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry add_op) +cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. diff --git a/paddle/framework/grad_op_creator.cc b/paddle/framework/grad_op_builder.cc similarity index 88% rename from paddle/framework/grad_op_creator.cc rename to paddle/framework/grad_op_builder.cc index 106c2eae9dade9ef1829fc2f1b793faf483947d4..6235be75f27dadb65de663ff1b3caf26a649f6cb 100644 --- a/paddle/framework/grad_op_creator.cc +++ b/paddle/framework/grad_op_builder.cc @@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/framework/grad_op_creator.h" +#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace framework { -OperatorBase* GradOpCreator::Create() { +OperatorBase* GradOpBuilder::Build() { BuildOpInOutArgList(); - OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)(); + std::string grad_op_type = OpRegistry::grad_ops().at(op_->type_); + OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); + grad_op->type_ = grad_op_type; CompleteGradOp(grad_op); return grad_op; } -OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, +OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var, const VarIndexMap& var_map, const std::vector& format, InOutType type) { @@ -36,7 +38,7 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, end_idx); } -void GradOpCreator::BuildOpInOutArgList() { +void GradOpBuilder::BuildOpInOutArgList() { const OpProto& op_proto = OpRegistry::protos().at(op_->type_); const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_)); const std::vector& in_format = @@ -57,7 +59,7 @@ void GradOpCreator::BuildOpInOutArgList() { } } -void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg, +void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, std::vector& in_out, std::vector& format, VarIndexMap* varmap, int& idx, @@ -80,8 +82,7 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg, format.push_back(in_out.size()); } -void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const { - grad_op->type_ = op_->type_ + "@GRAD"; // not necessary +void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const { grad_op->attrs_ = op_->attrs_; grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("output_format"); diff --git a/paddle/framework/grad_op_creator.h b/paddle/framework/grad_op_builder.h similarity index 92% rename from paddle/framework/grad_op_creator.h rename to paddle/framework/grad_op_builder.h index 21b160a73f3f6402a0571e2f13be06b26b5c30bc..2ecf39479b4f4a51f89cd500caf851897df0e599 100644 --- a/paddle/framework/grad_op_creator.h +++ b/paddle/framework/grad_op_builder.h @@ -25,12 +25,12 @@ struct OpInOutArg { size_t end_idx_; }; -class GradOpCreator { +class GradOpBuilder { using VarIndexMap = std::unordered_map; public: - GradOpCreator(const OperatorBase* op) : op_(op) {} - OperatorBase* Create(); + GradOpBuilder(const OperatorBase* op) : op_(op) {} + OperatorBase* Build(); private: OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map, diff --git a/paddle/framework/grad_op_creator_test.cc b/paddle/framework/grad_op_builder_test.cc similarity index 89% rename from paddle/framework/grad_op_creator_test.cc rename to paddle/framework/grad_op_builder_test.cc index 27ac65813120a2a682535a02bcecb882c4a7640d..288a7841cd7c9212d8fa230e38d49dfc26e76256 100644 --- a/paddle/framework/grad_op_creator_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -1,4 +1,4 @@ -#include "paddle/framework/grad_op_creator.h" +#include "paddle/framework/grad_op_builder.h" #include #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" @@ -8,7 +8,7 @@ USE_OP(add_two); namespace paddle { namespace framework { -TEST(GradOpCreator, AddTwo) { +TEST(GradOpBuilder, AddTwo) { std::shared_ptr add_op( OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(add_op); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 41c78309327342ff47982fc105eadf777c7e59c7..f16deae028d76dc40d6bc589648b461c430c3c98 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -20,7 +20,7 @@ limitations under the License. */ #include #include #include "paddle/framework/attr_checker.h" -#include "paddle/framework/grad_op_creator.h" +#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/scope.h" @@ -222,7 +222,7 @@ class OpRegistry { public: template static void RegisterOp(const std::string& op_type) { - creators()[op_type] = [] { return new OpType; }; + op_creators()[op_type] = [] { return new OpType; }; OpAttrChecker& op_checker = op_checkers()[op_type]; OpProto& op_proto = protos()[op_type]; auto maker = ProtoMakerType(&op_proto, &op_checker); @@ -245,17 +245,19 @@ class OpRegistry { } } - template - static void RegisterGradOp(const std::string& op_type) { - grad_creators()[op_type] = [] { return new OpType; }; + template + static void RegisterGradOp(const std::string& op_type, + const std::string& grad_op_type) { + op_creators()[grad_op_type] = [] { return new GradOpType; }; + grad_ops()[op_type] = grad_op_type; } static std::shared_ptr CreateOp(const std::string& type, const VarNameList& inputs, const VarNameList& outputs, const AttributeMap& attrs) { - auto op_create_it = creators().find(type); - PADDLE_ENFORCE(op_create_it != creators().end(), + auto op_create_it = op_creators().find(type); + PADDLE_ENFORCE(op_create_it != op_creators().end(), "Operator %s cannot be found.", type); auto op = op_create_it->second(); @@ -300,8 +302,8 @@ class OpRegistry { static std::shared_ptr CreateGradOp( std::shared_ptr op) { - GradOpCreator creator(op.get()); - std::shared_ptr grad_op(creator.Create()); + GradOpBuilder builder(op.get()); + std::shared_ptr grad_op(builder.Build()); grad_op->Init(); return grad_op; } @@ -311,9 +313,9 @@ class OpRegistry { return protos_; }; - static std::unordered_map& grad_creators() { - static std::unordered_map grad_creators_; - return grad_creators_; + static std::unordered_map& grad_ops() { + static std::unordered_map grad_ops_; + return grad_ops_; } static std::unordered_map>& @@ -322,12 +324,12 @@ class OpRegistry { return maps_; } - private: - static std::unordered_map& creators() { - static std::unordered_map creators_; - return creators_; + static std::unordered_map& op_creators() { + static std::unordered_map op_creators_; + return op_creators_; } + private: static std::unordered_map& op_checkers() { static std::unordered_map op_checkers_; return op_checkers_; @@ -353,11 +355,11 @@ class OpRegisterHelper { } }; -template +template class GradOpRegisterHelper { public: - GradOpRegisterHelper(const char* op_type) { - OpRegistry::RegisterGradOp(op_type); + GradOpRegisterHelper(const char* op_type, const char* grad_op_type) { + OpRegistry::RegisterGradOp(op_type, grad_op_type); } }; @@ -383,13 +385,16 @@ class GradOpRegisterHelper { /** * Macro to Register Gradient Operator. */ -#define REGISTER_GRADIENT_OP(__op_type, __op_class) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_gradient_op__##__op_type, \ - "REGISTER_GRADIENT_OP must be in global namespace"); \ - static ::paddle::framework::GradOpRegisterHelper<__op_class> \ - __op_gradient_register_##__op_type##__(#__op_type); \ - int __op_gradient_register_##__op_type##_handle__() { return 0; } +#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_gradient_op__##__op_type##__grad_op_type, \ + "REGISTER_GRADIENT_OP must be in global namespace"); \ + static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \ + __op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \ + #__grad_op_type); \ + int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \ + return 0; \ + } /** * Macro to Register OperatorKernel. diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index ff60f9b314c86ad92218caea15ca5d9f6d996b4e..8d415fbd2e72af556e21f89c37d31b9fad130e3d 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -65,6 +65,6 @@ protected: } // namespace paddle REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); -REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad); +REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad); REGISTER_OP_CPU_KERNEL( add_two, paddle::operators::AddKernel); diff --git a/paddle/operators/add_op_test.cc b/paddle/operators/add_op_test.cc index 7fc1049893e171a17af92da7e813b2463874c9de..3d52f5498323dbb7ca0ff25d038947f0ddb2017e 100644 --- a/paddle/operators/add_op_test.cc +++ b/paddle/operators/add_op_test.cc @@ -22,7 +22,7 @@ TEST(AddOp, GetOpProto) { auto& protos = paddle::framework::OpRegistry::protos(); auto it = protos.find("add_two"); ASSERT_NE(it, protos.end()); - auto& grad_creators = paddle::framework::OpRegistry::grad_creators(); - auto it1 = grad_creators.find("add_two"); - ASSERT_NE(it1, grad_creators.end()); + auto& op_creators = paddle::framework::OpRegistry::op_creators(); + auto it1 = op_creators.find("add_two_grad"); + ASSERT_NE(it1, op_creators.end()); } diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 89e0375a7a043730685c4c0883ac672bdd688159..cd74c8b976d18ffecd50077cc81e1fce56bea155 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -67,7 +67,7 @@ protected: } // namespace paddle REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker); -REGISTER_GRADIENT_OP(mul, paddle::operators::MulOpGrad); +REGISTER_GRADIENT_OP(mul, mul_grad, paddle::operators::MulOpGrad); REGISTER_OP_CPU_KERNEL( mul, paddle::operators::MulKernel); diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index 7dc58bbb10007545cd281ae7da359e4c2b32fae0..bf63af28b003daad0ab8c223e71a561437ee663a 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -56,7 +56,7 @@ protected: REGISTER_OP(sigmoid, paddle::operators::SigmoidOp, paddle::operators::SigmoidOpMaker); -REGISTER_GRADIENT_OP(sigmoid, paddle::operators::SigmoidOpGrad); +REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, paddle::operators::SigmoidOpGrad); REGISTER_OP_CPU_KERNEL( sigmoid, diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 1d10a415d0208e1edb881eacad951a07fcbb8b5c..82f72fa19f690bebdff01629e75d17eecd6ada74 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -59,6 +59,6 @@ protected: namespace ops = paddle::operators; REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); -REGISTER_GRADIENT_OP(softmax, paddle::operators::SoftmaxOpGrad); +REGISTER_GRADIENT_OP(softmax, softmax_grad, paddle::operators::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel);