未验证 提交 a725c9a5 编写于 作者: Z zhangyuqin1998 提交者: GitHub

Kernel registrar (#52079)

* add kernel register macro for all backend

* fix msvc bug

* fix

---------
Co-authored-by: Nzhangyuqin1998 <2368719379@qq.com>
上级 a1832474
......@@ -78,7 +78,7 @@ function(kernel_declare TARGET_LIST)
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z]*,[ \\\t\r\n]*[A-Z_]*"
"(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
first_registry
"${kernel_impl}")
if(NOT first_registry STREQUAL "")
......@@ -96,8 +96,18 @@ function(kernel_declare TARGET_LIST)
continue()
endif()
endif()
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
is_all_backend
"${first_registry}")
# parse the registerd kernel message
string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_msg "${first_registry}")
string(REPLACE "PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(" "" kernel_msg
"${first_registry}")
string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_msg "${kernel_msg}")
string(REPLACE "PD_REGISTER_GENERAL_KERNEL(" "" kernel_msg
"${kernel_msg}")
string(REPLACE "," ";" kernel_msg "${kernel_msg}")
......@@ -105,9 +115,13 @@ function(kernel_declare TARGET_LIST)
string(REGEX REPLACE "//cuda_only" "" kernel_msg "${kernel_msg}")
list(GET kernel_msg 0 kernel_name)
list(GET kernel_msg 1 kernel_backend)
list(GET kernel_msg 2 kernel_layout)
if(NOT is_all_backend STREQUAL "")
list(GET kernel_msg 1 kernel_layout)
set(kernel_backend "CPU")
else()
list(GET kernel_msg 1 kernel_backend)
list(GET kernel_msg 2 kernel_layout)
endif()
# append kernel declare into declarations.h
file(
APPEND ${kernel_declare_file}
......
......@@ -1388,6 +1388,132 @@ struct KernelRegistrar {
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
#endif
/** PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE
*
* Used to register a instantiated kernel function
* for all backend with one template argument.
*/
#define PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE( \
kernel_name, layout, meta_kernel_fn) \
_PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(::phi::RegType::INNER, \
kernel_name, \
layout, \
meta_kernel_fn, \
BACKEND_LIST)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#define _DEVICE GPU,
#elif defined(PADDLE_WITH_XPU)
#define _DEVICE XPU,
#else
#define _DEVICE
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#define _CUSTOM Custom,
#else
#define _CUSTOM
#endif
#define BACKEND_LIST _DEVICE _CUSTOM CPU
#define _PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE( \
reg_type, kernel_name, layout, meta_kernel_fn, ...) \
PD_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PD_REGISTER_nt_kernel_ns_check_##kernel_name##_##layout, \
"PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE must be called in global " \
"namespace."); \
PD_EXPAND(__PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(reg_type, \
kernel_name, \
layout, \
meta_kernel_fn, \
PD_NARGS(__VA_ARGS__), \
__VA_ARGS__))
#define __PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE( \
reg_type, kernel_name, layout, meta_kernel_fn, N, ...) \
static void __PD_KERNEL_args_def_FN_##kernel_name##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \
PD_EXPAND( \
PD_CONCATENATE(_PD_FOR_ALL_BACKEND_DTYPE_, N)( \
reg_type, \
kernel_name, \
layout, \
meta_kernel_fn, \
__PD_KERNEL_args_def_FN_##kernel_name##_##layout, \
__VA_ARGS__) void \
__PD_KERNEL_args_def_FN_##kernel_name##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel))
#ifndef _WIN32
#define ___PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE( \
reg_type, kernel_name, backend, layout, kernel_fn, args_def_fn) \
template decltype(kernel_fn) kernel_fn; \
static const ::phi::KernelRegistrar \
__reg_phi_kernel_##kernel_name##_##backend##_##layout( \
reg_type, \
#kernel_name, \
#backend, \
DATA_LAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&args_def_fn, \
PHI_KERNEL(kernel_fn), \
PHI_VARIADIC_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#else
#define ___PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE( \
reg_type, kernel_name, backend, layout, kernel_fn, args_def_fn) \
static const ::phi::KernelRegistrar \
__reg_phi_kernel_##kernel_name##_##backend##_##layout( \
reg_type, \
#kernel_name, \
#backend, \
DATA_LAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&args_def_fn, \
PHI_KERNEL(kernel_fn), \
PHI_VARIADIC_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#endif
#define _PD_FOR_ALL_BACKEND_DTYPE_1( \
reg_type, kernel_name, layout, meta_kernel_fn, args_def_fn, backend) \
___PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE( \
reg_type, \
kernel_name, \
backend, \
layout, \
meta_kernel_fn<::phi::backend##Context>, \
args_def_fn)
#define _PD_FOR_ALL_BACKEND_DTYPE_2( \
reg_type, kernel_name, layout, meta_kernel_fn, args_def_fn, backend, ...) \
___PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE( \
reg_type, \
kernel_name, \
backend, \
layout, \
meta_kernel_fn<::phi::backend##Context>, \
args_def_fn) \
PD_EXPAND(_PD_FOR_ALL_BACKEND_DTYPE_1(reg_type, \
kernel_name, \
layout, \
meta_kernel_fn, \
args_def_fn, \
__VA_ARGS__))
#define _PD_FOR_ALL_BACKEND_DTYPE_3( \
reg_type, kernel_name, layout, meta_kernel_fn, args_def_fn, backend, ...) \
___PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE( \
reg_type, \
kernel_name, \
backend, \
layout, \
meta_kernel_fn<::phi::backend##Context>, \
args_def_fn) \
PD_EXPAND(_PD_FOR_ALL_BACKEND_DTYPE_2(reg_type, \
kernel_name, \
layout, \
meta_kernel_fn, \
args_def_fn, \
__VA_ARGS__))
/** PD_DECLARE_KERNEL
*
* Used to export the symbols of the file where the kernel is located,
......
......@@ -87,43 +87,9 @@ void ReshapeKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_GENERAL_KERNEL(reshape_infer,
CPU,
ALL_LAYOUT,
phi::ReshapeInferKernel<phi::CPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
reshape, CPU, ALL_LAYOUT, phi::ReshapeKernel<phi::CPUContext>, ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(reshape_infer,
GPU,
ALL_LAYOUT,
phi::ReshapeInferKernel<phi::GPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
reshape, GPU, ALL_LAYOUT, phi::ReshapeKernel<phi::GPUContext>, ALL_DTYPE) {}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_GENERAL_KERNEL(reshape_infer,
XPU,
ALL_LAYOUT,
phi::ReshapeInferKernel<phi::XPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
reshape, XPU, ALL_LAYOUT, phi::ReshapeKernel<phi::XPUContext>, ALL_DTYPE) {}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_GENERAL_KERNEL(reshape_infer,
Custom,
ALL_LAYOUT,
phi::ReshapeInferKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(reshape,
Custom,
ALL_LAYOUT,
phi::ReshapeKernel<phi::CustomContext>,
ALL_DTYPE) {}
#endif
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(reshape_infer,
ALL_LAYOUT,
phi::ReshapeInferKernel) {}
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(reshape,
ALL_LAYOUT,
phi::ReshapeKernel) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册