未验证 提交 c9da845f 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Polish kernel register marco design (#38078)

* polish register marco

* resolve compile failed

* revert needless change

* revert eager related change

* revert eager related change

* change register marco name

* polish deetails
上级 206a33b3
...@@ -20,18 +20,18 @@ limitations under the License. */ ...@@ -20,18 +20,18 @@ limitations under the License. */
// the kernel declare statement is automatically generated according to the // the kernel declare statement is automatically generated according to the
// file name of the kernel, and this header file will be removed // file name of the kernel, and this header file will be removed
PT_DECLARE_KERNEL(full_like, CPU); PT_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CPU); PT_DECLARE_KERNEL(dot, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CPU); PT_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CPU); PT_DECLARE_KERNEL(sign, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(full_like, CUDA); PT_DECLARE_KERNEL(full_like, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CUDA); PT_DECLARE_KERNEL(dot, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CUDA); PT_DECLARE_KERNEL(flatten, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CUDA); PT_DECLARE_KERNEL(sign, CUDA, ALL_LAYOUT);
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(flatten, XPU); PT_DECLARE_KERNEL(flatten, XPU, ALL_LAYOUT);
#endif #endif
...@@ -25,14 +25,14 @@ limitations under the License. */ ...@@ -25,14 +25,14 @@ limitations under the License. */
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h" #include "paddle/pten/include/infermeta.h"
PT_DECLARE_KERNEL(copy, CPU); PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(copy, CUDA); PT_DECLARE_KERNEL(copy, CUDA, ALL_LAYOUT);
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(copy, XPU); PT_DECLARE_KERNEL(copy, XPU, ALL_LAYOUT);
#endif #endif
namespace paddle { namespace paddle {
......
...@@ -37,7 +37,6 @@ namespace experimental { ...@@ -37,7 +37,6 @@ namespace experimental {
* in the future * in the future
*/ */
enum class Backend : uint8_t { enum class Backend : uint8_t {
// kernel backend cannot be undefined
UNDEFINED = 0, UNDEFINED = 0,
// basic kernel backend // basic kernel backend
...@@ -54,6 +53,42 @@ enum class Backend : uint8_t { ...@@ -54,6 +53,42 @@ enum class Backend : uint8_t {
// end of backend types // end of backend types
NUM_BACKENDS, NUM_BACKENDS,
/**
* [ Why we need ALL in baisc kernel key member? ]
*
* For Tensor, ALL represents an illegal Backend, but for Kernel, some
* kernels may be device-independent by nature, such as reshape; and when
* and some kernels are also device-independent when implemented based on
* primitive API.
*
* In this case, we need to provide a more concise registration method,
* instead of registering the kernels for each device with almost
* repetitive code, we need one registration covers all situations,
* so if we provide the ALL field with Register the kernel in this statement.
*
* Of course, we have also considered solving this problem through different
* named macros, for example, if we define
*
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND
*
* Based on this design pattern, the dtype and layout also have the same
* requirements, this cause we need to define a series of macros
*
* PT_REGISTER_KERNEL_FOR_ALL_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_LAYOUT
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_LAYOUT_AND_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT_AND_DTYPE
*
* It makes the system of registering macros more complicated, we think
* this is not a simple design, so we still adopt the design of providing
* the ALL field.
*
* Note: ALL_BACKEND only used for Kernel registration and selection
*/
ALL_BACKEND = UNDEFINED,
}; };
inline std::ostream& operator<<(std::ostream& os, Backend backend) { inline std::ostream& operator<<(std::ostream& os, Backend backend) {
......
...@@ -45,7 +45,9 @@ enum class DataType { ...@@ -45,7 +45,9 @@ enum class DataType {
FLOAT64, FLOAT64,
COMPLEX64, COMPLEX64,
COMPLEX128, COMPLEX128,
NUM_DATA_TYPES NUM_DATA_TYPES,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_DTYPE = UNDEFINED,
}; };
inline size_t SizeOf(DataType data_type) { inline size_t SizeOf(DataType data_type) {
......
...@@ -20,11 +20,14 @@ namespace experimental { ...@@ -20,11 +20,14 @@ namespace experimental {
enum class DataLayout { enum class DataLayout {
UNDEFINED = 0, UNDEFINED = 0,
ANY, // TODO(chenweihang): keep ANY for compatibility, remove it later
ANY = UNDEFINED,
NHWC, NHWC,
NCHW, NCHW,
MKLDNN, MKLDNN,
NUM_DATA_LAYOUTS, NUM_DATA_LAYOUTS,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_LAYOUT = UNDEFINED,
}; };
inline std::ostream& operator<<(std::ostream& os, DataLayout layout) { inline std::ostream& operator<<(std::ostream& os, DataLayout layout) {
...@@ -32,9 +35,6 @@ inline std::ostream& operator<<(std::ostream& os, DataLayout layout) { ...@@ -32,9 +35,6 @@ inline std::ostream& operator<<(std::ostream& os, DataLayout layout) {
case DataLayout::UNDEFINED: case DataLayout::UNDEFINED:
os << "Undefined"; os << "Undefined";
break; break;
case DataLayout::ANY:
os << "Any";
break;
case DataLayout::NHWC: case DataLayout::NHWC:
os << "NHWC"; os << "NHWC";
break; break;
......
...@@ -93,6 +93,8 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -93,6 +93,8 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
} }
}; };
// TODO(chenweihang): Polish the kernel selection logic, support the selection
// of ALL_DTYPE kernel, and simplify the constructor
struct KernelRegistrar { struct KernelRegistrar {
public: public:
KernelRegistrar(const char* kernel_name_cstr, KernelRegistrar(const char* kernel_name_cstr,
...@@ -206,28 +208,33 @@ struct KernelRegistrar { ...@@ -206,28 +208,33 @@ struct KernelRegistrar {
* registration with only data type as template parameter, and the function * registration with only data type as template parameter, and the function
* pointer of the corresponding data type is automatically instantiated * pointer of the corresponding data type is automatically instantiated
* during registration. * during registration.
*
* Note: `1TA` means `1 template argument`
*/ */
#define PT_REGISTER_KERNEL( \ #define PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_kernel_ns_check_##kernel_name, \ pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_KERNEL must be called in global namespace."); \ "PT_REGISTER_KERNEL must be called in global namespace."); \
_PT_REGISTER_KERNEL( \ _PT_REGISTER_1TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__) kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32 #ifndef _WIN32
#define _PT_REGISTER_KERNEL( \ #define _PT_REGISTER_1TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \ PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
PT_KERNEL_REGISTRAR_INIT(kernel_name, \ ::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \
backend, \ backend, \
layout, \ layout, \
&__PT_KERNEL_args_def_FN_##kernel_name, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
__VA_ARGS__); \ __VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel) void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#else #else
/** /**
* `template decltype(fn) fn` can work on gcc and clang, * `template decltype(fn) fn` can work on gcc and clang,
...@@ -241,17 +248,20 @@ struct KernelRegistrar { ...@@ -241,17 +248,20 @@ struct KernelRegistrar {
* *
* And msvc can work without template instantiation * And msvc can work without template instantiation
*/ */
#define _PT_REGISTER_KERNEL( \ #define _PT_REGISTER_1TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
PT_KERNEL_REGISTRAR_INIT(kernel_name, \ ::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \
backend, \ backend, \
layout, \ layout, \
&__PT_KERNEL_args_def_FN_##kernel_name, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
__VA_ARGS__); \ __VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel) void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#endif #endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \ #define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \
...@@ -334,9 +344,9 @@ struct KernelRegistrar { ...@@ -334,9 +344,9 @@ struct KernelRegistrar {
...) \ ...) \
PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \ PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \
kernel_name, \ kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...@@ -345,15 +355,15 @@ struct KernelRegistrar { ...@@ -345,15 +355,15 @@ struct KernelRegistrar {
// clang-format on // clang-format on
#define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -363,17 +373,17 @@ struct KernelRegistrar { ...@@ -363,17 +373,17 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -384,22 +394,22 @@ struct KernelRegistrar { ...@@ -384,22 +394,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -410,22 +420,22 @@ struct KernelRegistrar { ...@@ -410,22 +420,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -436,22 +446,22 @@ struct KernelRegistrar { ...@@ -436,22 +446,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -462,22 +472,22 @@ struct KernelRegistrar { ...@@ -462,22 +472,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -488,22 +498,22 @@ struct KernelRegistrar { ...@@ -488,22 +498,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -514,22 +524,22 @@ struct KernelRegistrar { ...@@ -514,22 +524,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -540,22 +550,22 @@ struct KernelRegistrar { ...@@ -540,22 +550,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -566,22 +576,22 @@ struct KernelRegistrar { ...@@ -566,22 +576,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -592,22 +602,22 @@ struct KernelRegistrar { ...@@ -592,22 +602,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -618,22 +628,22 @@ struct KernelRegistrar { ...@@ -618,22 +628,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -644,22 +654,22 @@ struct KernelRegistrar { ...@@ -644,22 +654,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -670,22 +680,22 @@ struct KernelRegistrar { ...@@ -670,22 +680,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -696,22 +706,22 @@ struct KernelRegistrar { ...@@ -696,22 +706,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \
registrar_id, \
backend, \ backend, \
layout, \ layout, \
registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
...@@ -722,68 +732,55 @@ struct KernelRegistrar { ...@@ -722,68 +732,55 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
PT_ID, \
backend, \ backend, \
layout, \ layout, \
PT_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
/** PT_REGISTER_SINGLE_KERNEL /** PT_REGISTER_NO_TEMPLATE_KERNEL
* *
* Used to register a single kernel, pass in the complete function pointer * Basic Kernel register marco, used to register a no template argument kernel
* of the kernel, this registration macro will not do automatic template * function, pass in the complete function pointe of the kernel, this
* instantiation. * registration macro will not do automatic template instantiation.
*/
#define PT_REGISTER_SINGLE_KERNEL( \
kernel_name, backend, layout, dtype, kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_single_kernel_ns_check_##kernel_name, \
"PT_REGISTER_SINGLE_KERNEL must be called in global namespace."); \
static void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \
static const ::pten::KernelRegistrar __reg_pt_single_kernel_##kernel_name( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
DATATYPE(dtype), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
args_def_fn, \
PT_KERNEL(kernel_fn), \
PT_VARIADIC_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \
void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*)
/** PT_REGISTER_KERNEL_ALL_DTYPE
* *
* Used to register a kernel that supports all data types, such as copy and * Note: developer maybe register 2 kernel with same name, backend and diff
* reshape that are not sensitive to data types. * layout, so the layout also need to be a part of symbol var name. If developer
* register 2 kernel with same name, backend, layout and diff dtype, he should
* use another register marco PT_REGISTER_KERNEL.
*/ */
#define PT_REGISTER_KERNEL_ALL_DTYPE(kernel_name, backend, layout, kernel_fn) \ #define PT_REGISTER_NO_TEMPLATE_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_kernel_all_dtype_ns_check_##kernel_name, \ pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_KERNEL_ALL_DTYPE must be called in global namespace."); \ "PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \
static void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \ ::pten::Kernel*); \
static const ::pten::KernelRegistrar \ static const ::pten::KernelRegistrar \
__reg_pt_kernel_all_dtype_##kernel_name( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout( \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \ ::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&__PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
PT_KERNEL(kernel_fn), \ PT_KERNEL(kernel_fn), \
PT_VARIADIC_KERNEL(kernel_fn)); \ PT_VARIADIC_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \ int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \
void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name(::pten::Kernel* kernel) return 0; \
} \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
/** PT_DECLARE_KERNEL /** PT_DECLARE_KERNEL
* *
* Used to export the symbols of the file where the kernel is located, * Used to export the symbols of the file where the kernel is located,
* to avoid being removed by linker * to avoid being removed by linker
*/ */
#define PT_DECLARE_KERNEL(kernel_name, backend) \ #define PT_DECLARE_KERNEL(kernel_name, backend, layout) \
extern int TouchKernelSymbolFor_##kernel_name##_##backend(); \ extern int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout(); \
UNUSED static int __declare_kernel_symbol_for_##kernel_name##_##backend = \ UNUSED static int \
TouchKernelSymbolFor_##kernel_name##_##backend() __declare_kernel_symbol_for_##kernel_name##_##backend##_##layout = \
TouchKernelSymbolFor_##kernel_name##_##backend##_##layout()
} // namespace pten } // namespace pten
...@@ -63,7 +63,7 @@ void FillConstant(const CPUContext& dev_ctx, ...@@ -63,7 +63,7 @@ void FillConstant(const CPUContext& dev_ctx,
PT_REGISTER_KERNEL(full_like, PT_REGISTER_KERNEL(full_like,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::FillAnyLike, pten::FillAnyLike,
float, float,
double, double,
...@@ -74,7 +74,7 @@ PT_REGISTER_KERNEL(full_like, ...@@ -74,7 +74,7 @@ PT_REGISTER_KERNEL(full_like,
PT_REGISTER_KERNEL(full, PT_REGISTER_KERNEL(full,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::FillConstant, pten::FillConstant,
float, float,
double, double,
......
...@@ -75,7 +75,7 @@ using complex128 = ::paddle::platform::complex<double>; ...@@ -75,7 +75,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(dot, PT_REGISTER_KERNEL(dot,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::Dot, pten::Dot,
float, float,
double, double,
...@@ -84,5 +84,11 @@ PT_REGISTER_KERNEL(dot, ...@@ -84,5 +84,11 @@ PT_REGISTER_KERNEL(dot,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL( PT_REGISTER_KERNEL(matmul,
matmul, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {} CPU,
ALL_LAYOUT,
pten::Matmul,
float,
double,
complex64,
complex128) {}
...@@ -85,7 +85,7 @@ void Cast(const CPUContext& dev_ctx, ...@@ -85,7 +85,7 @@ void Cast(const CPUContext& dev_ctx,
PT_REGISTER_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::Flatten, pten::Flatten,
float, float,
double, double,
...@@ -95,7 +95,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -95,7 +95,7 @@ PT_REGISTER_KERNEL(flatten,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(flatten_with_xshape,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
float, float,
double, double,
...@@ -106,7 +106,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape, ...@@ -106,7 +106,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
PT_REGISTER_KERNEL(cast, PT_REGISTER_KERNEL(cast,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::Cast, pten::Cast,
float, float,
double, double,
...@@ -122,8 +122,7 @@ PT_REGISTER_KERNEL(cast, ...@@ -122,8 +122,7 @@ PT_REGISTER_KERNEL(cast,
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
} }
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::Reshape) {} PT_REGISTER_NO_TEMPLATE_KERNEL(
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape, reshape, CPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {}
CPU, PT_REGISTER_NO_TEMPLATE_KERNEL(
ANY, reshape_with_xshape, CPU, ALL_LAYOUT, pten::ReshapeWithXShape, ALL_DTYPE) {}
pten::ReshapeWithXShape) {}
...@@ -111,11 +111,11 @@ using complex128 = ::paddle::platform::complex<double>; ...@@ -111,11 +111,11 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16; // using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_KERNEL(sign, CPU, ANY, pten::Sign, float, double) {} PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ANY, pten::Mean, float, double, bool) {} PT_REGISTER_KERNEL(mean, CPU, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale, PT_REGISTER_KERNEL(scale,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::Scale, pten::Scale,
float, float,
double, double,
...@@ -127,7 +127,7 @@ PT_REGISTER_KERNEL(scale, ...@@ -127,7 +127,7 @@ PT_REGISTER_KERNEL(scale,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(add, PT_REGISTER_KERNEL(add,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::ElementwiseAdd, pten::ElementwiseAdd,
float, float,
double, double,
...@@ -137,7 +137,7 @@ PT_REGISTER_KERNEL(add, ...@@ -137,7 +137,7 @@ PT_REGISTER_KERNEL(add,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(subtract, PT_REGISTER_KERNEL(subtract,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::ElementwiseSub, pten::ElementwiseSub,
float, float,
double, double,
...@@ -147,7 +147,7 @@ PT_REGISTER_KERNEL(subtract, ...@@ -147,7 +147,7 @@ PT_REGISTER_KERNEL(subtract,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(divide, PT_REGISTER_KERNEL(divide,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::ElementwiseDiv, pten::ElementwiseDiv,
float, float,
double, double,
...@@ -157,7 +157,7 @@ PT_REGISTER_KERNEL(divide, ...@@ -157,7 +157,7 @@ PT_REGISTER_KERNEL(divide,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(multiply, PT_REGISTER_KERNEL(multiply,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::ElementwiseMul, pten::ElementwiseMul,
float, float,
double, double,
...@@ -168,7 +168,7 @@ PT_REGISTER_KERNEL(multiply, ...@@ -168,7 +168,7 @@ PT_REGISTER_KERNEL(multiply,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(sum, PT_REGISTER_KERNEL(sum,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::Sum, pten::Sum,
bool, bool,
float, float,
......
...@@ -57,4 +57,4 @@ void Copy(const CPUContext& dev_ctx, ...@@ -57,4 +57,4 @@ void Copy(const CPUContext& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_KERNEL_ALL_DTYPE(copy, CPU, ANY, pten::Copy) {} PT_REGISTER_NO_TEMPLATE_KERNEL(copy, CPU, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
...@@ -64,7 +64,7 @@ void FillConstant(const CUDAContext& dev_ctx, ...@@ -64,7 +64,7 @@ void FillConstant(const CUDAContext& dev_ctx,
PT_REGISTER_KERNEL(full_like, PT_REGISTER_KERNEL(full_like,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::FillAnyLike, pten::FillAnyLike,
float, float,
double, double,
...@@ -75,7 +75,7 @@ PT_REGISTER_KERNEL(full_like, ...@@ -75,7 +75,7 @@ PT_REGISTER_KERNEL(full_like,
PT_REGISTER_KERNEL(full, PT_REGISTER_KERNEL(full,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::FillConstant, pten::FillConstant,
float, float,
double, double,
......
...@@ -60,7 +60,7 @@ using complex128 = ::paddle::platform::complex<double>; ...@@ -60,7 +60,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(dot, PT_REGISTER_KERNEL(dot,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::Dot, pten::Dot,
float, float,
double, double,
...@@ -71,7 +71,7 @@ PT_REGISTER_KERNEL(dot, ...@@ -71,7 +71,7 @@ PT_REGISTER_KERNEL(dot,
PT_REGISTER_KERNEL(matmul, PT_REGISTER_KERNEL(matmul,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::Matmul, pten::Matmul,
float, float,
double, double,
......
...@@ -86,7 +86,7 @@ using float16 = paddle::platform::float16; ...@@ -86,7 +86,7 @@ using float16 = paddle::platform::float16;
PT_REGISTER_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::Flatten, pten::Flatten,
float, float,
float16, float16,
...@@ -97,7 +97,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -97,7 +97,7 @@ PT_REGISTER_KERNEL(flatten,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(flatten_with_xshape,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
float, float,
double, double,
...@@ -109,7 +109,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape, ...@@ -109,7 +109,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \ #define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL(cast, \ PT_REGISTER_KERNEL(cast, \
CUDA, \ CUDA, \
ANY, \ ALL_LAYOUT, \
pten::Cast, \ pten::Cast, \
float, \ float, \
double, \ double, \
...@@ -132,8 +132,6 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16) ...@@ -132,8 +132,6 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif #endif
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::Reshape) {} PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, CUDA, ANY, pten::Reshape, ALL_DTYPE) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape, PT_REGISTER_NO_TEMPLATE_KERNEL(
CUDA, reshape_with_xshape, CUDA, ANY, pten::ReshapeWithXShape, ALL_DTYPE) {}
ANY,
pten::ReshapeWithXShape) {}
...@@ -115,11 +115,12 @@ using float16 = paddle::platform::float16; ...@@ -115,11 +115,12 @@ using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(sign, CUDA, ANY, pten::Sign, float, double, float16) {} PT_REGISTER_KERNEL(sign, CUDA, ALL_LAYOUT, pten::Sign, float, double, float16) {
PT_REGISTER_KERNEL(mean, CUDA, ANY, pten::Mean, float, double, bool) {} }
PT_REGISTER_KERNEL(mean, CUDA, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale, PT_REGISTER_KERNEL(scale,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::Scale, pten::Scale,
float, float,
double, double,
...@@ -131,7 +132,7 @@ PT_REGISTER_KERNEL(scale, ...@@ -131,7 +132,7 @@ PT_REGISTER_KERNEL(scale,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(add, PT_REGISTER_KERNEL(add,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::ElementwiseAdd, pten::ElementwiseAdd,
float, float,
double, double,
...@@ -142,7 +143,7 @@ PT_REGISTER_KERNEL(add, ...@@ -142,7 +143,7 @@ PT_REGISTER_KERNEL(add,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(subtract, PT_REGISTER_KERNEL(subtract,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::ElementwiseSub, pten::ElementwiseSub,
float, float,
double, double,
...@@ -153,7 +154,7 @@ PT_REGISTER_KERNEL(subtract, ...@@ -153,7 +154,7 @@ PT_REGISTER_KERNEL(subtract,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(divide, PT_REGISTER_KERNEL(divide,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::ElementwiseDiv, pten::ElementwiseDiv,
float, float,
double, double,
...@@ -164,7 +165,7 @@ PT_REGISTER_KERNEL(divide, ...@@ -164,7 +165,7 @@ PT_REGISTER_KERNEL(divide,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(multiply, PT_REGISTER_KERNEL(multiply,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::ElementwiseMul, pten::ElementwiseMul,
float, float,
double, double,
...@@ -176,7 +177,7 @@ PT_REGISTER_KERNEL(multiply, ...@@ -176,7 +177,7 @@ PT_REGISTER_KERNEL(multiply,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(sum, PT_REGISTER_KERNEL(sum,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::Sum, pten::Sum,
bool, bool,
float, float,
......
...@@ -234,4 +234,4 @@ void Copy(const CUDAContext& dev_ctx, ...@@ -234,4 +234,4 @@ void Copy(const CUDAContext& dev_ctx,
} }
} // namespace pten } // namespace pten
PT_REGISTER_KERNEL_ALL_DTYPE(copy, CUDA, ANY, pten::Copy) {} PT_REGISTER_NO_TEMPLATE_KERNEL(copy, CUDA, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
...@@ -78,7 +78,7 @@ void ReshapeWithXShape(const XPUContext& dev_ctx, ...@@ -78,7 +78,7 @@ void ReshapeWithXShape(const XPUContext& dev_ctx,
PT_REGISTER_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
XPU, XPU,
ANY, ALL_LAYOUT,
pten::Flatten, pten::Flatten,
float, float,
paddle::platform::float16, paddle::platform::float16,
...@@ -90,7 +90,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -90,7 +90,7 @@ PT_REGISTER_KERNEL(flatten,
PT_REGISTER_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(flatten_with_xshape,
XPU, XPU,
ANY, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
float, float,
paddle::platform::float16, paddle::platform::float16,
...@@ -100,4 +100,5 @@ PT_REGISTER_KERNEL(flatten_with_xshape, ...@@ -100,4 +100,5 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::Reshape) {} PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape, XPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {}
...@@ -76,4 +76,4 @@ void Copy(const XPUDeviceContext& dev_ctx, ...@@ -76,4 +76,4 @@ void Copy(const XPUDeviceContext& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_KERNEL_ALL_DTYPE(copy, XPU, ANY, pten::Copy) {} PT_REGISTER_NO_TEMPLATE_KERNEL(copy, XPU, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
...@@ -28,7 +28,7 @@ TEST(DataLayout, OStream) { ...@@ -28,7 +28,7 @@ TEST(DataLayout, OStream) {
EXPECT_EQ(oss.str(), "Undefined"); EXPECT_EQ(oss.str(), "Undefined");
oss.str(""); oss.str("");
oss << pten::DataLayout::ANY; oss << pten::DataLayout::ANY;
EXPECT_EQ(oss.str(), "Any"); EXPECT_EQ(oss.str(), "Undefined");
oss.str(""); oss.str("");
oss << pten::DataLayout::NHWC; oss << pten::DataLayout::NHWC;
EXPECT_EQ(oss.str(), "NHWC"); EXPECT_EQ(oss.str(), "NHWC");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册