未验证 提交 c9da845f 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Polish kernel register marco design (#38078)

* polish register marco

* resolve compile failed

* revert needless change

* revert eager related change

* revert eager related change

* change register marco name

* polish deetails
上级 206a33b3
......@@ -20,18 +20,18 @@ limitations under the License. */
// the kernel declare statement is automatically generated according to the
// file name of the kernel, and this header file will be removed
PT_DECLARE_KERNEL(full_like, CPU);
PT_DECLARE_KERNEL(dot, CPU);
PT_DECLARE_KERNEL(flatten, CPU);
PT_DECLARE_KERNEL(sign, CPU);
PT_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(full_like, CUDA);
PT_DECLARE_KERNEL(dot, CUDA);
PT_DECLARE_KERNEL(flatten, CUDA);
PT_DECLARE_KERNEL(sign, CUDA);
PT_DECLARE_KERNEL(full_like, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CUDA, ALL_LAYOUT);
#endif
#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(flatten, XPU);
PT_DECLARE_KERNEL(flatten, XPU, ALL_LAYOUT);
#endif
......@@ -25,14 +25,14 @@ limitations under the License. */
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h"
PT_DECLARE_KERNEL(copy, CPU);
PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(copy, CUDA);
PT_DECLARE_KERNEL(copy, CUDA, ALL_LAYOUT);
#endif
#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(copy, XPU);
PT_DECLARE_KERNEL(copy, XPU, ALL_LAYOUT);
#endif
namespace paddle {
......
......@@ -37,7 +37,6 @@ namespace experimental {
* in the future
*/
enum class Backend : uint8_t {
// kernel backend cannot be undefined
UNDEFINED = 0,
// basic kernel backend
......@@ -54,6 +53,42 @@ enum class Backend : uint8_t {
// end of backend types
NUM_BACKENDS,
/**
* [ Why we need ALL in baisc kernel key member? ]
*
* For Tensor, ALL represents an illegal Backend, but for Kernel, some
* kernels may be device-independent by nature, such as reshape; and when
* and some kernels are also device-independent when implemented based on
* primitive API.
*
* In this case, we need to provide a more concise registration method,
* instead of registering the kernels for each device with almost
* repetitive code, we need one registration covers all situations,
* so if we provide the ALL field with Register the kernel in this statement.
*
* Of course, we have also considered solving this problem through different
* named macros, for example, if we define
*
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND
*
* Based on this design pattern, the dtype and layout also have the same
* requirements, this cause we need to define a series of macros
*
* PT_REGISTER_KERNEL_FOR_ALL_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_LAYOUT
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_LAYOUT_AND_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT_AND_DTYPE
*
* It makes the system of registering macros more complicated, we think
* this is not a simple design, so we still adopt the design of providing
* the ALL field.
*
* Note: ALL_BACKEND only used for Kernel registration and selection
*/
ALL_BACKEND = UNDEFINED,
};
inline std::ostream& operator<<(std::ostream& os, Backend backend) {
......
......@@ -45,7 +45,9 @@ enum class DataType {
FLOAT64,
COMPLEX64,
COMPLEX128,
NUM_DATA_TYPES
NUM_DATA_TYPES,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_DTYPE = UNDEFINED,
};
inline size_t SizeOf(DataType data_type) {
......
......@@ -20,11 +20,14 @@ namespace experimental {
enum class DataLayout {
UNDEFINED = 0,
ANY,
// TODO(chenweihang): keep ANY for compatibility, remove it later
ANY = UNDEFINED,
NHWC,
NCHW,
MKLDNN,
NUM_DATA_LAYOUTS,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_LAYOUT = UNDEFINED,
};
inline std::ostream& operator<<(std::ostream& os, DataLayout layout) {
......@@ -32,9 +35,6 @@ inline std::ostream& operator<<(std::ostream& os, DataLayout layout) {
case DataLayout::UNDEFINED:
os << "Undefined";
break;
case DataLayout::ANY:
os << "Any";
break;
case DataLayout::NHWC:
os << "NHWC";
break;
......
......@@ -93,6 +93,8 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
}
};
// TODO(chenweihang): Polish the kernel selection logic, support the selection
// of ALL_DTYPE kernel, and simplify the constructor
struct KernelRegistrar {
public:
KernelRegistrar(const char* kernel_name_cstr,
......@@ -206,28 +208,33 @@ struct KernelRegistrar {
* registration with only data type as template parameter, and the function
* pointer of the corresponding data type is automatically instantiated
* during registration.
*
* Note: `1TA` means `1 template argument`
*/
#define PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_kernel_ns_check_##kernel_name, \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_KERNEL must be called in global namespace."); \
_PT_REGISTER_KERNEL( \
_PT_REGISTER_1TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32
#define _PT_REGISTER_KERNEL( \
#define _PT_REGISTER_1TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT(kernel_name, \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel)
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#else
/**
* `template decltype(fn) fn` can work on gcc and clang,
......@@ -241,17 +248,20 @@ struct KernelRegistrar {
*
* And msvc can work without template instantiation
*/
#define _PT_REGISTER_KERNEL( \
#define _PT_REGISTER_1TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT(kernel_name, \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel)
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \
......@@ -334,9 +344,9 @@ struct KernelRegistrar {
...) \
PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \
kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
......@@ -345,15 +355,15 @@ struct KernelRegistrar {
// clang-format on
#define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -363,17 +373,17 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; }
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -384,22 +394,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -410,22 +420,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -436,22 +446,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -462,22 +472,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -488,22 +498,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -514,22 +524,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -540,22 +550,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -566,22 +576,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -592,22 +602,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -618,22 +628,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -644,22 +654,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -670,22 +680,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -696,22 +706,22 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \
registrar_id, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
......@@ -722,68 +732,55 @@ struct KernelRegistrar {
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
PT_ID, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
/** PT_REGISTER_SINGLE_KERNEL
/** PT_REGISTER_NO_TEMPLATE_KERNEL
*
* Used to register a single kernel, pass in the complete function pointer
* of the kernel, this registration macro will not do automatic template
* instantiation.
*/
#define PT_REGISTER_SINGLE_KERNEL( \
kernel_name, backend, layout, dtype, kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_single_kernel_ns_check_##kernel_name, \
"PT_REGISTER_SINGLE_KERNEL must be called in global namespace."); \
static void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \
static const ::pten::KernelRegistrar __reg_pt_single_kernel_##kernel_name( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
DATATYPE(dtype), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
args_def_fn, \
PT_KERNEL(kernel_fn), \
PT_VARIADIC_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \
void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*)
/** PT_REGISTER_KERNEL_ALL_DTYPE
* Basic Kernel register marco, used to register a no template argument kernel
* function, pass in the complete function pointe of the kernel, this
* registration macro will not do automatic template instantiation.
*
* Used to register a kernel that supports all data types, such as copy and
* reshape that are not sensitive to data types.
* Note: developer maybe register 2 kernel with same name, backend and diff
* layout, so the layout also need to be a part of symbol var name. If developer
* register 2 kernel with same name, backend, layout and diff dtype, he should
* use another register marco PT_REGISTER_KERNEL.
*/
#define PT_REGISTER_KERNEL_ALL_DTYPE(kernel_name, backend, layout, kernel_fn) \
#define PT_REGISTER_NO_TEMPLATE_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_kernel_all_dtype_ns_check_##kernel_name, \
"PT_REGISTER_KERNEL_ALL_DTYPE must be called in global namespace."); \
static void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name( \
pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
static const ::pten::KernelRegistrar \
__reg_pt_kernel_all_dtype_##kernel_name( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&__PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
PT_KERNEL(kernel_fn), \
PT_VARIADIC_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \
void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name(::pten::Kernel* kernel)
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \
return 0; \
} \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
/** PT_DECLARE_KERNEL
*
* Used to export the symbols of the file where the kernel is located,
* to avoid being removed by linker
*/
#define PT_DECLARE_KERNEL(kernel_name, backend) \
extern int TouchKernelSymbolFor_##kernel_name##_##backend(); \
UNUSED static int __declare_kernel_symbol_for_##kernel_name##_##backend = \
TouchKernelSymbolFor_##kernel_name##_##backend()
#define PT_DECLARE_KERNEL(kernel_name, backend, layout) \
extern int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout(); \
UNUSED static int \
__declare_kernel_symbol_for_##kernel_name##_##backend##_##layout = \
TouchKernelSymbolFor_##kernel_name##_##backend##_##layout()
} // namespace pten
......@@ -63,7 +63,7 @@ void FillConstant(const CPUContext& dev_ctx,
PT_REGISTER_KERNEL(full_like,
CPU,
ANY,
ALL_LAYOUT,
pten::FillAnyLike,
float,
double,
......@@ -74,7 +74,7 @@ PT_REGISTER_KERNEL(full_like,
PT_REGISTER_KERNEL(full,
CPU,
ANY,
ALL_LAYOUT,
pten::FillConstant,
float,
double,
......
......@@ -75,7 +75,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(dot,
CPU,
ANY,
ALL_LAYOUT,
pten::Dot,
float,
double,
......@@ -84,5 +84,11 @@ PT_REGISTER_KERNEL(dot,
complex64,
complex128) {}
PT_REGISTER_KERNEL(
matmul, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {}
PT_REGISTER_KERNEL(matmul,
CPU,
ALL_LAYOUT,
pten::Matmul,
float,
double,
complex64,
complex128) {}
......@@ -85,7 +85,7 @@ void Cast(const CPUContext& dev_ctx,
PT_REGISTER_KERNEL(flatten,
CPU,
ANY,
ALL_LAYOUT,
pten::Flatten,
float,
double,
......@@ -95,7 +95,7 @@ PT_REGISTER_KERNEL(flatten,
int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape,
CPU,
ANY,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
double,
......@@ -106,7 +106,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
PT_REGISTER_KERNEL(cast,
CPU,
ANY,
ALL_LAYOUT,
pten::Cast,
float,
double,
......@@ -122,8 +122,7 @@ PT_REGISTER_KERNEL(cast,
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::Reshape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape,
CPU,
ANY,
pten::ReshapeWithXShape) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape, CPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape_with_xshape, CPU, ALL_LAYOUT, pten::ReshapeWithXShape, ALL_DTYPE) {}
......@@ -111,11 +111,11 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_KERNEL(sign, CPU, ANY, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
CPU,
ANY,
ALL_LAYOUT,
pten::Scale,
float,
double,
......@@ -127,7 +127,7 @@ PT_REGISTER_KERNEL(scale,
int64_t) {}
PT_REGISTER_KERNEL(add,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseAdd,
float,
double,
......@@ -137,7 +137,7 @@ PT_REGISTER_KERNEL(add,
complex128) {}
PT_REGISTER_KERNEL(subtract,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseSub,
float,
double,
......@@ -147,7 +147,7 @@ PT_REGISTER_KERNEL(subtract,
complex128) {}
PT_REGISTER_KERNEL(divide,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseDiv,
float,
double,
......@@ -157,7 +157,7 @@ PT_REGISTER_KERNEL(divide,
complex128) {}
PT_REGISTER_KERNEL(multiply,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseMul,
float,
double,
......@@ -168,7 +168,7 @@ PT_REGISTER_KERNEL(multiply,
complex128) {}
PT_REGISTER_KERNEL(sum,
CPU,
ANY,
ALL_LAYOUT,
pten::Sum,
bool,
float,
......
......@@ -57,4 +57,4 @@ void Copy(const CPUContext& dev_ctx,
} // namespace pten
PT_REGISTER_KERNEL_ALL_DTYPE(copy, CPU, ANY, pten::Copy) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(copy, CPU, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
......@@ -64,7 +64,7 @@ void FillConstant(const CUDAContext& dev_ctx,
PT_REGISTER_KERNEL(full_like,
CUDA,
ANY,
ALL_LAYOUT,
pten::FillAnyLike,
float,
double,
......@@ -75,7 +75,7 @@ PT_REGISTER_KERNEL(full_like,
PT_REGISTER_KERNEL(full,
CUDA,
ANY,
ALL_LAYOUT,
pten::FillConstant,
float,
double,
......
......@@ -60,7 +60,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(dot,
CUDA,
ANY,
ALL_LAYOUT,
pten::Dot,
float,
double,
......@@ -71,7 +71,7 @@ PT_REGISTER_KERNEL(dot,
PT_REGISTER_KERNEL(matmul,
CUDA,
ANY,
ALL_LAYOUT,
pten::Matmul,
float,
double,
......
......@@ -86,7 +86,7 @@ using float16 = paddle::platform::float16;
PT_REGISTER_KERNEL(flatten,
CUDA,
ANY,
ALL_LAYOUT,
pten::Flatten,
float,
float16,
......@@ -97,7 +97,7 @@ PT_REGISTER_KERNEL(flatten,
int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape,
CUDA,
ANY,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
double,
......@@ -109,7 +109,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL(cast, \
CUDA, \
ANY, \
ALL_LAYOUT, \
pten::Cast, \
float, \
double, \
......@@ -132,8 +132,6 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::Reshape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape,
CUDA,
ANY,
pten::ReshapeWithXShape) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, CUDA, ANY, pten::Reshape, ALL_DTYPE) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape_with_xshape, CUDA, ANY, pten::ReshapeWithXShape, ALL_DTYPE) {}
......@@ -115,11 +115,12 @@ using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(sign, CUDA, ANY, pten::Sign, float, double, float16) {}
PT_REGISTER_KERNEL(mean, CUDA, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(sign, CUDA, ALL_LAYOUT, pten::Sign, float, double, float16) {
}
PT_REGISTER_KERNEL(mean, CUDA, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
CUDA,
ANY,
ALL_LAYOUT,
pten::Scale,
float,
double,
......@@ -131,7 +132,7 @@ PT_REGISTER_KERNEL(scale,
int64_t) {}
PT_REGISTER_KERNEL(add,
CUDA,
ANY,
ALL_LAYOUT,
pten::ElementwiseAdd,
float,
double,
......@@ -142,7 +143,7 @@ PT_REGISTER_KERNEL(add,
complex128) {}
PT_REGISTER_KERNEL(subtract,
CUDA,
ANY,
ALL_LAYOUT,
pten::ElementwiseSub,
float,
double,
......@@ -153,7 +154,7 @@ PT_REGISTER_KERNEL(subtract,
complex128) {}
PT_REGISTER_KERNEL(divide,
CUDA,
ANY,
ALL_LAYOUT,
pten::ElementwiseDiv,
float,
double,
......@@ -164,7 +165,7 @@ PT_REGISTER_KERNEL(divide,
complex128) {}
PT_REGISTER_KERNEL(multiply,
CUDA,
ANY,
ALL_LAYOUT,
pten::ElementwiseMul,
float,
double,
......@@ -176,7 +177,7 @@ PT_REGISTER_KERNEL(multiply,
complex128) {}
PT_REGISTER_KERNEL(sum,
CUDA,
ANY,
ALL_LAYOUT,
pten::Sum,
bool,
float,
......
......@@ -234,4 +234,4 @@ void Copy(const CUDAContext& dev_ctx,
}
} // namespace pten
PT_REGISTER_KERNEL_ALL_DTYPE(copy, CUDA, ANY, pten::Copy) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(copy, CUDA, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
......@@ -78,7 +78,7 @@ void ReshapeWithXShape(const XPUContext& dev_ctx,
PT_REGISTER_KERNEL(flatten,
XPU,
ANY,
ALL_LAYOUT,
pten::Flatten,
float,
paddle::platform::float16,
......@@ -90,7 +90,7 @@ PT_REGISTER_KERNEL(flatten,
PT_REGISTER_KERNEL(flatten_with_xshape,
XPU,
ANY,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
paddle::platform::float16,
......@@ -100,4 +100,5 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
int,
int64_t) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::Reshape) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape, XPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {}
......@@ -76,4 +76,4 @@ void Copy(const XPUDeviceContext& dev_ctx,
} // namespace pten
PT_REGISTER_KERNEL_ALL_DTYPE(copy, XPU, ANY, pten::Copy) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(copy, XPU, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
......@@ -28,7 +28,7 @@ TEST(DataLayout, OStream) {
EXPECT_EQ(oss.str(), "Undefined");
oss.str("");
oss << pten::DataLayout::ANY;
EXPECT_EQ(oss.str(), "Any");
EXPECT_EQ(oss.str(), "Undefined");
oss.str("");
oss << pten::DataLayout::NHWC;
EXPECT_EQ(oss.str(), "NHWC");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册