// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #include "lite/api/paddle_lite_factory_helper.h" #include "lite/core/kernel.h" #include "lite/core/op_lite.h" #include "lite/core/target_wrapper.h" #include "lite/utils/all.h" #include "lite/utils/macros.h" using LiteType = paddle::lite::Type; class OpKernelInfoCollector { public: static OpKernelInfoCollector &Global() { static auto *x = new OpKernelInfoCollector; return *x; } void AddOp2path(const std::string &op_name, const std::string &op_path) { size_t index = op_path.find_last_of('/'); if (index != std::string::npos) { op2path_.insert(std::pair( op_name, op_path.substr(index + 1))); } } void AddKernel2path(const std::string &kernel_name, const std::string &kernel_path) { size_t index = kernel_path.find_last_of('/'); if (index != std::string::npos) { kernel2path_.insert(std::pair( kernel_name, kernel_path.substr(index + 1))); } } void SetKernel2path( const std::map &kernel2path_map) { kernel2path_ = kernel2path_map; } const std::map &GetOp2PathDict() { return op2path_; } const std::map &GetKernel2PathDict() { return kernel2path_; } private: std::map op2path_; std::map kernel2path_; }; namespace paddle { namespace lite { const std::map &GetOp2PathDict(); using KernelFunc = std::function; using KernelFuncCreator = std::function()>; class LiteOpRegistry final : public Factory> { public: static LiteOpRegistry &Global() { static auto *x = new LiteOpRegistry; return *x; } private: LiteOpRegistry() = default; }; template class OpLiteRegistor : public Registor { public: explicit OpLiteRegistor(const std::string &op_type) : Registor([&] { LiteOpRegistry::Global().Register( op_type, [op_type]() -> std::unique_ptr { return std::unique_ptr(new OpClass(op_type)); }); }) {} }; template using KernelRegistryForTarget = Factory, std::unique_ptr>; class KernelRegistry final { public: using any_kernel_registor_t = variant *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget * // >; KernelRegistry(); static KernelRegistry &Global(); template void Register( const std::string &name, typename KernelRegistryForTarget::creator_t &&creator) { using kernel_registor_t = KernelRegistryForTarget; auto &varient = registries_[std::make_tuple(Target, Precision, Layout)]; auto *reg = varient.template get(); CHECK(reg) << "Can not be empty of " << name; reg->Register(name, std::move(creator)); #ifdef LITE_ON_MODEL_OPTIMIZE_TOOL kernel_info_map_[name].push_back( std::make_tuple(Target, Precision, Layout)); #endif // LITE_ON_MODEL_OPTIMIZE_TOOL } template std::list> Create(const std::string &op_type) { using kernel_registor_t = KernelRegistryForTarget; std::list> kernel_list; std::tuple temp_tuple( Target, Precision, Layout); if (registries_[temp_tuple].valid()) { kernel_list = registries_[temp_tuple].template get()->Creates( op_type); } return kernel_list; } std::list> Create(const std::string &op_type, TargetType target, PrecisionType precision, DataLayoutType layout); std::string DebugString() const { #ifndef LITE_ON_MODEL_OPTIMIZE_TOOL return "No more debug info"; #else // LITE_ON_MODEL_OPTIMIZE_TOOL STL::stringstream ss; ss << "\n"; ss << "Count of kernel kinds: "; int count = 0; for (auto &item : kernel_info_map_) { count += item.second.size(); } ss << count << "\n"; ss << "Count of registered kernels: " << kernel_info_map_.size() << "\n"; for (auto &item : kernel_info_map_) { ss << "op: " << item.first << "\n"; for (auto &kernel : item.second) { ss << " - (" << TargetToStr(std::get<0>(kernel)) << ","; ss << PrecisionToStr(std::get<1>(kernel)) << ","; ss << DataLayoutToStr(std::get<2>(kernel)); ss << ")"; ss << "\n"; } } return ss.str(); #endif // LITE_ON_MODEL_OPTIMIZE_TOOL } private: mutable std::map, any_kernel_registor_t> registries_; #ifndef LITE_ON_TINY_PUBLISH mutable std::map< std::string, std::vector>> kernel_info_map_; #endif }; template class KernelRegistor : public lite::Registor { public: KernelRegistor(const std::string &op_type, const std::string &alias) : Registor([=] { KernelRegistry::Global().Register( op_type, [=]() -> std::unique_ptr { std::unique_ptr x(new KernelType); x->set_op_type(op_type); x->set_alias(alias); return x; }); }) {} }; } // namespace lite } // namespace paddle // Operator registry #define LITE_OP_REGISTER_INSTANCE(op_type__) op_type__##__registry__instance__ #define REGISTER_LITE_OP(op_type__, OpClass) \ static paddle::lite::OpLiteRegistor LITE_OP_REGISTER_INSTANCE( \ op_type__)(#op_type__); \ int touch_op_##op_type__() { \ OpKernelInfoCollector::Global().AddOp2path(#op_type__, __FILE__); \ return LITE_OP_REGISTER_INSTANCE(op_type__).Touch(); \ } // Kernel registry #define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \ op_type__##__##target__##__##precision__##__registor__ #define LITE_KERNEL_REGISTER_INSTANCE( \ op_type__, target__, precision__, layout__, alias__) \ op_type__##__##target__##__##precision__##__##layout__##registor__instance__##alias__ // NOLINT #define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \ LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__) #define REGISTER_LITE_KERNEL( \ op_type__, target__, precision__, layout__, KernelClass, alias__) \ static paddle::lite::KernelRegistor \ LITE_KERNEL_REGISTER_INSTANCE( \ op_type__, target__, precision__, layout__, alias__)(#op_type__, \ #alias__); \ static KernelClass LITE_KERNEL_INSTANCE( \ op_type__, target__, precision__, layout__, alias__); \ int touch_##op_type__##target__##precision__##layout__##alias__() { \ OpKernelInfoCollector::Global().AddKernel2path( \ #op_type__ "," #target__ "," #precision__ "," #layout__ "," #alias__, \ __FILE__); \ LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, alias__) \ .Touch(); \ return 0; \ } \ static bool LITE_KERNEL_PARAM_INSTANCE( \ op_type__, target__, precision__, layout__, alias__) UNUSED = \ paddle::lite::ParamTypeRegistry::NewInstance( \ #op_type__ "/" #alias__) #define LITE_KERNEL_INSTANCE( \ op_type__, target__, precision__, layout__, alias__) \ op_type__##target__##precision__##layout__##alias__ #define LITE_KERNEL_PARAM_INSTANCE( \ op_type__, target__, precision__, layout__, alias__) \ op_type__##target__##precision__##layout__##alias__##param_register