/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include "paddle/fluid/operators/jitkernels/kernel_base.h" #include "paddle/fluid/operators/jitkernels/kernel_pool.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/variant.h" // for UNUSED namespace paddle { namespace operators { namespace jitkernels { // make_unique is supported since c++14 template inline std::unique_ptr make_unique(Args&&... args) { static_assert(!std::is_array::value, "T must not be array"); return std::unique_ptr(new T(std::forward(args)...)); } template struct JitKernelRegistrarFunctor; template struct JitKernelRegistrarFunctor { void operator()(KernelType kt) const {} }; template struct JitKernelRegistrarFunctor { using KERNEL_IMPL_TYPE = typename std::tuple_element>::type; void operator()(KernelType kt) const { KernelKey kkey(kt, PlaceType()); Pool().Instance().Insert(kkey, std::move(make_unique())); constexpr auto size = std::tuple_size>::value; JitKernelRegistrarFunctor func; func(kt); } }; template class JitKernelRegistrar { public: explicit JitKernelRegistrar(KernelType kt) { JitKernelRegistrarFunctor func; func(kt); } void Touch() {} }; #define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \ struct __test_global_namespace_##uniq_name##__ {}; \ static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ __test_global_namespace_##uniq_name##__>::value, \ msg) // Refer always on CPUPlace #define REGISTER_JITKERNEL_REFER(kernel_type, ...) \ STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ __reg_jitkernel_##kernel_type##_refer_CPUPlace, \ "REGISTER_KERNEL_REFER must be called in global namespace"); \ static ::paddle::operators::jitkernels::JitKernelRegistrar< \ ::paddle::operators::jitkernels::ReferKernelPool, \ ::paddle::platform::CPUPlace, __VA_ARGS__> \ __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \ ::paddle::operators::jitkernels::KernelType::kernel_type); \ int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \ __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \ return 0; \ } // kernel_type: should be in paddle::operators::jitkernels::KernelType // place_type: should be one of CPUPlace and GPUPlace in paddle::platform #define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \ STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ __reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \ "REGISTER_KERNEL_MORE must be called in global namespace"); \ extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \ UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ static ::paddle::operators::jitkernels::JitKernelRegistrar< \ ::paddle::operators::jitkernels::KernelPool, \ ::paddle::platform::place_type, __VA_ARGS__> \ __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \ ::paddle::operators::jitkernels::KernelType::kernel_type); \ int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \ __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \ .Touch(); \ return 0; \ } #define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \ REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__) #define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \ REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__) // REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode); #define USE_JITKERNEL_REFER(kernel_type) \ STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ __reg_jitkernel_##kernel_type##_refer_CPUPlace_, \ "USE_JITKERNEL_REFER must be called in global namespace"); \ extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \ TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() #define USE_KERNEL_MORE(kernel_type, impl_type, place_type) \ STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ __reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_, \ "USE_JITKERNEL_MORE must be called in global namespace"); \ extern int \ TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \ static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \ UNUSED = \ TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() #define USE_JITKERNEL_MORE(kernel_type, impl_type) \ USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace) } // namespace jitkernels } // namespace operators } // namespace paddle