// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { namespace lite { using KernelFunc = std::function; using KernelFuncCreator = std::function()>; class LiteOpRegistry final : public Factory { public: static LiteOpRegistry &Global() { static auto *x = new LiteOpRegistry; return *x; } private: LiteOpRegistry() = default; }; template class OpLiteRegistor : public Registor { public: OpLiteRegistor(const std::string &op_type) : Registor([&] { LiteOpRegistry::Global().Register( op_type, []() -> std::unique_ptr { return std::unique_ptr(new OpClass); }); }) {} }; template using KernelRegistryForTarget = Factory>; class KernelRegistry final { public: using any_kernel_registor_t = variant *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget *, // KernelRegistryForTarget * // >; KernelRegistry() { /* using kernel_target_t = KernelRegistryForTarget; registries_[0].set( &KernelRegistryForTarget::Global()); */ #define INIT_FOR(target__, precision__) \ registries_[KernelRegistry::GetKernelOffset()] \ .set \ *>(&KernelRegistryForTarget::Global()); // Currently, just register 2 kernel targets. INIT_FOR(kHost, kFloat); #undef INIT_FOR } static KernelRegistry &Global() { static auto *x = new KernelRegistry; return *x; } template void Register(const std::string &name, typename KernelRegistryForTarget::creator_t &&creator) { using kernel_registor_t = KernelRegistryForTarget; registries_[GetKernelOffset()] .template get() ->Register(name, std::move(creator)); } template std::unique_ptr Create(const std::string &op_type) { using kernel_registor_t = KernelRegistryForTarget; return registries_[GetKernelOffset()] .template get() ->Create(op_type); } std::unique_ptr Create(const std::string &op_type, TargetType target, PrecisionType precision) { #define CREATE_KERNEL(target__) \ switch (precision) { \ case PRECISION(kFloat): \ return Create(op_type); \ default: \ CHECK(false) << "not supported kernel place yet"; \ } switch (target) { case TARGET(kHost): { CREATE_KERNEL(kHost); } break; case TARGET(kX86): { CREATE_KERNEL(kX86); } break; case TARGET(kCUDA): { CREATE_KERNEL(kCUDA); } break; default: CHECK(false) << "not supported kernel place"; } #undef CREATE_KERNEL } // Get a kernel registry offset in all the registries. template static constexpr int GetKernelOffset() { return kNumTargets * static_cast(Target) + static_cast(Precision); } private: std::array registries_; }; template class KernelRegistor : public lite::Registor { public: KernelRegistor(const std::string op_type) : Registor([&] { KernelRegistry::Global().Register( op_type, [&]() -> std::unique_ptr { return std::unique_ptr(new KernelType); }); }) {} }; } // namespace lite } // namespace paddle // Operator registry #define LITE_OP_REGISTER_INSTANCE(op_type__) op_type__##__registry__instance__ #define LITE_OP_REGISTER_FAKE(op_type__) op_type__##__registry__ #define REGISTER_LITE_OP(op_type__, OpClass) \ static paddle::lite::OpLiteRegistor LITE_OP_REGISTER_INSTANCE( \ op_type__)(#op_type__); #define USE_LITE_OP(op_type__) \ int LITE_OP_REGISTER_FAKE(op_type__)((unused)) = \ LITE_OP_REGISTER_INSTANCE(op_type__).Touch(); // Kernel registry #define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \ op_type__##target__##precision__##__registor__ #define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) \ op_type__##target__##precision__##__registor__instance__ #define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \ LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__)##__fake__ #define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass) \ static paddle::lite::KernelRegistor \ LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, \ precision__)(#op_type__); #define USE_LITE_KERNEL(op_type__, target__, precision__) \ int LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__)((unused)) = \ LITE_KERNEL_REGISTER(op_type__, target__, precision__).Touch();