// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/utils/all.h" using LiteType = paddle::lite::Type; namespace paddle { namespace lite { using KernelFunc = std::function; using KernelFuncCreator = std::function()>; class LiteOpRegistry final : public Factory> { public: static LiteOpRegistry &Global() { static auto *x = new LiteOpRegistry; return *x; } private: LiteOpRegistry() = default; }; template class OpLiteRegistor : public Registor { public: explicit OpLiteRegistor(const std::string &op_type) : Registor([&] { LiteOpRegistry::Global().Register( op_type, [op_type]() -> std::unique_ptr { return std::unique_ptr(new OpClass(op_type)); }); }) {} }; template using KernelRegistryForTarget = Factory, std::unique_ptr>; class KernelRegistry final { public: using any_kernel_registor_t = variant *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget * // >; KernelRegistry(); static KernelRegistry &Global(); template void Register(const std::string &name, typename KernelRegistryForTarget::creator_t &&creator) { VLOG(3) << "register for " << TargetToStr(Target) << ":" << PrecisionToStr(Precision) << "//" << GetKernelOffset(); using kernel_registor_t = KernelRegistryForTarget; auto &varient = registries_[GetKernelOffset()]; varient.template get()->Register(name, std::move(creator)); } template std::list> Create(const std::string &op_type) { using kernel_registor_t = KernelRegistryForTarget; return registries_[GetKernelOffset()] .template get() ->Creates(op_type); } std::list> Create(const std::string &op_type, TargetType target, PrecisionType precision, DataLayoutType layout); // Get a kernel registry offset in all the registries. template static int GetKernelOffset() { CHECK_LT(static_cast(Target), static_cast(TARGET(NUM))); CHECK_LT(static_cast(Precision), static_cast(PRECISION(NUM))); CHECK_LT(static_cast(Layout), static_cast(DATALAYOUT(NUM))); return static_cast(Target) * static_cast(PRECISION(NUM)) * static_cast(DATALAYOUT(NUM)) + // static_cast(Precision) * static_cast(DATALAYOUT(NUM)) + // static_cast(Layout); } std::string DebugString() const { std::stringstream ss; ss << "KernelCreator:" << std::endl; ss << registries_[GetKernelOffset()] .get *>() ->DebugString(); ss << std::endl; return ss.str(); } private: mutable std::array(TARGET(NUM)) * static_cast(PRECISION(NUM)) * static_cast(DATALAYOUT(NUM))> registries_; }; template class KernelRegistor : public lite::Registor { public: KernelRegistor(const std::string &op_type, const std::string &alias) : Registor([=] { VLOG(3) << "Register kernel " << op_type << " for " << TargetToStr(target) << " " << PrecisionToStr(precision) << " " << DataLayoutToStr(layout) << " alias " << alias; KernelRegistry::Global().Register( op_type, [=]() -> std::unique_ptr { std::unique_ptr x(new KernelType); x->set_op_type(op_type); x->set_alias(alias); return x; }); }) {} }; } // namespace lite } // namespace paddle // Operator registry #define LITE_OP_REGISTER_INSTANCE(op_type__) op_type__##__registry__instance__ #define LITE_OP_REGISTER_FAKE(op_type__) op_type__##__registry__ #define REGISTER_LITE_OP(op_type__, OpClass) \ static paddle::lite::OpLiteRegistor LITE_OP_REGISTER_INSTANCE( \ op_type__)(#op_type__); \ int touch_op_##op_type__() { \ return LITE_OP_REGISTER_INSTANCE(op_type__).Touch(); \ } #define USE_LITE_OP(op_type__) \ extern int touch_op_##op_type__(); \ int LITE_OP_REGISTER_FAKE(op_type__) __attribute__((unused)) = \ touch_op_##op_type__(); // Kernel registry #define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \ op_type__##__##target__##__##precision__##__registor__ #define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \ layout__, alias__) \ op_type__##__##target__##__##precision__##__registor__instance__##alias__ #define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \ LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__) #define REGISTER_LITE_KERNEL(op_type__, target__, precision__, layout__, \ KernelClass, alias__) \ static paddle::lite::KernelRegistor \ LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \ layout__, alias__)(#op_type__, #alias__); \ static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \ layout__, alias__); \ int touch_##op_type__##target__##precision__##layout__##alias__() { \ LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, alias__) \ .Touch(); \ return 0; \ } \ static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \ layout__, alias__) \ __attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \ TARGET(target__), PRECISION(precision__), DATALAYOUT(layout__)>( \ #op_type__ "/" #alias__) #define USE_LITE_KERNEL(op_type__, target__, precision__, layout__, alias__) \ extern int touch_##op_type__##target__##precision__##layout__##alias__(); \ int op_type__##target__##precision__##layout__##alias__ \ __attribute__((unused)) = \ touch_##op_type__##target__##precision__##layout__##alias__(); #define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, \ alias__) \ op_type__##target__##precision__##layout__##alias__ #define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, layout__, \ alias__) \ op_type__##target__##precision__##layout__##alias__##param_register