diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index c0ad888b706f6bf0933611b03e30ca186339cc12..9140854a96ccac583adce4c89f7b134360360c6d 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -22,7 +22,7 @@ cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) -cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) +cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..d2516ccc1eec21682276c2fddf049984453404d4 --- /dev/null +++ b/paddle/framework/details/op_registry.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_info.h" +#include "paddle/framework/op_proto_maker.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace framework { +namespace details { + +enum OpInfoFillType { + kOperator = 0, + kOpProtoAndCheckerMaker = 1, + kGradOpDescMaker = 2 +}; + +template +struct OpInfoFillTypeID { + static constexpr OpInfoFillType ID() { + return std::is_base_of::value + ? kOperator + : (std::is_base_of::value + ? kOpProtoAndCheckerMaker + : (std::is_base_of::value + ? kGradOpDescMaker + : static_cast(-1))); + } +}; + +template ::ID()> +struct OpInfoFiller; + +template +class OperatorRegistrarRecursive; + +template +class OperatorRegistrarRecursive { + public: + using T = typename std::tuple_element>::type; + OperatorRegistrarRecursive(const char* op_type, OpInfo* info) { + OpInfoFiller fill; + fill(op_type, info); + constexpr auto size = sizeof...(ARGS); + OperatorRegistrarRecursive reg(op_type, + info); + (void)(reg); + } +}; + +template +class OperatorRegistrarRecursive { + public: + OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {} +}; + +template +struct OpInfoFiller { + void operator()(const char* op_type, OpInfo* info) const { + info->creator_ = [](const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, + const AttributeMap& attrs) { + return new T(type, inputs, outputs, attrs); + }; + } +}; + +template +struct OpInfoFiller { + void operator()(const char* op_type, OpInfo* info) const { + info->proto_ = new OpProto; + info->checker_ = new OpAttrChecker(); + auto maker = T(info->proto_, info->checker_); + maker.Validate(); + info->proto_->set_type(op_type); + PADDLE_ENFORCE( + info->proto_->IsInitialized(), + "Fail to initialize %s's OpProto, because %s is not initialized", + op_type, info->proto_->InitializationErrorString()); + } +}; + +template +struct OpInfoFiller { + void operator()(const char* op_type, OpInfo* info) const { + info->grad_op_maker_ = new T(); + } +}; +} // namespace details + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index b98d8f23a14cf6fbe787953ad16b5c9ab99222ad..6d1ee4dece14affa78ccf4ff1d9c7f09992f127f 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -17,8 +17,8 @@ #include #include #include - #include "paddle/framework/attribute.h" +#include "paddle/framework/op_desc.h" namespace paddle { namespace framework { @@ -29,11 +29,18 @@ using OpCreator = std::function; +class GradOpDescMakerBase { + public: + virtual ~GradOpDescMakerBase() = default; + virtual std::vector operator()(const OpDescBind&) const = 0; +}; + struct OpInfo { OpCreator creator_; std::string grad_op_type_; - OpProto* proto_; - OpAttrChecker* checker_; + GradOpDescMakerBase* grad_op_maker_{nullptr}; + OpProto* proto_{nullptr}; + OpAttrChecker* checker_{nullptr}; bool HasOpProtoAndChecker() const { return proto_ != nullptr && checker_ != nullptr; diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 4db38badaea8ae22d9ad47951f4941f3bdeb401a..4ee2c7d27561c3855059c42e1604f353c65bfa41 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -21,49 +21,42 @@ limitations under the License. */ #include #include #include "paddle/framework/attribute.h" +#include "paddle/framework/details/op_registry.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/grad_op_builder.h" -#include "paddle/framework/op_info.h" -#include "paddle/framework/op_proto_maker.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" namespace paddle { namespace framework { +template +struct OperatorRegistrar { + explicit OperatorRegistrar(const char* op_type) : op_type(op_type) { + PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), + "'%s' is registered more than once.", op_type); + static_assert(sizeof...(ARGS) != 0, + "OperatorRegistrar should be invoked at least by OpClass"); + details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info); + } + + ~OperatorRegistrar() { OpInfoMap::Instance().Insert(op_type, info); } + + const char* op_type; + + OpInfo info; +}; + class OpRegistry { public: template static void RegisterOp(const std::string& op_type, const std::string& grad_op_type) { - PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), - "'%s' is registered more than once.", op_type); - OpInfo op_info; - op_info.creator_ = []( - const std::string& type, const VariableNameMap& inputs, - const VariableNameMap& outputs, const AttributeMap& attrs) { - return new OpType(type, inputs, outputs, attrs); - }; - op_info.grad_op_type_ = grad_op_type; - if (std::type_index(typeid(ProtoMakerType)) != - std::type_index(typeid(NOPMaker))) { - op_info.proto_ = new OpProto; - op_info.checker_ = new OpAttrChecker; - auto maker = ProtoMakerType(op_info.proto_, op_info.checker_); - maker.Validate(); - op_info.proto_->set_type(op_type); - PADDLE_ENFORCE( - op_info.proto_->IsInitialized(), - "Fail to initialize %s's OpProto, because %s is not initialized", - op_type, op_info.proto_->InitializationErrorString()); - } else { - op_info.proto_ = nullptr; - op_info.checker_ = nullptr; - } - OpInfoMap::Instance().Insert(op_type, op_info); + OperatorRegistrar reg(op_type.c_str()); + reg.info.grad_op_type_ = grad_op_type; // register gradient op if (!grad_op_type.empty()) { - RegisterOp(grad_op_type, ""); + OperatorRegistrar grad_reg(grad_op_type.c_str()); } } diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index b6fc0409d5cb22b13352df41b8e911c79bc4825a..f89f40b444f893d898595db50e245824a19284f6 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -173,3 +173,14 @@ TEST(OpRegistry, CustomChecker) { int test_attr = op->Attr("test_attr"); ASSERT_EQ(test_attr, 4); } + +class CosineOpComplete : public paddle::framework::CosineOp { + public: + DEFINE_OP_CONSTRUCTOR(CosineOpComplete, paddle::framework::CosineOp); + DEFINE_OP_CLONE_METHOD(CosineOpComplete); +}; + +TEST(OperatorRegistrar, Test) { + using namespace paddle::framework; + OperatorRegistrar reg("cos"); +} \ No newline at end of file