diff --git a/src/common/type_define.h b/src/common/type_define.h index 895a7f9786ed3ad10cf336f8323d55b50e9bb6a7..43eb6b2066712c59067c1a8fd580e7d1f75c5099 100644 --- a/src/common/type_define.h +++ b/src/common/type_define.h @@ -19,7 +19,9 @@ SOFTWARE. #include #include +#include #include "framework/attribute.h" +#include "framework/scope.h" namespace paddle_mobile { @@ -37,7 +39,8 @@ template using OpCreator = std::function *( const std::string & /*type*/, const VariableNameMap & /*inputs*/, const VariableNameMap & /*outputs*/, - const framework::AttributeMap & /*attrs*/)>; + const framework::AttributeMap & /*attrs*/, + std::shared_ptr /*scope*/)>; using GradOpMakerFN = std::function>( diff --git a/src/framework/op_info.h b/src/framework/op_info.h index fe55594d20d8e8a4b6c13d67078ed1a862170664..32f453ab97d6bfb30cffea4d537c9394696e3212 100644 --- a/src/framework/op_info.h +++ b/src/framework/op_info.h @@ -18,8 +18,10 @@ SOFTWARE. #pragma once +#include +#include "common/log.h" #include "common/type_define.h" -#include "framework.pb.h" +#include "framework/framework.pb.h" namespace paddle_mobile { namespace framework { @@ -45,11 +47,12 @@ template class OpInfoMap { public: static OpInfoMap &Instance() { + LOG(paddle_mobile::kLOG_DEBUG1) << " TODO: fix bug"; if (g_op_info_map == nullptr) { g_op_info_map = new OpInfoMap(); } return *g_op_info_map; - }; + } bool Has(const std::string &op_type) const { return map_.find(op_type) != map_.end(); diff --git a/src/framework/op_registry.h b/src/framework/op_registry.h index 8a2aacdde48950d9284bf533bc9d92b30dc66792..4ba60ad4d481d92a4613bd59149bfc5f5c8266ae 100644 --- a/src/framework/op_registry.h +++ b/src/framework/op_registry.h @@ -17,3 +17,121 @@ SOFTWARE. ==============================================================================*/ #pragma once + +#include +#include +#include "common/log.h" +#include "common/type_define.h" +#include "framework/op_info.h" +#include "framework/operator.h" + +namespace paddle_mobile { +namespace framework { + +class Registrar { + public: + void Touch() {} +}; + +template +class OperatorRegistrarRecursive; + +template +struct OperatorRegistrar : public Registrar { + explicit OperatorRegistrar(const std::string& op_type) { + if (OpInfoMap::Instance().Has(op_type)) { + LOG(paddle_mobile::kLOG_DEBUG1) + << op_type << " is registered more than once."; + return; + } + if (sizeof...(ARGS) == 0) { + LOG(paddle_mobile::kLOG_DEBUG1) + << "OperatorRegistrar should be invoked at least by OpClass"; + return; + } + OpInfo info; + OperatorRegistrarRecursive(op_type, &info); + OpInfoMap::Instance().Insert(op_type, info); + } +}; + +template +struct OpInfoFiller { + void operator()(const std::string& op_type, OpInfo* info) const { + info->creator_ = [](const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, + const AttributeMap& attrs, + std::shared_ptr scope) { + return new T(type, inputs, outputs, attrs, scope); + }; + } +}; + +template +class OperatorRegistrarRecursive { + public: + using T = typename std::tuple_element>::type; + OperatorRegistrarRecursive(const std::string& op_type, OpInfo* info) { + OpInfoFiller fill; + fill(op_type, info); + constexpr auto size = sizeof...(ARGS); + OperatorRegistrarRecursive reg( + op_type, info); + (void)(reg); + } +}; + +template +class OperatorRegistrarRecursive { + public: + OperatorRegistrarRecursive(const std::string& op_type, OpInfo* info) {} +}; + +template +class OpRegistry { + public: + static std::shared_ptr> CreateOp( + const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap attrs, + std::shared_ptr scope) { + LOG(paddle_mobile::kLOG_DEBUG1) << " type: " + << type; + LOG(paddle_mobile::kLOG_DEBUG1) << " input size: " + << inputs.size(); + LOG(paddle_mobile::kLOG_DEBUG1) << " output size: " + << outputs.size(); + LOG(paddle_mobile::kLOG_DEBUG1) << " attr size: " + << attrs.size(); + LOG(paddle_mobile::kLOG_DEBUG1) << " OpInfoMap size: " + << OpInfoMap::Instance().map().size(); + LOG(paddle_mobile::kLOG_DEBUG1) << " has type: " + << type + << " " + << OpInfoMap::Instance().Has(type); + auto& info = OpInfoMap::Instance().Get(type); + auto op = info.Creator()(type, inputs, outputs, attrs, scope); + return std::shared_ptr>(op); + } +}; + +#define REGISTER_OPERATOR(op_type, op_class) \ + template \ + class _OpClass_##op_type##_ : public op_class { \ + public: \ + DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ + }; \ + static paddle_mobile::framework::OperatorRegistrar< \ + paddle_mobile::CPU, _OpClass_##op_type##_> \ + __op_registrar_##op_type##__(#op_type); \ + int TouchOpRegistrar_##op_type() { \ + __op_registrar_##op_type##__.Touch(); \ + return 0; \ + } + +#define USE_OP(op_type) \ + extern int TouchOpRegistrar_##op_type(); \ + static int use_op_itself_##op_type##_ __attribute__((unused)) = \ + TouchOpRegistrar_##op_type() + +} // namespace framework +} // namespace paddle_mobile diff --git a/src/framework/operator.h b/src/framework/operator.h index 6250abf9138c7d71f0a29f7530097a531a4e6072..71de21153500fcd22847977f1cbeca804baf69b6 100644 --- a/src/framework/operator.h +++ b/src/framework/operator.h @@ -19,18 +19,20 @@ SOFTWARE. #pragma once #include - -#include "attribute.h" -#include "block_desc.h" +#include +#include +#include +#include "framework/attribute.h" +#include "framework/block_desc.h" #include "common/type_define.h" #include "common/types.h" #include "common/variant.h" -#include "op_info.h" -#include "op_kernel_type.h" -#include "paddle_mobile_object.h" -#include "scope.h" -#include "tensor.h" -#include "variable.h" +#include "framework/op_info.h" +#include "framework/op_kernel_type.h" +#include "framework/paddle_mobile_object.h" +#include "framework/scope.h" +#include "framework/tensor.h" +#include "framework/variable.h" namespace paddle_mobile { namespace framework { @@ -97,5 +99,12 @@ class OpKernelBase : PaddleMobileObject { virtual ~OpKernelBase() = default; }; +#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls) \ + cls(const std::string &type, const ::paddle_mobile::VariableNameMap &inputs, \ + const ::paddle_mobile::VariableNameMap &outputs, \ + const ::paddle_mobile::framework::AttributeMap &attrs, \ + std::shared_ptr<::paddle_mobile::framework::Scope> scope) \ + : parent_cls(type, inputs, outputs, attrs, scope) {} + } // namespace framework } // namespace paddle_mobile diff --git a/src/operators/conv_op.cpp b/src/operators/conv_op.cpp index 248ed4798b0b7bdb8236698d37d55207278b4cde..237d3ad62e18d630f5b0a5b802b671611302422d 100644 --- a/src/operators/conv_op.cpp +++ b/src/operators/conv_op.cpp @@ -16,9 +16,11 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ==============================================================================*/ -#include "conv_op.h" +#include +#include "operators/conv_op.h" #include "framework/data_type.h" #include "framework/op_proto_maker.h" +#include "framework/op_registry.h" namespace paddle_mobile { namespace operators { @@ -73,3 +75,7 @@ template class ConvOp; } // namespace operators } // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +USE_OP(conv2d); +REGISTER_OPERATOR(conv2d, ops::ConvOp);