未验证 提交 158bf13f 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Rename kernel register marco (#38861)

* rename register marco

* fix error changing

* fix format error
上级 dccdc719
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
function(kernel_declare TARGET_LIST) function(kernel_declare TARGET_LIST)
foreach(kernel_path ${TARGET_LIST}) foreach(kernel_path ${TARGET_LIST})
file(READ ${kernel_path} kernel_impl) file(READ ${kernel_path} kernel_impl)
# TODO(chenweihang): rename PT_REGISTER_CTX_KERNEL to PT_REGISTER_KERNEL # TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name # NOTE(chenweihang): now we don't recommend to use digit in kernel name
string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}") string(REGEX MATCH "(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}")
if (NOT first_registry STREQUAL "") if (NOT first_registry STREQUAL "")
# parse the first kernel name # parse the first kernel name
string(REPLACE "PT_REGISTER_CTX_KERNEL(" "" kernel_name "${first_registry}") string(REPLACE "PT_REGISTER_KERNEL(" "" kernel_name "${first_registry}")
string(REPLACE "PT_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}") string(REPLACE "PT_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}")
string(REPLACE "," "" kernel_name "${kernel_name}") string(REPLACE "," "" kernel_name "${kernel_name}")
string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}") string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}")
......
...@@ -213,20 +213,20 @@ struct KernelRegistrar { ...@@ -213,20 +213,20 @@ struct KernelRegistrar {
* pointer of the corresponding data type is automatically instantiated * pointer of the corresponding data type is automatically instantiated
* during registration. * during registration.
* *
* Note: `1TA` means `1 template argument` * Note: `2TA` means `2 template argument`
*/ */
#define PT_REGISTER_KERNEL( \ #define PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_KERNEL must be called in global namespace."); \ "PT_REGISTER_KERNEL must be called in global namespace."); \
_PT_REGISTER_1TA_KERNEL( \ _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__) kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32 #ifndef _WIN32
#define _PT_REGISTER_1TA_KERNEL( \ #define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \ PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \ ::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \ PT_KERNEL_REGISTRAR_INIT( \
...@@ -252,7 +252,7 @@ struct KernelRegistrar { ...@@ -252,7 +252,7 @@ struct KernelRegistrar {
* *
* And msvc can work without template instantiation * And msvc can work without template instantiation
*/ */
#define _PT_REGISTER_1TA_KERNEL( \ #define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \ ::pten::Kernel*); \
...@@ -268,60 +268,76 @@ struct KernelRegistrar { ...@@ -268,60 +268,76 @@ struct KernelRegistrar {
::pten::Kernel* kernel) ::pten::Kernel* kernel)
#endif #endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \ #define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \ _PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ backend, \
cpp_dtype, \
__VA_ARGS__) __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, cpp_dtype, ...) \ #define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, cpp_dtype, __VA_ARGS__) (meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, cpp_dtype, ...) \ #define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype> template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, cpp_dtype, ...) \ meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ #define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, __VA_ARGS__)) template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, cpp_dtype, ...) \ meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, __VA_ARGS__)) #define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, cpp_dtype, ...) \ template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, __VA_ARGS__)) PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, cpp_dtype, ...) \ #define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, __VA_ARGS__)) meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, cpp_dtype, ...) \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ #define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, __VA_ARGS__)) template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, cpp_dtype, ...) \ meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, __VA_ARGS__)) #define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, cpp_dtype, ...) \ template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, __VA_ARGS__)) PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, cpp_dtype, ...) \ #define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, __VA_ARGS__)) meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, cpp_dtype, ...) \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ #define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, __VA_ARGS__)) template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, cpp_dtype, ...) \ meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, __VA_ARGS__)) #define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, cpp_dtype, ...) \ template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, __VA_ARGS__)) PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, cpp_dtype, ...) \ #define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, __VA_ARGS__)) meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, cpp_dtype, ...) \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ #define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, __VA_ARGS__)) template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, cpp_dtype, ...) \ meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__)) #define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT( \ #define PT_KERNEL_REGISTRAR_INIT( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
...@@ -373,10 +389,11 @@ struct KernelRegistrar { ...@@ -373,10 +389,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; } int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \ backend, \
...@@ -393,10 +410,11 @@ struct KernelRegistrar { ...@@ -393,10 +410,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -419,10 +437,11 @@ struct KernelRegistrar { ...@@ -419,10 +437,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -445,10 +464,11 @@ struct KernelRegistrar { ...@@ -445,10 +464,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -471,10 +491,11 @@ struct KernelRegistrar { ...@@ -471,10 +491,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -497,10 +518,11 @@ struct KernelRegistrar { ...@@ -497,10 +518,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -523,10 +545,11 @@ struct KernelRegistrar { ...@@ -523,10 +545,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -549,10 +572,11 @@ struct KernelRegistrar { ...@@ -549,10 +572,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -575,10 +599,11 @@ struct KernelRegistrar { ...@@ -575,10 +599,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -601,10 +626,11 @@ struct KernelRegistrar { ...@@ -601,10 +626,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -627,10 +653,11 @@ struct KernelRegistrar { ...@@ -627,10 +653,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -653,10 +680,11 @@ struct KernelRegistrar { ...@@ -653,10 +680,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -679,10 +707,11 @@ struct KernelRegistrar { ...@@ -679,10 +707,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -705,10 +734,11 @@ struct KernelRegistrar { ...@@ -705,10 +734,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -731,10 +761,11 @@ struct KernelRegistrar { ...@@ -731,10 +761,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
backend, \ backend, \
layout, \ layout, \
...@@ -743,41 +774,6 @@ struct KernelRegistrar { ...@@ -743,41 +774,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
/** PT_REGISTER_NO_TEMPLATE_KERNEL
*
* Basic Kernel register marco, used to register a no template argument kernel
* function, pass in the complete function pointe of the kernel, this
* registration macro will not do automatic template instantiation.
*
* Note: developer maybe register 2 kernel with same name, backend and diff
* layout, so the layout also need to be a part of symbol var name. If developer
* register 2 kernel with same name, backend, layout and diff dtype, he should
* use another register marco PT_REGISTER_KERNEL.
*
* TODO(chenweihang): remove this marco later
*/
#define PT_REGISTER_NO_TEMPLATE_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
static const ::pten::KernelRegistrar \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
PT_KERNEL(kernel_fn), \
PT_VARIADIC_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \
return 0; \
} \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
/** PT_REGISTER_GENERAL_KERNEL /** PT_REGISTER_GENERAL_KERNEL
* *
* Basic Kernel register marco, used to register a instantiated kernel function * Basic Kernel register marco, used to register a instantiated kernel function
...@@ -832,558 +828,6 @@ struct KernelRegistrar { ...@@ -832,558 +828,6 @@ struct KernelRegistrar {
::pten::Kernel* kernel) ::pten::Kernel* kernel)
#endif #endif
/** PT_REGISTER_CTX_KERNEL
*
* Used for kernel registration with device context and data type as
* template parameter.
*/
#define PT_REGISTER_CTX_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_ctx_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_CTX_KERNEL must be called in global namespace."); \
_PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#else
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#endif
#define PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION2(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \
backend, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION2(N, meta_kernel_fn, backend, cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION2_, N) \
(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
#define _PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_15(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
_PT_KERNEL_REGISTRAR_INIT2(PT_NARGS(cpp_dtype, __VA_ARGS__), \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
// clang-format off
/* The =pre-commit always treats this macro into the wrong format,
and multi-line macros cannot be skipped with NOLINT.*/
#define _PT_KERNEL_REGISTRAR_INIT2(N, \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT2_, N) ( \
kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
// clang-format on
#define _PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_15(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
/** PT_DECLARE_KERNEL /** PT_DECLARE_KERNEL
* *
* Used to export the symbols of the file where the kernel is located, * Used to export the symbols of the file where the kernel is located,
......
...@@ -58,20 +58,20 @@ void CastKernel(const Context& dev_ctx, ...@@ -58,20 +58,20 @@ void CastKernel(const Context& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_CTX_KERNEL(cast, PT_REGISTER_KERNEL(cast,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::CastKernel, pten::CastKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
int16_t, int16_t,
bool, bool,
uint8_t, uint8_t,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::bfloat16, paddle::platform::bfloat16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) { paddle::platform::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
} }
...@@ -21,13 +21,13 @@ ...@@ -21,13 +21,13 @@
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(conj, PT_REGISTER_KERNEL(conj,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::ConjKernel, pten::ConjKernel,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>, paddle::platform::complex<double>,
float, float,
double, double,
int, int,
int64_t) {} int64_t) {}
...@@ -20,13 +20,13 @@ ...@@ -20,13 +20,13 @@
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(dot_grad, PT_REGISTER_KERNEL(dot_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::DotGradKernel, pten::DotGradKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
...@@ -49,13 +49,13 @@ void DotKernel(const Context& dev_ctx, ...@@ -49,13 +49,13 @@ void DotKernel(const Context& dev_ctx,
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(dot, PT_REGISTER_KERNEL(dot,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::DotKernel, pten::DotKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128) {}
...@@ -18,29 +18,29 @@ limitations under the License. */ ...@@ -18,29 +18,29 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h" #include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(full, PT_REGISTER_KERNEL(full,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FullKernel, pten::FullKernel,
float, float,
double, double,
uint8_t, uint8_t,
int16_t, int16_t,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::bfloat16, paddle::platform::bfloat16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(full_like, PT_REGISTER_KERNEL(full_like,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FullLikeKernel, pten::FullLikeKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16) {} paddle::platform::float16) {}
...@@ -118,60 +118,60 @@ using complex128 = ::paddle::platform::complex<double>; ...@@ -118,60 +118,60 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16; // using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_CTX_KERNEL( PT_REGISTER_KERNEL(
mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {} mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {}
PT_REGISTER_CTX_KERNEL(add, PT_REGISTER_KERNEL(add,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::AddKernel, pten::AddKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(subtract, PT_REGISTER_KERNEL(subtract,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::SubtractKernel, pten::SubtractKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(divide, PT_REGISTER_KERNEL(divide,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::DivideKernel, pten::DivideKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(multiply, PT_REGISTER_KERNEL(multiply,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MultiplyKernel, pten::MultiplyKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
bool, bool,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(sum, PT_REGISTER_KERNEL(sum,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::SumKernel, pten::SumKernel,
bool, bool,
float, float,
double, double,
paddle::platform::float16, paddle::platform::float16,
int, int,
int64_t, int64_t,
complex64, complex64,
complex128) { complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
} }
...@@ -19,29 +19,29 @@ limitations under the License. */ ...@@ -19,29 +19,29 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul_grad, PT_REGISTER_KERNEL(matmul_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulGradKernel, pten::MatmulGradKernel,
float, float,
double, double,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_double_grad, PT_REGISTER_KERNEL(matmul_double_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulDoubleGradKernel, pten::MatmulDoubleGradKernel,
float, float,
double, double,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_triple_grad, PT_REGISTER_KERNEL(matmul_triple_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulTripleGradKernel, pten::MatmulTripleGradKernel,
float, float,
double, double,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
...@@ -20,11 +20,11 @@ limitations under the License. */ ...@@ -20,11 +20,11 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h" #include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul, PT_REGISTER_KERNEL(matmul,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulKernel, pten::MatmulKernel,
float, float,
double, double,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
...@@ -51,15 +51,15 @@ void ScaleKernel(const Context& dev_ctx, ...@@ -51,15 +51,15 @@ void ScaleKernel(const Context& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_CTX_KERNEL(scale, PT_REGISTER_KERNEL(scale,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::ScaleKernel, pten::ScaleKernel,
float, float,
double, double,
paddle::platform::bfloat16, paddle::platform::bfloat16,
uint8_t, uint8_t,
int8_t, int8_t,
int16_t, int16_t,
int, int,
int64_t) {} int64_t) {}
...@@ -21,5 +21,4 @@ limitations under the License. */ ...@@ -21,5 +21,4 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
PT_REGISTER_CTX_KERNEL(sign, CPU, ALL_LAYOUT, pten::SignKernel, float, double) { PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::SignKernel, float, double) {}
}
...@@ -34,66 +34,66 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) { ...@@ -34,66 +34,66 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
} // namespace pten } // namespace pten
PT_REGISTER_CTX_KERNEL(empty, PT_REGISTER_KERNEL(empty,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::EmptyKernel, pten::EmptyKernel,
float, float,
double, double,
uint8_t, uint8_t,
int16_t, int16_t,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::bfloat16, paddle::platform::bfloat16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(empty_like, PT_REGISTER_KERNEL(empty_like,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::EmptyLikeKernel, pten::EmptyLikeKernel,
float, float,
double, double,
uint8_t, uint8_t,
int16_t, int16_t,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::bfloat16, paddle::platform::bfloat16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(empty, PT_REGISTER_KERNEL(empty,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::EmptyKernel, pten::EmptyKernel,
float, float,
double, double,
uint8_t, uint8_t,
int16_t, int16_t,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(empty_like, PT_REGISTER_KERNEL(empty_like,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::EmptyLikeKernel, pten::EmptyLikeKernel,
float, float,
double, double,
uint8_t, uint8_t,
int16_t, int16_t,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
#endif #endif
...@@ -33,41 +33,41 @@ void FlattenGradKernel(const Context& dev_ctx, ...@@ -33,41 +33,41 @@ void FlattenGradKernel(const Context& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_CTX_KERNEL(flatten_grad, PT_REGISTER_KERNEL(flatten_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenGradKernel, pten::FlattenGradKernel,
float, float,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(flatten_grad, PT_REGISTER_KERNEL(flatten_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenGradKernel, pten::FlattenGradKernel,
float, float,
paddle::platform::float16, paddle::platform::float16,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PT_REGISTER_CTX_KERNEL(flatten_grad, PT_REGISTER_KERNEL(flatten_grad,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenGradKernel, pten::FlattenGradKernel,
float, float,
paddle::platform::float16, paddle::platform::float16,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
#endif #endif
...@@ -48,72 +48,72 @@ void FlattenWithXShape(const Context& dev_ctx, ...@@ -48,72 +48,72 @@ void FlattenWithXShape(const Context& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_CTX_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenKernel, pten::FlattenKernel,
float, float,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_CTX_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(flatten_with_xshape,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
float, float,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenKernel, pten::FlattenKernel,
float, float,
paddle::platform::float16, paddle::platform::float16,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_CTX_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(flatten_with_xshape,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
float, float,
paddle::platform::float16, paddle::platform::float16,
double, double,
uint8_t, uint8_t,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PT_REGISTER_CTX_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenKernel, pten::FlattenKernel,
float, float,
paddle::platform::float16, paddle::platform::float16,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_CTX_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(flatten_with_xshape,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
float, float,
paddle::platform::float16, paddle::platform::float16,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
#endif #endif
...@@ -60,24 +60,24 @@ void CastKernel(const Context& dev_ctx, ...@@ -60,24 +60,24 @@ void CastKernel(const Context& dev_ctx,
} // namespace pten } // namespace pten
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \ #define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_CTX_KERNEL(cast, \ PT_REGISTER_KERNEL(cast, \
GPU, \ GPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
pten::CastKernel, \ pten::CastKernel, \
float, \ float, \
double, \ double, \
int, \ int, \
int64_t, \ int64_t, \
int16_t, \ int16_t, \
bool, \ bool, \
uint8_t, \ uint8_t, \
paddle::platform::float16, \ paddle::platform::float16, \
paddle::platform::complex<float>, \ paddle::platform::complex<float>, \
paddle::platform::complex<double>, \ paddle::platform::complex<double>, \
##__VA_ARGS__) { \ ##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \ kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED); \ paddle::experimental::DataType::UNDEFINED); \
} }
#if !defined(PADDLE_WITH_HIP) #if !defined(PADDLE_WITH_HIP)
......
...@@ -21,14 +21,14 @@ ...@@ -21,14 +21,14 @@
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(conj, PT_REGISTER_KERNEL(conj,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::ConjKernel, pten::ConjKernel,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>, paddle::platform::complex<double>,
float, float,
double, double,
int, int,
int64_t) {} int64_t) {}
...@@ -20,13 +20,13 @@ limitations under the License. */ ...@@ -20,13 +20,13 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(dot_grad, PT_REGISTER_KERNEL(dot_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::DotGradKernel, pten::DotGradKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
...@@ -52,13 +52,13 @@ void DotKernel(const Context& dev_ctx, ...@@ -52,13 +52,13 @@ void DotKernel(const Context& dev_ctx,
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(dot, PT_REGISTER_KERNEL(dot,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::DotKernel, pten::DotKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128) {}
...@@ -18,28 +18,28 @@ limitations under the License. */ ...@@ -18,28 +18,28 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h" #include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(full, PT_REGISTER_KERNEL(full,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FullKernel, pten::FullKernel,
float, float,
double, double,
uint8_t, uint8_t,
int16_t, int16_t,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(full_like, PT_REGISTER_KERNEL(full_like,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FullLikeKernel, pten::FullLikeKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
bool, bool,
paddle::platform::float16) {} paddle::platform::float16) {}
...@@ -110,64 +110,64 @@ using float16 = paddle::platform::float16; ...@@ -110,64 +110,64 @@ using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL( PT_REGISTER_KERNEL(
mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {} mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {}
PT_REGISTER_CTX_KERNEL(add, PT_REGISTER_KERNEL(add,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::AddKernel, pten::AddKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
float16, float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(subtract, PT_REGISTER_KERNEL(subtract,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::SubtractKernel, pten::SubtractKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
float16, float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(divide, PT_REGISTER_KERNEL(divide,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::DivideKernel, pten::DivideKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
float16, float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(multiply, PT_REGISTER_KERNEL(multiply,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MultiplyKernel, pten::MultiplyKernel,
float, float,
double, double,
int, int,
int64_t, int64_t,
bool, bool,
float16, float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(sum, PT_REGISTER_KERNEL(sum,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::SumKernel, pten::SumKernel,
bool, bool,
float, float,
double, double,
float16, float16,
int, int,
int64_t, int64_t,
complex64, complex64,
complex128) { complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
} }
...@@ -19,32 +19,32 @@ limitations under the License. */ ...@@ -19,32 +19,32 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul_grad, PT_REGISTER_KERNEL(matmul_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulGradKernel, pten::MatmulGradKernel,
float, float,
double, double,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_double_grad, PT_REGISTER_KERNEL(matmul_double_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulDoubleGradKernel, pten::MatmulDoubleGradKernel,
float, float,
double, double,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_triple_grad, PT_REGISTER_KERNEL(matmul_triple_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulTripleGradKernel, pten::MatmulTripleGradKernel,
float, float,
double, double,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
...@@ -20,12 +20,12 @@ limitations under the License. */ ...@@ -20,12 +20,12 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h" #include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul, PT_REGISTER_KERNEL(matmul,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulKernel, pten::MatmulKernel,
float, float,
double, double,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
...@@ -64,15 +64,15 @@ void ScaleKernel(const ContextT& dev_ctx, ...@@ -64,15 +64,15 @@ void ScaleKernel(const ContextT& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_CTX_KERNEL(scale, PT_REGISTER_KERNEL(scale,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::ScaleKernel, pten::ScaleKernel,
float, float,
double, double,
paddle::platform::float16, paddle::platform::float16,
uint8_t, uint8_t,
int8_t, int8_t,
int16_t, int16_t,
int, int,
int64_t) {} int64_t) {}
...@@ -23,5 +23,5 @@ limitations under the License. */ ...@@ -23,5 +23,5 @@ limitations under the License. */
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
PT_REGISTER_CTX_KERNEL( PT_REGISTER_KERNEL(
sign, GPU, ALL_LAYOUT, pten::SignKernel, float, double, float16) {} sign, GPU, ALL_LAYOUT, pten::SignKernel, float, double, float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册