diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index 25523204015622f4d5f0b65b3e67d17fc5f43c5a..927c36e9e8f432f2d7f068c6f76142a2509faced 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -219,18 +219,17 @@ struct KernelRegistrar { * * Note: `2TA` means `2 template argument` */ -#define PT_REGISTER_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ - "PT_REGISTER_KERNEL must be called in global namespace."); \ - _PT_REGISTER_2TA_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__) +#define PT_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ + "PT_REGISTER_KERNEL must be called in global namespace."); \ + PT_EXPAND(_PT_REGISTER_2TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, __VA_ARGS__)) #ifndef _WIN32 #define _PT_REGISTER_2TA_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \ + kernel_name, backend, layout, meta_kernel_fn, ...) \ + PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, __VA_ARGS__); \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ ::pten::Kernel*); \ PT_KERNEL_REGISTRAR_INIT( \ @@ -239,7 +238,6 @@ struct KernelRegistrar { layout, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ meta_kernel_fn, \ - cpp_dtype, \ __VA_ARGS__); \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ ::pten::Kernel* kernel) @@ -257,34 +255,30 @@ struct KernelRegistrar { * And msvc can work without template instantiation */ #define _PT_REGISTER_2TA_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + kernel_name, backend, layout, meta_kernel_fn, ...) \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ ::pten::Kernel*); \ - PT_KERNEL_REGISTRAR_INIT( \ + PT_EXPAND(PT_KERNEL_REGISTRAR_INIT( \ kernel_name, \ backend, \ layout, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__); \ + __VA_ARGS__)); \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ ::pten::Kernel* kernel) #endif -#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \ - _PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \ - meta_kernel_fn, \ - backend, \ - cpp_dtype, \ - __VA_ARGS__) +#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, ...) \ + _PT_KERNEL_INSTANTIATION( \ + PT_NARGS(__VA_ARGS__), meta_kernel_fn, backend, __VA_ARGS__) -#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, cpp_dtype, ...) \ - PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ - (meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__) +#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, ...) \ + PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ + (meta_kernel_fn, backend, __VA_ARGS__) -#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ +#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype) \ + template decltype(meta_kernel_fn) \ meta_kernel_fn #define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \ template decltype(meta_kernel_fn) \ @@ -343,38 +337,35 @@ struct KernelRegistrar { meta_kernel_fn; \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__)) -#define PT_KERNEL_REGISTRAR_INIT( \ - kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \ - _PT_KERNEL_REGISTRAR_INIT(PT_NARGS(cpp_dtype, __VA_ARGS__), \ - kernel_name, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__) +#define PT_KERNEL_REGISTRAR_INIT( \ + kernel_name, backend, layout, args_def_fn, meta_kernel_fn, ...) \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(__VA_ARGS__), \ + kernel_name, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) // clang-format off /* The =pre-commit always treats this macro into the wrong format, and multi-line macros cannot be skipped with NOLINT.*/ -#define _PT_KERNEL_REGISTRAR_INIT(N, \ - kernel_name, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \ - kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__) +#define _PT_KERNEL_REGISTRAR_INIT(N, \ + kernel_name, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + ...) \ + PT_EXPAND(PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \ + kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) // clang-format on @@ -384,8 +375,7 @@ struct KernelRegistrar { registrar_id, \ args_def_fn, \ meta_kernel_fn, \ - cpp_dtype, \ - ...) \ + cpp_dtype) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ #kernel_name, \