From fc0a50aa725de28f87640d3593ee23d062d49eef Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 23 Dec 2021 23:11:46 -0600 Subject: [PATCH] add register general kernel marco (#38409) --- cmake/pten_kernel.cmake | 5 +-- paddle/pten/core/kernel_registry.h | 56 ++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/cmake/pten_kernel.cmake b/cmake/pten_kernel.cmake index 3934a828c2..8ec81dfa5b 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten_kernel.cmake @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -# call kernel_declare need to make sure the target of input is exists +# call kernel_declare need to make sure whether the target of input exists function(kernel_declare TARGET_LIST) foreach(kernel_path ${TARGET_LIST}) file(READ ${kernel_path} kernel_impl) # TODO(chenweihang): rename PT_REGISTER_CTX_KERNEL to PT_REGISTER_KERNEL # NOTE(chenweihang): now we don't recommend to use digit in kernel name - string(REGEX MATCH "PT_REGISTER_CTX_KERNEL\\([ \t\r\n]*[a-z_]*," first_registry "${kernel_impl}") + string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z_]*," first_registry "${kernel_impl}") if (NOT first_registry STREQUAL "") # parse the first kernel name string(REPLACE "PT_REGISTER_CTX_KERNEL(" "" kernel_name "${first_registry}") + string(REPLACE "PT_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}") string(REPLACE "," "" kernel_name "${kernel_name}") string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}") # append kernel declare into declarations.h diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index a33b13dac2..bd4687c6e7 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -749,6 +749,8 @@ struct KernelRegistrar { * layout, so the layout also need to be a part of symbol var name. If developer * register 2 kernel with same name, backend, layout and diff dtype, he should * use another register marco PT_REGISTER_KERNEL. + * + * TODO(chenweihang): remove this marco later */ #define PT_REGISTER_NO_TEMPLATE_KERNEL( \ kernel_name, backend, layout, kernel_fn, dtype) \ @@ -772,6 +774,60 @@ struct KernelRegistrar { void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ ::pten::Kernel* kernel) +/** PT_REGISTER_GENERAL_KERNEL + * + * Basic Kernel register marco, used to register a instantiated kernel function + * with one template argument. + */ + +#define PT_REGISTER_GENERAL_KERNEL( \ + kernel_name, backend, layout, kernel_fn, dtype) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \ + "PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \ + _PT_REGISTER_GENERAL_KERNEL(kernel_name, backend, layout, kernel_fn, dtype) + +#ifndef _WIN32 +#define _PT_REGISTER_GENERAL_KERNEL( \ + kernel_name, backend, layout, kernel_fn, dtype) \ + template decltype(kernel_fn) kernel_fn; \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + static const ::pten::KernelRegistrar \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::pten::KernelArgsParseFunctor::Parse, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + PT_KERNEL(kernel_fn), \ + PT_VARIADIC_KERNEL(kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \ + return 0; \ + } \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) +#else +#define _PT_REGISTER_GENERAL_KERNEL( \ + kernel_name, backend, layout, kernel_fn, dtype) \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + static const ::pten::KernelRegistrar \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::pten::KernelArgsParseFunctor::Parse, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + PT_KERNEL(kernel_fn), \ + PT_VARIADIC_KERNEL(kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \ + return 0; \ + } \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) +#endif + /** PT_REGISTER_CTX_KERNEL * * Used for kernel registration with device context and data type as -- GitLab