Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
158bf13f
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
158bf13f
编写于
1月 13, 2022
作者:
C
Chen Weihang
提交者:
GitHub
1月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PTen] Rename kernel register marco (#38861)
* rename register marco * fix error changing * fix format error
上级
dccdc719
变更
25
显示空白变更内容
内联
并排
Showing
25 changed file
with
636 addition
and
1193 deletion
+636
-1193
cmake/pten_kernel.cmake
cmake/pten_kernel.cmake
+3
-3
paddle/pten/core/kernel_registry.h
paddle/pten/core/kernel_registry.h
+132
-688
paddle/pten/kernels/cpu/cast_kernel.cc
paddle/pten/kernels/cpu/cast_kernel.cc
+15
-15
paddle/pten/kernels/cpu/complex_kernel.cc
paddle/pten/kernels/cpu/complex_kernel.cc
+10
-10
paddle/pten/kernels/cpu/dot_grad_kernel.cc
paddle/pten/kernels/cpu/dot_grad_kernel.cc
+10
-10
paddle/pten/kernels/cpu/dot_kernel.cc
paddle/pten/kernels/cpu/dot_kernel.cc
+10
-10
paddle/pten/kernels/cpu/full_kernel.cc
paddle/pten/kernels/cpu/full_kernel.cc
+25
-25
paddle/pten/kernels/cpu/math_kernel.cc
paddle/pten/kernels/cpu/math_kernel.cc
+54
-54
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
+26
-26
paddle/pten/kernels/cpu/matmul_kernel.cc
paddle/pten/kernels/cpu/matmul_kernel.cc
+8
-8
paddle/pten/kernels/cpu/scale_kernel.cc
paddle/pten/kernels/cpu/scale_kernel.cc
+12
-12
paddle/pten/kernels/cpu/sign_kernel.cc
paddle/pten/kernels/cpu/sign_kernel.cc
+1
-2
paddle/pten/kernels/empty_kernel.cc
paddle/pten/kernels/empty_kernel.cc
+58
-58
paddle/pten/kernels/flatten_grad_kernel.cc
paddle/pten/kernels/flatten_grad_kernel.cc
+30
-30
paddle/pten/kernels/flatten_kernel.cc
paddle/pten/kernels/flatten_kernel.cc
+60
-60
paddle/pten/kernels/gpu/cast_kernel.cu
paddle/pten/kernels/gpu/cast_kernel.cu
+18
-18
paddle/pten/kernels/gpu/complex_kernel.cu
paddle/pten/kernels/gpu/complex_kernel.cu
+11
-11
paddle/pten/kernels/gpu/dot_grad_kernel.cu
paddle/pten/kernels/gpu/dot_grad_kernel.cu
+10
-10
paddle/pten/kernels/gpu/dot_kernel.cu
paddle/pten/kernels/gpu/dot_kernel.cu
+10
-10
paddle/pten/kernels/gpu/full_kernel.cu
paddle/pten/kernels/gpu/full_kernel.cu
+24
-24
paddle/pten/kernels/gpu/math_kernel.cu
paddle/pten/kernels/gpu/math_kernel.cu
+58
-58
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
+29
-29
paddle/pten/kernels/gpu/matmul_kernel.cu
paddle/pten/kernels/gpu/matmul_kernel.cu
+9
-9
paddle/pten/kernels/gpu/scale_kernel.cu
paddle/pten/kernels/gpu/scale_kernel.cu
+12
-12
paddle/pten/kernels/gpu/sign_kernel.cu
paddle/pten/kernels/gpu/sign_kernel.cu
+1
-1
未找到文件。
cmake/pten_kernel.cmake
浏览文件 @
158bf13f
...
...
@@ -16,12 +16,12 @@
function
(
kernel_declare TARGET_LIST
)
foreach
(
kernel_path
${
TARGET_LIST
}
)
file
(
READ
${
kernel_path
}
kernel_impl
)
# TODO(chenweihang): rename PT_REGISTER_
CTX_
KERNEL to PT_REGISTER_KERNEL
# TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
string
(
REGEX MATCH
"(PT_REGISTER_
CTX_
KERNEL|PT_REGISTER_GENERAL_KERNEL)
\\
([
\t\r\n
]*[a-z0-9_]*,"
first_registry
"
${
kernel_impl
}
"
)
string
(
REGEX MATCH
"(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)
\\
([
\t\r\n
]*[a-z0-9_]*,"
first_registry
"
${
kernel_impl
}
"
)
if
(
NOT first_registry STREQUAL
""
)
# parse the first kernel name
string
(
REPLACE
"PT_REGISTER_
CTX_
KERNEL("
""
kernel_name
"
${
first_registry
}
"
)
string
(
REPLACE
"PT_REGISTER_KERNEL("
""
kernel_name
"
${
first_registry
}
"
)
string
(
REPLACE
"PT_REGISTER_GENERAL_KERNEL("
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
","
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REGEX REPLACE
"[
\t\r\n
]+"
""
kernel_name
"
${
kernel_name
}
"
)
...
...
paddle/pten/core/kernel_registry.h
浏览文件 @
158bf13f
...
...
@@ -213,20 +213,20 @@ struct KernelRegistrar {
* pointer of the corresponding data type is automatically instantiated
* during registration.
*
* Note: `
1TA` means `1
template argument`
* Note: `
2TA` means `2
template argument`
*/
#define PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_KERNEL must be called in global namespace."); \
_PT_REGISTER_
1
TA_KERNEL( \
_PT_REGISTER_
2
TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32
#define _PT_REGISTER_
1
TA_KERNEL( \
#define _PT_REGISTER_
2
TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn,
cpp_dtype, __VA_ARGS__);
\
PT_KERNEL_INSTANTIATION(meta_kernel_fn,
backend, cpp_dtype, __VA_ARGS__);
\
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \
...
...
@@ -252,7 +252,7 @@ struct KernelRegistrar {
*
* And msvc can work without template instantiation
*/
#define _PT_REGISTER_
1
TA_KERNEL( \
#define _PT_REGISTER_
2
TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
...
...
@@ -268,60 +268,76 @@ struct KernelRegistrar {
::pten::Kernel* kernel)
#endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn,
backend,
cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \
backend, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn,
backend,
cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, cpp_dtype, __VA_ARGS__)
(meta_kernel_fn,
backend,
cpp_dtype, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
...
...
@@ -373,10 +389,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \
...
...
@@ -393,10 +410,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
backend, \
layout, \
...
...
@@ -419,10 +437,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \
layout, \
...
...
@@ -445,10 +464,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
backend, \
layout, \
...
...
@@ -471,10 +491,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
backend, \
layout, \
...
...
@@ -497,10 +518,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
backend, \
layout, \
...
...
@@ -523,10 +545,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
backend, \
layout, \
...
...
@@ -549,10 +572,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
backend, \
layout, \
...
...
@@ -575,10 +599,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
backend, \
layout, \
...
...
@@ -601,10 +626,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
backend, \
layout, \
...
...
@@ -627,10 +653,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
backend, \
layout, \
...
...
@@ -653,10 +680,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
backend, \
layout, \
...
...
@@ -679,10 +707,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
backend, \
layout, \
...
...
@@ -705,10 +734,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
backend, \
layout, \
...
...
@@ -731,10 +761,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
backend, \
layout, \
...
...
@@ -743,41 +774,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
/** PT_REGISTER_NO_TEMPLATE_KERNEL
*
* Basic Kernel register marco, used to register a no template argument kernel
* function, pass in the complete function pointe of the kernel, this
* registration macro will not do automatic template instantiation.
*
* Note: developer maybe register 2 kernel with same name, backend and diff
* layout, so the layout also need to be a part of symbol var name. If developer
* register 2 kernel with same name, backend, layout and diff dtype, he should
* use another register marco PT_REGISTER_KERNEL.
*
* TODO(chenweihang): remove this marco later
*/
#define PT_REGISTER_NO_TEMPLATE_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
static const ::pten::KernelRegistrar \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
PT_KERNEL(kernel_fn), \
PT_VARIADIC_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \
return 0; \
} \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
/** PT_REGISTER_GENERAL_KERNEL
*
* Basic Kernel register marco, used to register a instantiated kernel function
...
...
@@ -832,558 +828,6 @@ struct KernelRegistrar {
::pten::Kernel* kernel)
#endif
/** PT_REGISTER_CTX_KERNEL
*
* Used for kernel registration with device context and data type as
* template parameter.
*/
#define PT_REGISTER_CTX_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_ctx_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_CTX_KERNEL must be called in global namespace."); \
_PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#else
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#endif
#define PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION2(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \
backend, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION2(N, meta_kernel_fn, backend, cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION2_, N) \
(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
#define _PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_15(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
_PT_KERNEL_REGISTRAR_INIT2(PT_NARGS(cpp_dtype, __VA_ARGS__), \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
// clang-format off
/* The =pre-commit always treats this macro into the wrong format,
and multi-line macros cannot be skipped with NOLINT.*/
#define _PT_KERNEL_REGISTRAR_INIT2(N, \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT2_, N) ( \
kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
// clang-format on
#define _PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_15(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
/** PT_DECLARE_KERNEL
*
* Used to export the symbols of the file where the kernel is located,
...
...
paddle/pten/kernels/cpu/cast_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -58,7 +58,7 @@ void CastKernel(const Context& dev_ctx,
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
cast
,
PT_REGISTER_KERNEL
(
cast
,
CPU
,
ALL_LAYOUT
,
pten
::
CastKernel
,
...
...
paddle/pten/kernels/cpu/complex_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -21,7 +21,7 @@
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
conj
,
PT_REGISTER_KERNEL
(
conj
,
CPU
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
...
...
paddle/pten/kernels/cpu/dot_grad_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -20,7 +20,7 @@
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
dot_grad
,
PT_REGISTER_KERNEL
(
dot_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
...
...
paddle/pten/kernels/cpu/dot_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -49,7 +49,7 @@ void DotKernel(const Context& dev_ctx,
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
dot
,
PT_REGISTER_KERNEL
(
dot
,
CPU
,
ALL_LAYOUT
,
pten
::
DotKernel
,
...
...
paddle/pten/kernels/cpu/full_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
full
,
PT_REGISTER_KERNEL
(
full
,
CPU
,
ALL_LAYOUT
,
pten
::
FullKernel
,
...
...
@@ -34,7 +34,7 @@ PT_REGISTER_CTX_KERNEL(full,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
full_like
,
PT_REGISTER_KERNEL
(
full_like
,
CPU
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
...
...
paddle/pten/kernels/cpu/math_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -118,9 +118,9 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
mean
,
CPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
)
{}
PT_REGISTER_
CTX_
KERNEL
(
add
,
PT_REGISTER_KERNEL
(
add
,
CPU
,
ALL_LAYOUT
,
pten
::
AddKernel
,
...
...
@@ -130,7 +130,7 @@ PT_REGISTER_CTX_KERNEL(add,
int64_t
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
subtract
,
PT_REGISTER_KERNEL
(
subtract
,
CPU
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
...
...
@@ -140,7 +140,7 @@ PT_REGISTER_CTX_KERNEL(subtract,
int64_t
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
divide
,
PT_REGISTER_KERNEL
(
divide
,
CPU
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
...
...
@@ -150,7 +150,7 @@ PT_REGISTER_CTX_KERNEL(divide,
int64_t
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
multiply
,
PT_REGISTER_KERNEL
(
multiply
,
CPU
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
...
...
@@ -161,7 +161,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
bool
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
sum
,
PT_REGISTER_KERNEL
(
sum
,
CPU
,
ALL_LAYOUT
,
pten
::
SumKernel
,
...
...
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul_grad
,
PT_REGISTER_KERNEL
(
matmul_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
...
...
@@ -28,7 +28,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_double_grad
,
PT_REGISTER_KERNEL
(
matmul_double_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
...
...
@@ -37,7 +37,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_triple_grad
,
PT_REGISTER_KERNEL
(
matmul_triple_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
...
...
paddle/pten/kernels/cpu/matmul_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul
,
PT_REGISTER_KERNEL
(
matmul
,
CPU
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
...
...
paddle/pten/kernels/cpu/scale_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -51,7 +51,7 @@ void ScaleKernel(const Context& dev_ctx,
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
scale
,
PT_REGISTER_KERNEL
(
scale
,
CPU
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
...
...
paddle/pten/kernels/cpu/sign_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -21,5 +21,4 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/bfloat16.h"
PT_REGISTER_CTX_KERNEL
(
sign
,
CPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
)
{
}
PT_REGISTER_KERNEL
(
sign
,
CPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
)
{}
paddle/pten/kernels/empty_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -34,7 +34,7 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
empty
,
PT_REGISTER_KERNEL
(
empty
,
CPU
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
...
...
@@ -50,7 +50,7 @@ PT_REGISTER_CTX_KERNEL(empty,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
empty_like
,
PT_REGISTER_KERNEL
(
empty_like
,
CPU
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
...
...
@@ -67,7 +67,7 @@ PT_REGISTER_CTX_KERNEL(empty_like,
paddle
::
platform
::
complex
<
double
>
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
empty
,
PT_REGISTER_KERNEL
(
empty
,
GPU
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
...
...
@@ -82,7 +82,7 @@ PT_REGISTER_CTX_KERNEL(empty,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
empty_like
,
PT_REGISTER_KERNEL
(
empty_like
,
GPU
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
...
...
paddle/pten/kernels/flatten_grad_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -33,7 +33,7 @@ void FlattenGradKernel(const Context& dev_ctx,
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
...
...
@@ -45,7 +45,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad,
int64_t
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
...
...
@@ -60,7 +60,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad,
#endif
#ifdef PADDLE_WITH_XPU
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
XPU
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
...
...
paddle/pten/kernels/flatten_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -48,7 +48,7 @@ void FlattenWithXShape(const Context& dev_ctx,
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
CPU
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
...
...
@@ -59,7 +59,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
int
,
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
CPU
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
...
...
@@ -71,7 +71,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
int64_t
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
GPU
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
...
...
@@ -83,7 +83,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
int
,
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
GPU
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
...
...
@@ -97,7 +97,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
#endif
#ifdef PADDLE_WITH_XPU
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
XPU
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
...
...
@@ -107,7 +107,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
int
,
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
XPU
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
...
...
paddle/pten/kernels/gpu/cast_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -61,7 +61,7 @@ void CastKernel(const Context& dev_ctx,
}
// namespace pten
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_
CTX_
KERNEL(cast, \
PT_REGISTER_KERNEL(cast, \
GPU, \
ALL_LAYOUT, \
pten::CastKernel, \
...
...
paddle/pten/kernels/gpu/complex_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -21,7 +21,7 @@
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
conj
,
PT_REGISTER_KERNEL
(
conj
,
GPU
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
...
...
paddle/pten/kernels/gpu/dot_grad_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
dot_grad
,
PT_REGISTER_KERNEL
(
dot_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
...
...
paddle/pten/kernels/gpu/dot_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -52,7 +52,7 @@ void DotKernel(const Context& dev_ctx,
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
dot
,
PT_REGISTER_KERNEL
(
dot
,
GPU
,
ALL_LAYOUT
,
pten
::
DotKernel
,
...
...
paddle/pten/kernels/gpu/full_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
full
,
PT_REGISTER_KERNEL
(
full
,
GPU
,
ALL_LAYOUT
,
pten
::
FullKernel
,
...
...
@@ -33,7 +33,7 @@ PT_REGISTER_CTX_KERNEL(full,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
full_like
,
PT_REGISTER_KERNEL
(
full_like
,
GPU
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
...
...
paddle/pten/kernels/gpu/math_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -110,9 +110,9 @@ using float16 = paddle::platform::float16;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
mean
,
GPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
,
float16
)
{}
PT_REGISTER_
CTX_
KERNEL
(
add
,
PT_REGISTER_KERNEL
(
add
,
GPU
,
ALL_LAYOUT
,
pten
::
AddKernel
,
...
...
@@ -123,7 +123,7 @@ PT_REGISTER_CTX_KERNEL(add,
float16
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
subtract
,
PT_REGISTER_KERNEL
(
subtract
,
GPU
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
...
...
@@ -134,7 +134,7 @@ PT_REGISTER_CTX_KERNEL(subtract,
float16
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
divide
,
PT_REGISTER_KERNEL
(
divide
,
GPU
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
...
...
@@ -145,7 +145,7 @@ PT_REGISTER_CTX_KERNEL(divide,
float16
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
multiply
,
PT_REGISTER_KERNEL
(
multiply
,
GPU
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
...
...
@@ -157,7 +157,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
float16
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
sum
,
PT_REGISTER_KERNEL
(
sum
,
GPU
,
ALL_LAYOUT
,
pten
::
SumKernel
,
...
...
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul_grad
,
PT_REGISTER_KERNEL
(
matmul_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
...
...
@@ -29,7 +29,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_double_grad
,
PT_REGISTER_KERNEL
(
matmul_double_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
...
...
@@ -39,7 +39,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_triple_grad
,
PT_REGISTER_KERNEL
(
matmul_triple_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
...
...
paddle/pten/kernels/gpu/matmul_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul
,
PT_REGISTER_KERNEL
(
matmul
,
GPU
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
...
...
paddle/pten/kernels/gpu/scale_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -64,7 +64,7 @@ void ScaleKernel(const ContextT& dev_ctx,
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
scale
,
PT_REGISTER_KERNEL
(
scale
,
GPU
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
...
...
paddle/pten/kernels/gpu/sign_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -23,5 +23,5 @@ limitations under the License. */
using
float16
=
paddle
::
platform
::
float16
;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
sign
,
GPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
,
float16
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录