未验证 提交 9fd67ffe 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Fix single dtype register errror (#39506)

* fix single dtype reg errror

* fix windows failed
上级 b81358d1
......@@ -219,18 +219,17 @@ struct KernelRegistrar {
*
* Note: `2TA` means `2 template argument`
*/
#define PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_KERNEL must be called in global namespace."); \
_PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#define PT_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_KERNEL must be called in global namespace."); \
PT_EXPAND(_PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, __VA_ARGS__))
#ifndef _WIN32
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \
kernel_name, backend, layout, meta_kernel_fn, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \
......@@ -239,7 +238,6 @@ struct KernelRegistrar {
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
......@@ -257,34 +255,30 @@ struct KernelRegistrar {
* And msvc can work without template instantiation
*/
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
kernel_name, backend, layout, meta_kernel_fn, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \
PT_EXPAND(PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
__VA_ARGS__)); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \
backend, \
cpp_dtype, \
__VA_ARGS__)
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, ...) \
_PT_KERNEL_INSTANTIATION( \
PT_NARGS(__VA_ARGS__), meta_kernel_fn, backend, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, backend, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
......@@ -343,38 +337,35 @@ struct KernelRegistrar {
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(cpp_dtype, __VA_ARGS__), \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
#define PT_KERNEL_REGISTRAR_INIT( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, ...) \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(__VA_ARGS__), \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
// clang-format off
/* The =pre-commit always treats this macro into the wrong format,
and multi-line macros cannot be skipped with NOLINT.*/
#define _PT_KERNEL_REGISTRAR_INIT(N, \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \
kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_KERNEL_REGISTRAR_INIT(N, \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
...) \
PT_EXPAND(PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \
kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
// clang-format on
......@@ -384,8 +375,7 @@ struct KernelRegistrar {
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
cpp_dtype) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册