diff --git a/paddle/pten/api/lib/kernel_declare.h b/paddle/pten/api/lib/kernel_declare.h index 8c21094a4af20257cb488a8b74d804dcba2bce87..fa11811178322fac4b8c8b618ffc9fe4506a283e 100644 --- a/paddle/pten/api/lib/kernel_declare.h +++ b/paddle/pten/api/lib/kernel_declare.h @@ -20,18 +20,18 @@ limitations under the License. */ // the kernel declare statement is automatically generated according to the // file name of the kernel, and this header file will be removed -PT_DECLARE_KERNEL(full_like, CPU); -PT_DECLARE_KERNEL(dot, CPU); -PT_DECLARE_KERNEL(flatten, CPU); -PT_DECLARE_KERNEL(sign, CPU); +PT_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(dot, CPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(sign, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PT_DECLARE_KERNEL(full_like, CUDA); -PT_DECLARE_KERNEL(dot, CUDA); -PT_DECLARE_KERNEL(flatten, CUDA); -PT_DECLARE_KERNEL(sign, CUDA); +PT_DECLARE_KERNEL(full_like, CUDA, ALL_LAYOUT); +PT_DECLARE_KERNEL(dot, CUDA, ALL_LAYOUT); +PT_DECLARE_KERNEL(flatten, CUDA, ALL_LAYOUT); +PT_DECLARE_KERNEL(sign, CUDA, ALL_LAYOUT); #endif #ifdef PADDLE_WITH_XPU -PT_DECLARE_KERNEL(flatten, XPU); +PT_DECLARE_KERNEL(flatten, XPU, ALL_LAYOUT); #endif diff --git a/paddle/pten/api/lib/utils.cc b/paddle/pten/api/lib/utils.cc index bfde9b14b0020d9a726bdd1fb6f211274a4e39d2..683eb4e5b0c01df14e7777b260bf2f1b0f53188e 100644 --- a/paddle/pten/api/lib/utils.cc +++ b/paddle/pten/api/lib/utils.cc @@ -25,14 +25,14 @@ limitations under the License. */ #include "paddle/pten/include/core.h" #include "paddle/pten/include/infermeta.h" -PT_DECLARE_KERNEL(copy, CPU); +PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PT_DECLARE_KERNEL(copy, CUDA); +PT_DECLARE_KERNEL(copy, CUDA, ALL_LAYOUT); #endif #ifdef PADDLE_WITH_XPU -PT_DECLARE_KERNEL(copy, XPU); +PT_DECLARE_KERNEL(copy, XPU, ALL_LAYOUT); #endif namespace paddle { diff --git a/paddle/pten/common/backend.h b/paddle/pten/common/backend.h index 94080701e28166c20952ba09d8eccb1394508923..95bbc88681a965edbee66f85c38bdc08cf461fd8 100644 --- a/paddle/pten/common/backend.h +++ b/paddle/pten/common/backend.h @@ -37,7 +37,6 @@ namespace experimental { * in the future */ enum class Backend : uint8_t { - // kernel backend cannot be undefined UNDEFINED = 0, // basic kernel backend @@ -54,6 +53,42 @@ enum class Backend : uint8_t { // end of backend types NUM_BACKENDS, + + /** + * [ Why we need ALL in baisc kernel key member? ] + * + * For Tensor, ALL represents an illegal Backend, but for Kernel, some + * kernels may be device-independent by nature, such as reshape; and when + * and some kernels are also device-independent when implemented based on + * primitive API. + * + * In this case, we need to provide a more concise registration method, + * instead of registering the kernels for each device with almost + * repetitive code, we need one registration covers all situations, + * so if we provide the ALL field with Register the kernel in this statement. + * + * Of course, we have also considered solving this problem through different + * named macros, for example, if we define + * + * PT_REGISTER_KERNEL_FOR_ALL_BACKEND + * + * Based on this design pattern, the dtype and layout also have the same + * requirements, this cause we need to define a series of macros + * + * PT_REGISTER_KERNEL_FOR_ALL_DTYPE + * PT_REGISTER_KERNEL_FOR_ALL_LAYOUT + * PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT + * PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_DTYPE + * PT_REGISTER_KERNEL_FOR_ALL_LAYOUT_AND_DTYPE + * PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT_AND_DTYPE + * + * It makes the system of registering macros more complicated, we think + * this is not a simple design, so we still adopt the design of providing + * the ALL field. + * + * Note: ALL_BACKEND only used for Kernel registration and selection + */ + ALL_BACKEND = UNDEFINED, }; inline std::ostream& operator<<(std::ostream& os, Backend backend) { diff --git a/paddle/pten/common/data_type.h b/paddle/pten/common/data_type.h index 1ddee0746d4d16e8d31c52f2442b31955709f3a3..a00d68c535415daf48e40d5fa9b4624459dde5d0 100644 --- a/paddle/pten/common/data_type.h +++ b/paddle/pten/common/data_type.h @@ -45,7 +45,9 @@ enum class DataType { FLOAT64, COMPLEX64, COMPLEX128, - NUM_DATA_TYPES + NUM_DATA_TYPES, + // See Note [ Why we need ALL in baisc kernel key member? ] + ALL_DTYPE = UNDEFINED, }; inline size_t SizeOf(DataType data_type) { diff --git a/paddle/pten/common/layout.h b/paddle/pten/common/layout.h index b93e2623ee7d0f8cbcfd3ecb410799d2a31cd2f1..b7c151e7e6a7c87ec5603d9cb4faf5e8125716e0 100644 --- a/paddle/pten/common/layout.h +++ b/paddle/pten/common/layout.h @@ -20,11 +20,14 @@ namespace experimental { enum class DataLayout { UNDEFINED = 0, - ANY, + // TODO(chenweihang): keep ANY for compatibility, remove it later + ANY = UNDEFINED, NHWC, NCHW, MKLDNN, NUM_DATA_LAYOUTS, + // See Note [ Why we need ALL in baisc kernel key member? ] + ALL_LAYOUT = UNDEFINED, }; inline std::ostream& operator<<(std::ostream& os, DataLayout layout) { @@ -32,9 +35,6 @@ inline std::ostream& operator<<(std::ostream& os, DataLayout layout) { case DataLayout::UNDEFINED: os << "Undefined"; break; - case DataLayout::ANY: - os << "Any"; - break; case DataLayout::NHWC: os << "NHWC"; break; diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index 80ebd5b832a393d9a3b15bf682fbf915c37354df..f46e9afd3defbd2293bc86e62c2dbcb2e4cf74e1 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -93,6 +93,8 @@ struct KernelArgsParseFunctor { } }; +// TODO(chenweihang): Polish the kernel selection logic, support the selection +// of ALL_DTYPE kernel, and simplify the constructor struct KernelRegistrar { public: KernelRegistrar(const char* kernel_name_cstr, @@ -206,28 +208,33 @@ struct KernelRegistrar { * registration with only data type as template parameter, and the function * pointer of the corresponding data type is automatically instantiated * during registration. + * + * Note: `1TA` means `1 template argument` */ -#define PT_REGISTER_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - pt_register_kernel_ns_check_##kernel_name, \ - "PT_REGISTER_KERNEL must be called in global namespace."); \ - _PT_REGISTER_KERNEL( \ +#define PT_REGISTER_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ + "PT_REGISTER_KERNEL must be called in global namespace."); \ + _PT_REGISTER_1TA_KERNEL( \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__) #ifndef _WIN32 -#define _PT_REGISTER_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \ - static void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \ - PT_KERNEL_REGISTRAR_INIT(kernel_name, \ - backend, \ - layout, \ - &__PT_KERNEL_args_def_FN_##kernel_name, \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__); \ - void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel) +#define _PT_REGISTER_1TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + PT_KERNEL_REGISTRAR_INIT( \ + kernel_name, \ + backend, \ + layout, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__); \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) #else /** * `template decltype(fn) fn` can work on gcc and clang, @@ -241,17 +248,20 @@ struct KernelRegistrar { * * And msvc can work without template instantiation */ -#define _PT_REGISTER_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - static void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \ - PT_KERNEL_REGISTRAR_INIT(kernel_name, \ - backend, \ - layout, \ - &__PT_KERNEL_args_def_FN_##kernel_name, \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__); \ - void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel) +#define _PT_REGISTER_1TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + PT_KERNEL_REGISTRAR_INIT( \ + kernel_name, \ + backend, \ + layout, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__); \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) #endif #define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \ @@ -334,9 +344,9 @@ struct KernelRegistrar { ...) \ PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \ kernel_name, \ - PT_ID, \ backend, \ layout, \ + PT_ID, \ args_def_fn, \ meta_kernel_fn, \ cpp_dtype, \ @@ -344,446 +354,433 @@ struct KernelRegistrar { // clang-format on -#define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } -#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; } +#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -/** PT_REGISTER_SINGLE_KERNEL +/** PT_REGISTER_NO_TEMPLATE_KERNEL * - * Used to register a single kernel, pass in the complete function pointer - * of the kernel, this registration macro will not do automatic template - * instantiation. - */ -#define PT_REGISTER_SINGLE_KERNEL( \ - kernel_name, backend, layout, dtype, kernel_fn) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - pt_register_single_kernel_ns_check_##kernel_name, \ - "PT_REGISTER_SINGLE_KERNEL must be called in global namespace."); \ - static void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \ - static const ::pten::KernelRegistrar __reg_pt_single_kernel_##kernel_name( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - DATATYPE(dtype), \ - ::pten::KernelArgsParseFunctor::Parse, \ - args_def_fn, \ - PT_KERNEL(kernel_fn), \ - PT_VARIADIC_KERNEL(kernel_fn)); \ - int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \ - void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*) - -/** PT_REGISTER_KERNEL_ALL_DTYPE + * Basic Kernel register marco, used to register a no template argument kernel + * function, pass in the complete function pointe of the kernel, this + * registration macro will not do automatic template instantiation. * - * Used to register a kernel that supports all data types, such as copy and - * reshape that are not sensitive to data types. + * Note: developer maybe register 2 kernel with same name, backend and diff + * layout, so the layout also need to be a part of symbol var name. If developer + * register 2 kernel with same name, backend, layout and diff dtype, he should + * use another register marco PT_REGISTER_KERNEL. */ -#define PT_REGISTER_KERNEL_ALL_DTYPE(kernel_name, backend, layout, kernel_fn) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - pt_register_kernel_all_dtype_ns_check_##kernel_name, \ - "PT_REGISTER_KERNEL_ALL_DTYPE must be called in global namespace."); \ - static void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name( \ - ::pten::Kernel*); \ - static const ::pten::KernelRegistrar \ - __reg_pt_kernel_all_dtype_##kernel_name( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pten::KernelArgsParseFunctor::Parse, \ - &__PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name, \ - PT_KERNEL(kernel_fn), \ - PT_VARIADIC_KERNEL(kernel_fn)); \ - int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \ - void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name(::pten::Kernel* kernel) +#define PT_REGISTER_NO_TEMPLATE_KERNEL( \ + kernel_name, backend, layout, kernel_fn, dtype) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \ + "PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + static const ::pten::KernelRegistrar \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::pten::KernelArgsParseFunctor::Parse, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + PT_KERNEL(kernel_fn), \ + PT_VARIADIC_KERNEL(kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \ + return 0; \ + } \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) /** PT_DECLARE_KERNEL * * Used to export the symbols of the file where the kernel is located, * to avoid being removed by linker */ -#define PT_DECLARE_KERNEL(kernel_name, backend) \ - extern int TouchKernelSymbolFor_##kernel_name##_##backend(); \ - UNUSED static int __declare_kernel_symbol_for_##kernel_name##_##backend = \ - TouchKernelSymbolFor_##kernel_name##_##backend() +#define PT_DECLARE_KERNEL(kernel_name, backend, layout) \ + extern int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout(); \ + UNUSED static int \ + __declare_kernel_symbol_for_##kernel_name##_##backend##_##layout = \ + TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() } // namespace pten diff --git a/paddle/pten/kernels/cpu/creation.cc b/paddle/pten/kernels/cpu/creation.cc index 4175203410f8da092cb2aa7673d5651d17245fbb..f21c322e2db616226f3c2c52276446925ae1483e 100644 --- a/paddle/pten/kernels/cpu/creation.cc +++ b/paddle/pten/kernels/cpu/creation.cc @@ -63,7 +63,7 @@ void FillConstant(const CPUContext& dev_ctx, PT_REGISTER_KERNEL(full_like, CPU, - ANY, + ALL_LAYOUT, pten::FillAnyLike, float, double, @@ -74,7 +74,7 @@ PT_REGISTER_KERNEL(full_like, PT_REGISTER_KERNEL(full, CPU, - ANY, + ALL_LAYOUT, pten::FillConstant, float, double, diff --git a/paddle/pten/kernels/cpu/linalg.cc b/paddle/pten/kernels/cpu/linalg.cc index 9f4f1be18259a53a5f945302112b961c757e7036..87c4078896a18b66bfdd01ac5f2d5d7535fcb533 100644 --- a/paddle/pten/kernels/cpu/linalg.cc +++ b/paddle/pten/kernels/cpu/linalg.cc @@ -75,7 +75,7 @@ using complex128 = ::paddle::platform::complex; PT_REGISTER_KERNEL(dot, CPU, - ANY, + ALL_LAYOUT, pten::Dot, float, double, @@ -84,5 +84,11 @@ PT_REGISTER_KERNEL(dot, complex64, complex128) {} -PT_REGISTER_KERNEL( - matmul, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {} +PT_REGISTER_KERNEL(matmul, + CPU, + ALL_LAYOUT, + pten::Matmul, + float, + double, + complex64, + complex128) {} diff --git a/paddle/pten/kernels/cpu/manipulation.cc b/paddle/pten/kernels/cpu/manipulation.cc index 9c34f84233731c0b895e12e236e995440fce5014..32bc8e4e35d7bc80bd9e3deab65ba98656393359 100644 --- a/paddle/pten/kernels/cpu/manipulation.cc +++ b/paddle/pten/kernels/cpu/manipulation.cc @@ -85,7 +85,7 @@ void Cast(const CPUContext& dev_ctx, PT_REGISTER_KERNEL(flatten, CPU, - ANY, + ALL_LAYOUT, pten::Flatten, float, double, @@ -95,7 +95,7 @@ PT_REGISTER_KERNEL(flatten, int64_t) {} PT_REGISTER_KERNEL(flatten_with_xshape, CPU, - ANY, + ALL_LAYOUT, pten::FlattenWithXShape, float, double, @@ -106,7 +106,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(cast, CPU, - ANY, + ALL_LAYOUT, pten::Cast, float, double, @@ -122,8 +122,7 @@ PT_REGISTER_KERNEL(cast, kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } -PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::Reshape) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape, - CPU, - ANY, - pten::ReshapeWithXShape) {} +PT_REGISTER_NO_TEMPLATE_KERNEL( + reshape, CPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {} +PT_REGISTER_NO_TEMPLATE_KERNEL( + reshape_with_xshape, CPU, ALL_LAYOUT, pten::ReshapeWithXShape, ALL_DTYPE) {} diff --git a/paddle/pten/kernels/cpu/math.cc b/paddle/pten/kernels/cpu/math.cc index 2d556d96c2fcf7effa1dbe7f0f31adeded6a403a..616058d5ace1fc2ac4324afe2f176cd3bc44eac7 100644 --- a/paddle/pten/kernels/cpu/math.cc +++ b/paddle/pten/kernels/cpu/math.cc @@ -111,11 +111,11 @@ using complex128 = ::paddle::platform::complex; // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // using bfloat16 = ::paddle::platform::bfloat16; -PT_REGISTER_KERNEL(sign, CPU, ANY, pten::Sign, float, double) {} -PT_REGISTER_KERNEL(mean, CPU, ANY, pten::Mean, float, double, bool) {} +PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::Sign, float, double) {} +PT_REGISTER_KERNEL(mean, CPU, ALL_LAYOUT, pten::Mean, float, double, bool) {} PT_REGISTER_KERNEL(scale, CPU, - ANY, + ALL_LAYOUT, pten::Scale, float, double, @@ -127,7 +127,7 @@ PT_REGISTER_KERNEL(scale, int64_t) {} PT_REGISTER_KERNEL(add, CPU, - ANY, + ALL_LAYOUT, pten::ElementwiseAdd, float, double, @@ -137,7 +137,7 @@ PT_REGISTER_KERNEL(add, complex128) {} PT_REGISTER_KERNEL(subtract, CPU, - ANY, + ALL_LAYOUT, pten::ElementwiseSub, float, double, @@ -147,7 +147,7 @@ PT_REGISTER_KERNEL(subtract, complex128) {} PT_REGISTER_KERNEL(divide, CPU, - ANY, + ALL_LAYOUT, pten::ElementwiseDiv, float, double, @@ -157,7 +157,7 @@ PT_REGISTER_KERNEL(divide, complex128) {} PT_REGISTER_KERNEL(multiply, CPU, - ANY, + ALL_LAYOUT, pten::ElementwiseMul, float, double, @@ -168,7 +168,7 @@ PT_REGISTER_KERNEL(multiply, complex128) {} PT_REGISTER_KERNEL(sum, CPU, - ANY, + ALL_LAYOUT, pten::Sum, bool, float, diff --git a/paddle/pten/kernels/cpu/utils.cc b/paddle/pten/kernels/cpu/utils.cc index 500b4664d638885d8faba393e267c2fb0e556577..1ca20df4d92dcbc89008c5befa0c6bfb37c36de5 100644 --- a/paddle/pten/kernels/cpu/utils.cc +++ b/paddle/pten/kernels/cpu/utils.cc @@ -57,4 +57,4 @@ void Copy(const CPUContext& dev_ctx, } // namespace pten -PT_REGISTER_KERNEL_ALL_DTYPE(copy, CPU, ANY, pten::Copy) {} +PT_REGISTER_NO_TEMPLATE_KERNEL(copy, CPU, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {} diff --git a/paddle/pten/kernels/cuda/creation.cu b/paddle/pten/kernels/cuda/creation.cu index dd29fd5fbb84d2536ef63c37c8fb0fd59d285258..95a561d0c94e97fb223af821408a05b1cba3ff4c 100644 --- a/paddle/pten/kernels/cuda/creation.cu +++ b/paddle/pten/kernels/cuda/creation.cu @@ -64,7 +64,7 @@ void FillConstant(const CUDAContext& dev_ctx, PT_REGISTER_KERNEL(full_like, CUDA, - ANY, + ALL_LAYOUT, pten::FillAnyLike, float, double, @@ -75,7 +75,7 @@ PT_REGISTER_KERNEL(full_like, PT_REGISTER_KERNEL(full, CUDA, - ANY, + ALL_LAYOUT, pten::FillConstant, float, double, diff --git a/paddle/pten/kernels/cuda/linalg.cu b/paddle/pten/kernels/cuda/linalg.cu index 2114bbcc71c75e7edd9d3b71a606d0d394aebd59..da6511e2c8708ab85961d1d7a9daed8331d1ea41 100644 --- a/paddle/pten/kernels/cuda/linalg.cu +++ b/paddle/pten/kernels/cuda/linalg.cu @@ -60,7 +60,7 @@ using complex128 = ::paddle::platform::complex; PT_REGISTER_KERNEL(dot, CUDA, - ANY, + ALL_LAYOUT, pten::Dot, float, double, @@ -71,7 +71,7 @@ PT_REGISTER_KERNEL(dot, PT_REGISTER_KERNEL(matmul, CUDA, - ANY, + ALL_LAYOUT, pten::Matmul, float, double, diff --git a/paddle/pten/kernels/cuda/manipulation.cu b/paddle/pten/kernels/cuda/manipulation.cu index d3c0759698eea9ebfb06c1347937713a81dedeb2..49bbf1b61c9916cc8eb6dfc1c4c798930d48c8b0 100644 --- a/paddle/pten/kernels/cuda/manipulation.cu +++ b/paddle/pten/kernels/cuda/manipulation.cu @@ -86,7 +86,7 @@ using float16 = paddle::platform::float16; PT_REGISTER_KERNEL(flatten, CUDA, - ANY, + ALL_LAYOUT, pten::Flatten, float, float16, @@ -97,7 +97,7 @@ PT_REGISTER_KERNEL(flatten, int64_t) {} PT_REGISTER_KERNEL(flatten_with_xshape, CUDA, - ANY, + ALL_LAYOUT, pten::FlattenWithXShape, float, double, @@ -109,7 +109,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape, #define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \ PT_REGISTER_KERNEL(cast, \ CUDA, \ - ANY, \ + ALL_LAYOUT, \ pten::Cast, \ float, \ double, \ @@ -132,8 +132,6 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) #endif -PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::Reshape) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape, - CUDA, - ANY, - pten::ReshapeWithXShape) {} +PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, CUDA, ANY, pten::Reshape, ALL_DTYPE) {} +PT_REGISTER_NO_TEMPLATE_KERNEL( + reshape_with_xshape, CUDA, ANY, pten::ReshapeWithXShape, ALL_DTYPE) {} diff --git a/paddle/pten/kernels/cuda/math.cu b/paddle/pten/kernels/cuda/math.cu index 66aaf14dcd0f62ac2b752d803b136994815963df..b4a60340e00eb2898eb87a9c4fd767ca321a8c59 100644 --- a/paddle/pten/kernels/cuda/math.cu +++ b/paddle/pten/kernels/cuda/math.cu @@ -115,11 +115,12 @@ using float16 = paddle::platform::float16; using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; -PT_REGISTER_KERNEL(sign, CUDA, ANY, pten::Sign, float, double, float16) {} -PT_REGISTER_KERNEL(mean, CUDA, ANY, pten::Mean, float, double, bool) {} +PT_REGISTER_KERNEL(sign, CUDA, ALL_LAYOUT, pten::Sign, float, double, float16) { +} +PT_REGISTER_KERNEL(mean, CUDA, ALL_LAYOUT, pten::Mean, float, double, bool) {} PT_REGISTER_KERNEL(scale, CUDA, - ANY, + ALL_LAYOUT, pten::Scale, float, double, @@ -131,7 +132,7 @@ PT_REGISTER_KERNEL(scale, int64_t) {} PT_REGISTER_KERNEL(add, CUDA, - ANY, + ALL_LAYOUT, pten::ElementwiseAdd, float, double, @@ -142,7 +143,7 @@ PT_REGISTER_KERNEL(add, complex128) {} PT_REGISTER_KERNEL(subtract, CUDA, - ANY, + ALL_LAYOUT, pten::ElementwiseSub, float, double, @@ -153,7 +154,7 @@ PT_REGISTER_KERNEL(subtract, complex128) {} PT_REGISTER_KERNEL(divide, CUDA, - ANY, + ALL_LAYOUT, pten::ElementwiseDiv, float, double, @@ -164,7 +165,7 @@ PT_REGISTER_KERNEL(divide, complex128) {} PT_REGISTER_KERNEL(multiply, CUDA, - ANY, + ALL_LAYOUT, pten::ElementwiseMul, float, double, @@ -176,7 +177,7 @@ PT_REGISTER_KERNEL(multiply, complex128) {} PT_REGISTER_KERNEL(sum, CUDA, - ANY, + ALL_LAYOUT, pten::Sum, bool, float, diff --git a/paddle/pten/kernels/cuda/utils.cu b/paddle/pten/kernels/cuda/utils.cu index 49027e956b2d7dc28b77d72c2db684ae0007a15c..cf1407e7208de5539806fcfc3e9bc5bb34d4b4a6 100644 --- a/paddle/pten/kernels/cuda/utils.cu +++ b/paddle/pten/kernels/cuda/utils.cu @@ -234,4 +234,4 @@ void Copy(const CUDAContext& dev_ctx, } } // namespace pten -PT_REGISTER_KERNEL_ALL_DTYPE(copy, CUDA, ANY, pten::Copy) {} +PT_REGISTER_NO_TEMPLATE_KERNEL(copy, CUDA, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {} diff --git a/paddle/pten/kernels/xpu/manipulation.cc b/paddle/pten/kernels/xpu/manipulation.cc index cee3e5ceedb6dadea24ba6cc96110d2d582b95cd..70ac70371e90acec6187f7575af2e57e3e5b66c0 100644 --- a/paddle/pten/kernels/xpu/manipulation.cc +++ b/paddle/pten/kernels/xpu/manipulation.cc @@ -78,7 +78,7 @@ void ReshapeWithXShape(const XPUContext& dev_ctx, PT_REGISTER_KERNEL(flatten, XPU, - ANY, + ALL_LAYOUT, pten::Flatten, float, paddle::platform::float16, @@ -90,7 +90,7 @@ PT_REGISTER_KERNEL(flatten, PT_REGISTER_KERNEL(flatten_with_xshape, XPU, - ANY, + ALL_LAYOUT, pten::FlattenWithXShape, float, paddle::platform::float16, @@ -100,4 +100,5 @@ PT_REGISTER_KERNEL(flatten_with_xshape, int, int64_t) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::Reshape) {} +PT_REGISTER_NO_TEMPLATE_KERNEL( + reshape, XPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {} diff --git a/paddle/pten/kernels/xpu/utils.cc b/paddle/pten/kernels/xpu/utils.cc index 5c98217f4ec2c41b39729f9df76ea4ffba277286..5ea3a359ef6d69b2fafaa9b76418c770ebfc9de7 100644 --- a/paddle/pten/kernels/xpu/utils.cc +++ b/paddle/pten/kernels/xpu/utils.cc @@ -76,4 +76,4 @@ void Copy(const XPUDeviceContext& dev_ctx, } // namespace pten -PT_REGISTER_KERNEL_ALL_DTYPE(copy, XPU, ANY, pten::Copy) {} +PT_REGISTER_NO_TEMPLATE_KERNEL(copy, XPU, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {} diff --git a/paddle/pten/tests/common/test_data_layout.cc b/paddle/pten/tests/common/test_data_layout.cc index 22c503b5917175b3e83ee5d17aa7ca9116921d7e..66b3e347538960f6b61eac928f4afb342219a98b 100644 --- a/paddle/pten/tests/common/test_data_layout.cc +++ b/paddle/pten/tests/common/test_data_layout.cc @@ -28,7 +28,7 @@ TEST(DataLayout, OStream) { EXPECT_EQ(oss.str(), "Undefined"); oss.str(""); oss << pten::DataLayout::ANY; - EXPECT_EQ(oss.str(), "Any"); + EXPECT_EQ(oss.str(), "Undefined"); oss.str(""); oss << pten::DataLayout::NHWC; EXPECT_EQ(oss.str(), "NHWC");