未验证 提交 9fd67ffe 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Fix single dtype register errror (#39506)

* fix single dtype reg errror

* fix windows failed
上级 b81358d1
...@@ -219,18 +219,17 @@ struct KernelRegistrar { ...@@ -219,18 +219,17 @@ struct KernelRegistrar {
* *
* Note: `2TA` means `2 template argument` * Note: `2TA` means `2 template argument`
*/ */
#define PT_REGISTER_KERNEL( \ #define PT_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ "PT_REGISTER_KERNEL must be called in global namespace."); \
"PT_REGISTER_KERNEL must be called in global namespace."); \ PT_EXPAND(_PT_REGISTER_2TA_KERNEL( \
_PT_REGISTER_2TA_KERNEL( \ kernel_name, backend, layout, meta_kernel_fn, __VA_ARGS__))
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32 #ifndef _WIN32
#define _PT_REGISTER_2TA_KERNEL( \ #define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, meta_kernel_fn, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \ PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \ ::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \ PT_KERNEL_REGISTRAR_INIT( \
...@@ -239,7 +238,6 @@ struct KernelRegistrar { ...@@ -239,7 +238,6 @@ struct KernelRegistrar {
layout, \ layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \ __VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel) ::pten::Kernel* kernel)
...@@ -257,34 +255,30 @@ struct KernelRegistrar { ...@@ -257,34 +255,30 @@ struct KernelRegistrar {
* And msvc can work without template instantiation * And msvc can work without template instantiation
*/ */
#define _PT_REGISTER_2TA_KERNEL( \ #define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, meta_kernel_fn, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \ ::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \ PT_EXPAND(PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \ kernel_name, \
backend, \ backend, \
layout, \ layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ __VA_ARGS__)); \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel) ::pten::Kernel* kernel)
#endif #endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \ #define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, ...) \
_PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \ _PT_KERNEL_INSTANTIATION( \
meta_kernel_fn, \ PT_NARGS(__VA_ARGS__), meta_kernel_fn, backend, __VA_ARGS__)
backend, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, cpp_dtype, ...) \ #define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__) (meta_kernel_fn, backend, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \ #define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \ template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context> meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \ #define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \ template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
...@@ -343,38 +337,35 @@ struct KernelRegistrar { ...@@ -343,38 +337,35 @@ struct KernelRegistrar {
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \ meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__)) PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT( \ #define PT_KERNEL_REGISTRAR_INIT( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, args_def_fn, meta_kernel_fn, ...) \
_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(cpp_dtype, __VA_ARGS__), \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(__VA_ARGS__), \
kernel_name, \ kernel_name, \
backend, \ backend, \
layout, \ layout, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ __VA_ARGS__))
__VA_ARGS__)
// clang-format off // clang-format off
/* The =pre-commit always treats this macro into the wrong format, /* The =pre-commit always treats this macro into the wrong format,
and multi-line macros cannot be skipped with NOLINT.*/ and multi-line macros cannot be skipped with NOLINT.*/
#define _PT_KERNEL_REGISTRAR_INIT(N, \ #define _PT_KERNEL_REGISTRAR_INIT(N, \
kernel_name, \ kernel_name, \
backend, \ backend, \
layout, \ layout, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ ...) \
...) \ PT_EXPAND(PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \
PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \ kernel_name, \
kernel_name, \ backend, \
backend, \ layout, \
layout, \ PT_ID, \
PT_ID, \ args_def_fn, \
args_def_fn, \ meta_kernel_fn, \
meta_kernel_fn, \ __VA_ARGS__))
cpp_dtype, \
__VA_ARGS__)
// clang-format on // clang-format on
...@@ -384,8 +375,7 @@ struct KernelRegistrar { ...@@ -384,8 +375,7 @@ struct KernelRegistrar {
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype) \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册