Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
158bf13f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
158bf13f
编写于
1月 13, 2022
作者:
C
Chen Weihang
提交者:
GitHub
1月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PTen] Rename kernel register marco (#38861)
* rename register marco * fix error changing * fix format error
上级
dccdc719
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
636 addition
and
1193 deletion
+636
-1193
cmake/pten_kernel.cmake
cmake/pten_kernel.cmake
+3
-3
paddle/pten/core/kernel_registry.h
paddle/pten/core/kernel_registry.h
+132
-688
paddle/pten/kernels/cpu/cast_kernel.cc
paddle/pten/kernels/cpu/cast_kernel.cc
+15
-15
paddle/pten/kernels/cpu/complex_kernel.cc
paddle/pten/kernels/cpu/complex_kernel.cc
+10
-10
paddle/pten/kernels/cpu/dot_grad_kernel.cc
paddle/pten/kernels/cpu/dot_grad_kernel.cc
+10
-10
paddle/pten/kernels/cpu/dot_kernel.cc
paddle/pten/kernels/cpu/dot_kernel.cc
+10
-10
paddle/pten/kernels/cpu/full_kernel.cc
paddle/pten/kernels/cpu/full_kernel.cc
+25
-25
paddle/pten/kernels/cpu/math_kernel.cc
paddle/pten/kernels/cpu/math_kernel.cc
+54
-54
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
+26
-26
paddle/pten/kernels/cpu/matmul_kernel.cc
paddle/pten/kernels/cpu/matmul_kernel.cc
+8
-8
paddle/pten/kernels/cpu/scale_kernel.cc
paddle/pten/kernels/cpu/scale_kernel.cc
+12
-12
paddle/pten/kernels/cpu/sign_kernel.cc
paddle/pten/kernels/cpu/sign_kernel.cc
+1
-2
paddle/pten/kernels/empty_kernel.cc
paddle/pten/kernels/empty_kernel.cc
+58
-58
paddle/pten/kernels/flatten_grad_kernel.cc
paddle/pten/kernels/flatten_grad_kernel.cc
+30
-30
paddle/pten/kernels/flatten_kernel.cc
paddle/pten/kernels/flatten_kernel.cc
+60
-60
paddle/pten/kernels/gpu/cast_kernel.cu
paddle/pten/kernels/gpu/cast_kernel.cu
+18
-18
paddle/pten/kernels/gpu/complex_kernel.cu
paddle/pten/kernels/gpu/complex_kernel.cu
+11
-11
paddle/pten/kernels/gpu/dot_grad_kernel.cu
paddle/pten/kernels/gpu/dot_grad_kernel.cu
+10
-10
paddle/pten/kernels/gpu/dot_kernel.cu
paddle/pten/kernels/gpu/dot_kernel.cu
+10
-10
paddle/pten/kernels/gpu/full_kernel.cu
paddle/pten/kernels/gpu/full_kernel.cu
+24
-24
paddle/pten/kernels/gpu/math_kernel.cu
paddle/pten/kernels/gpu/math_kernel.cu
+58
-58
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
+29
-29
paddle/pten/kernels/gpu/matmul_kernel.cu
paddle/pten/kernels/gpu/matmul_kernel.cu
+9
-9
paddle/pten/kernels/gpu/scale_kernel.cu
paddle/pten/kernels/gpu/scale_kernel.cu
+12
-12
paddle/pten/kernels/gpu/sign_kernel.cu
paddle/pten/kernels/gpu/sign_kernel.cu
+1
-1
未找到文件。
cmake/pten_kernel.cmake
浏览文件 @
158bf13f
...
@@ -16,12 +16,12 @@
...
@@ -16,12 +16,12 @@
function
(
kernel_declare TARGET_LIST
)
function
(
kernel_declare TARGET_LIST
)
foreach
(
kernel_path
${
TARGET_LIST
}
)
foreach
(
kernel_path
${
TARGET_LIST
}
)
file
(
READ
${
kernel_path
}
kernel_impl
)
file
(
READ
${
kernel_path
}
kernel_impl
)
# TODO(chenweihang): rename PT_REGISTER_
CTX_
KERNEL to PT_REGISTER_KERNEL
# TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
string
(
REGEX MATCH
"(PT_REGISTER_
CTX_
KERNEL|PT_REGISTER_GENERAL_KERNEL)
\\
([
\t\r\n
]*[a-z0-9_]*,"
first_registry
"
${
kernel_impl
}
"
)
string
(
REGEX MATCH
"(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)
\\
([
\t\r\n
]*[a-z0-9_]*,"
first_registry
"
${
kernel_impl
}
"
)
if
(
NOT first_registry STREQUAL
""
)
if
(
NOT first_registry STREQUAL
""
)
# parse the first kernel name
# parse the first kernel name
string
(
REPLACE
"PT_REGISTER_
CTX_
KERNEL("
""
kernel_name
"
${
first_registry
}
"
)
string
(
REPLACE
"PT_REGISTER_KERNEL("
""
kernel_name
"
${
first_registry
}
"
)
string
(
REPLACE
"PT_REGISTER_GENERAL_KERNEL("
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
"PT_REGISTER_GENERAL_KERNEL("
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
","
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
","
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REGEX REPLACE
"[
\t\r\n
]+"
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REGEX REPLACE
"[
\t\r\n
]+"
""
kernel_name
"
${
kernel_name
}
"
)
...
...
paddle/pten/core/kernel_registry.h
浏览文件 @
158bf13f
...
@@ -213,20 +213,20 @@ struct KernelRegistrar {
...
@@ -213,20 +213,20 @@ struct KernelRegistrar {
* pointer of the corresponding data type is automatically instantiated
* pointer of the corresponding data type is automatically instantiated
* during registration.
* during registration.
*
*
* Note: `
1TA` means `1
template argument`
* Note: `
2TA` means `2
template argument`
*/
*/
#define PT_REGISTER_KERNEL( \
#define PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_KERNEL must be called in global namespace."); \
"PT_REGISTER_KERNEL must be called in global namespace."); \
_PT_REGISTER_
1
TA_KERNEL( \
_PT_REGISTER_
2
TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32
#ifndef _WIN32
#define _PT_REGISTER_
1
TA_KERNEL( \
#define _PT_REGISTER_
2
TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn,
cpp_dtype, __VA_ARGS__);
\
PT_KERNEL_INSTANTIATION(meta_kernel_fn,
backend, cpp_dtype, __VA_ARGS__);
\
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \
PT_KERNEL_REGISTRAR_INIT( \
...
@@ -252,7 +252,7 @@ struct KernelRegistrar {
...
@@ -252,7 +252,7 @@ struct KernelRegistrar {
*
*
* And msvc can work without template instantiation
* And msvc can work without template instantiation
*/
*/
#define _PT_REGISTER_
1
TA_KERNEL( \
#define _PT_REGISTER_
2
TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
::pten::Kernel*); \
...
@@ -268,60 +268,76 @@ struct KernelRegistrar {
...
@@ -268,60 +268,76 @@ struct KernelRegistrar {
::pten::Kernel* kernel)
::pten::Kernel* kernel)
#endif
#endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \
_PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \
meta_kernel_fn, \
cpp_dtype, \
backend, \
cpp_dtype, \
__VA_ARGS__)
__VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn,
backend,
cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N)
\
(meta_kernel_fn, cpp_dtype, __VA_ARGS__)
(meta_kernel_fn,
backend,
cpp_dtype, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, cpp_dtype, ...) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, cpp_dtype, ...) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, __VA_ARGS__))
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, cpp_dtype, ...) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, __VA_ARGS__))
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, cpp_dtype, ...) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, __VA_ARGS__))
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, cpp_dtype, ...) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT( \
#define PT_KERNEL_REGISTRAR_INIT( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
...
@@ -373,10 +389,11 @@ struct KernelRegistrar {
...
@@ -373,10 +389,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \
backend, \
...
@@ -393,10 +410,11 @@ struct KernelRegistrar {
...
@@ -393,10 +410,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -419,10 +437,11 @@ struct KernelRegistrar {
...
@@ -419,10 +437,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -445,10 +464,11 @@ struct KernelRegistrar {
...
@@ -445,10 +464,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -471,10 +491,11 @@ struct KernelRegistrar {
...
@@ -471,10 +491,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -497,10 +518,11 @@ struct KernelRegistrar {
...
@@ -497,10 +518,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -523,10 +545,11 @@ struct KernelRegistrar {
...
@@ -523,10 +545,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -549,10 +572,11 @@ struct KernelRegistrar {
...
@@ -549,10 +572,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -575,10 +599,11 @@ struct KernelRegistrar {
...
@@ -575,10 +599,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -601,10 +626,11 @@ struct KernelRegistrar {
...
@@ -601,10 +626,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -627,10 +653,11 @@ struct KernelRegistrar {
...
@@ -627,10 +653,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -653,10 +680,11 @@ struct KernelRegistrar {
...
@@ -653,10 +680,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -679,10 +707,11 @@ struct KernelRegistrar {
...
@@ -679,10 +707,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -705,10 +734,11 @@ struct KernelRegistrar {
...
@@ -705,10 +734,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -731,10 +761,11 @@ struct KernelRegistrar {
...
@@ -731,10 +761,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
backend, \
backend, \
layout, \
layout, \
...
@@ -743,41 +774,6 @@ struct KernelRegistrar {
...
@@ -743,41 +774,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
meta_kernel_fn, \
__VA_ARGS__))
__VA_ARGS__))
/** PT_REGISTER_NO_TEMPLATE_KERNEL
*
* Basic Kernel register marco, used to register a no template argument kernel
* function, pass in the complete function pointe of the kernel, this
* registration macro will not do automatic template instantiation.
*
* Note: developer maybe register 2 kernel with same name, backend and diff
* layout, so the layout also need to be a part of symbol var name. If developer
* register 2 kernel with same name, backend, layout and diff dtype, he should
* use another register marco PT_REGISTER_KERNEL.
*
* TODO(chenweihang): remove this marco later
*/
#define PT_REGISTER_NO_TEMPLATE_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
static const ::pten::KernelRegistrar \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
PT_KERNEL(kernel_fn), \
PT_VARIADIC_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \
return 0; \
} \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
/** PT_REGISTER_GENERAL_KERNEL
/** PT_REGISTER_GENERAL_KERNEL
*
*
* Basic Kernel register marco, used to register a instantiated kernel function
* Basic Kernel register marco, used to register a instantiated kernel function
...
@@ -832,558 +828,6 @@ struct KernelRegistrar {
...
@@ -832,558 +828,6 @@ struct KernelRegistrar {
::pten::Kernel* kernel)
::pten::Kernel* kernel)
#endif
#endif
/** PT_REGISTER_CTX_KERNEL
*
* Used for kernel registration with device context and data type as
* template parameter.
*/
#define PT_REGISTER_CTX_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_ctx_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_CTX_KERNEL must be called in global namespace."); \
_PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#else
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#endif
#define PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION2(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \
backend, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION2(N, meta_kernel_fn, backend, cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION2_, N) \
(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
#define _PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_15(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
_PT_KERNEL_REGISTRAR_INIT2(PT_NARGS(cpp_dtype, __VA_ARGS__), \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
// clang-format off
/* The =pre-commit always treats this macro into the wrong format,
and multi-line macros cannot be skipped with NOLINT.*/
#define _PT_KERNEL_REGISTRAR_INIT2(N, \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT2_, N) ( \
kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
// clang-format on
#define _PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_15(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
/** PT_DECLARE_KERNEL
/** PT_DECLARE_KERNEL
*
*
* Used to export the symbols of the file where the kernel is located,
* Used to export the symbols of the file where the kernel is located,
...
...
paddle/pten/kernels/cpu/cast_kernel.cc
浏览文件 @
158bf13f
...
@@ -58,20 +58,20 @@ void CastKernel(const Context& dev_ctx,
...
@@ -58,20 +58,20 @@ void CastKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
cast
,
PT_REGISTER_KERNEL
(
cast
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
CastKernel
,
pten
::
CastKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
int16_t
,
int16_t
,
bool
,
bool
,
uint8_t
,
uint8_t
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{
paddle
::
platform
::
complex
<
double
>
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
}
paddle/pten/kernels/cpu/complex_kernel.cc
浏览文件 @
158bf13f
...
@@ -21,13 +21,13 @@
...
@@ -21,13 +21,13 @@
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
conj
,
PT_REGISTER_KERNEL
(
conj
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
pten
::
ConjKernel
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
,
paddle
::
platform
::
complex
<
double
>
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
paddle/pten/kernels/cpu/dot_grad_kernel.cc
浏览文件 @
158bf13f
...
@@ -20,13 +20,13 @@
...
@@ -20,13 +20,13 @@
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
dot_grad
,
PT_REGISTER_KERNEL
(
dot_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
pten
::
DotGradKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/cpu/dot_kernel.cc
浏览文件 @
158bf13f
...
@@ -49,13 +49,13 @@ void DotKernel(const Context& dev_ctx,
...
@@ -49,13 +49,13 @@ void DotKernel(const Context& dev_ctx,
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
dot
,
PT_REGISTER_KERNEL
(
dot
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotKernel
,
pten
::
DotKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
paddle/pten/kernels/cpu/full_kernel.cc
浏览文件 @
158bf13f
...
@@ -18,29 +18,29 @@ limitations under the License. */
...
@@ -18,29 +18,29 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
full
,
PT_REGISTER_KERNEL
(
full
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullKernel
,
pten
::
FullKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
full_like
,
PT_REGISTER_KERNEL
(
full_like
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
pten
::
FullLikeKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
)
{}
paddle
::
platform
::
float16
)
{}
paddle/pten/kernels/cpu/math_kernel.cc
浏览文件 @
158bf13f
...
@@ -118,60 +118,60 @@ using complex128 = ::paddle::platform::complex<double>;
...
@@ -118,60 +118,60 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
mean
,
CPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
)
{}
mean
,
CPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
)
{}
PT_REGISTER_
CTX_
KERNEL
(
add
,
PT_REGISTER_KERNEL
(
add
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
AddKernel
,
pten
::
AddKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
subtract
,
PT_REGISTER_KERNEL
(
subtract
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
pten
::
SubtractKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
divide
,
PT_REGISTER_KERNEL
(
divide
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
pten
::
DivideKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
multiply
,
PT_REGISTER_KERNEL
(
multiply
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
pten
::
MultiplyKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
sum
,
PT_REGISTER_KERNEL
(
sum
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SumKernel
,
pten
::
SumKernel
,
bool
,
bool
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{
complex128
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
}
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
浏览文件 @
158bf13f
...
@@ -19,29 +19,29 @@ limitations under the License. */
...
@@ -19,29 +19,29 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul_grad
,
PT_REGISTER_KERNEL
(
matmul_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
pten
::
MatmulGradKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_double_grad
,
PT_REGISTER_KERNEL
(
matmul_double_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
pten
::
MatmulDoubleGradKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_triple_grad
,
PT_REGISTER_KERNEL
(
matmul_triple_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
pten
::
MatmulTripleGradKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/cpu/matmul_kernel.cc
浏览文件 @
158bf13f
...
@@ -20,11 +20,11 @@ limitations under the License. */
...
@@ -20,11 +20,11 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul
,
PT_REGISTER_KERNEL
(
matmul
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
pten
::
MatmulKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/cpu/scale_kernel.cc
浏览文件 @
158bf13f
...
@@ -51,15 +51,15 @@ void ScaleKernel(const Context& dev_ctx,
...
@@ -51,15 +51,15 @@ void ScaleKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
scale
,
PT_REGISTER_KERNEL
(
scale
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
pten
::
ScaleKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
bfloat16
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
paddle/pten/kernels/cpu/sign_kernel.cc
浏览文件 @
158bf13f
...
@@ -21,5 +21,4 @@ limitations under the License. */
...
@@ -21,5 +21,4 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/bfloat16.h"
PT_REGISTER_CTX_KERNEL
(
sign
,
CPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
)
{
PT_REGISTER_KERNEL
(
sign
,
CPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
)
{}
}
paddle/pten/kernels/empty_kernel.cc
浏览文件 @
158bf13f
...
@@ -34,66 +34,66 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
...
@@ -34,66 +34,66 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
empty
,
PT_REGISTER_KERNEL
(
empty
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
pten
::
EmptyKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
empty_like
,
PT_REGISTER_KERNEL
(
empty_like
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
pten
::
EmptyLikeKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
empty
,
PT_REGISTER_KERNEL
(
empty
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
pten
::
EmptyKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
empty_like
,
PT_REGISTER_KERNEL
(
empty_like
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
pten
::
EmptyLikeKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
#endif
#endif
paddle/pten/kernels/flatten_grad_kernel.cc
浏览文件 @
158bf13f
...
@@ -33,41 +33,41 @@ void FlattenGradKernel(const Context& dev_ctx,
...
@@ -33,41 +33,41 @@ void FlattenGradKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
pten
::
FlattenGradKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
pten
::
FlattenGradKernel
,
float
,
float
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
double
,
double
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
#endif
#endif
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
XPU
,
XPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
pten
::
FlattenGradKernel
,
float
,
float
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
#endif
#endif
paddle/pten/kernels/flatten_kernel.cc
浏览文件 @
158bf13f
...
@@ -48,72 +48,72 @@ void FlattenWithXShape(const Context& dev_ctx,
...
@@ -48,72 +48,72 @@ void FlattenWithXShape(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
pten
::
FlattenKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
pten
::
FlattenWithXShape
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
pten
::
FlattenKernel
,
float
,
float
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
double
,
double
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
pten
::
FlattenWithXShape
,
float
,
float
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
double
,
double
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
#endif
#endif
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
XPU
,
XPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
pten
::
FlattenKernel
,
float
,
float
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
XPU
,
XPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
pten
::
FlattenWithXShape
,
float
,
float
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
#endif
#endif
paddle/pten/kernels/gpu/cast_kernel.cu
浏览文件 @
158bf13f
...
@@ -60,24 +60,24 @@ void CastKernel(const Context& dev_ctx,
...
@@ -60,24 +60,24 @@ void CastKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...)
\
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_
CTX_
KERNEL(cast, \
PT_REGISTER_KERNEL(cast, \
GPU, \
GPU, \
ALL_LAYOUT, \
ALL_LAYOUT, \
pten::CastKernel, \
pten::CastKernel, \
float, \
float, \
double, \
double, \
int, \
int, \
int64_t, \
int64_t, \
int16_t, \
int16_t, \
bool, \
bool, \
uint8_t, \
uint8_t, \
paddle::platform::float16, \
paddle::platform::float16, \
paddle::platform::complex<float>, \
paddle::platform::complex<float>, \
paddle::platform::complex<double>, \
paddle::platform::complex<double>, \
##__VA_ARGS__) { \
##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType(
\
kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED);
\
paddle::experimental::DataType::UNDEFINED); \
}
}
#if !defined(PADDLE_WITH_HIP)
#if !defined(PADDLE_WITH_HIP)
...
...
paddle/pten/kernels/gpu/complex_kernel.cu
浏览文件 @
158bf13f
...
@@ -21,14 +21,14 @@
...
@@ -21,14 +21,14 @@
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
conj
,
PT_REGISTER_KERNEL
(
conj
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
pten
::
ConjKernel
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
,
paddle
::
platform
::
complex
<
double
>
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
paddle/pten/kernels/gpu/dot_grad_kernel.cu
浏览文件 @
158bf13f
...
@@ -20,13 +20,13 @@ limitations under the License. */
...
@@ -20,13 +20,13 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
dot_grad
,
PT_REGISTER_KERNEL
(
dot_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
pten
::
DotGradKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/gpu/dot_kernel.cu
浏览文件 @
158bf13f
...
@@ -52,13 +52,13 @@ void DotKernel(const Context& dev_ctx,
...
@@ -52,13 +52,13 @@ void DotKernel(const Context& dev_ctx,
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
dot
,
PT_REGISTER_KERNEL
(
dot
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotKernel
,
pten
::
DotKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
paddle/pten/kernels/gpu/full_kernel.cu
浏览文件 @
158bf13f
...
@@ -18,28 +18,28 @@ limitations under the License. */
...
@@ -18,28 +18,28 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
full
,
PT_REGISTER_KERNEL
(
full
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullKernel
,
pten
::
FullKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
full_like
,
PT_REGISTER_KERNEL
(
full_like
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
pten
::
FullLikeKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
)
{}
paddle
::
platform
::
float16
)
{}
paddle/pten/kernels/gpu/math_kernel.cu
浏览文件 @
158bf13f
...
@@ -110,64 +110,64 @@ using float16 = paddle::platform::float16;
...
@@ -110,64 +110,64 @@ using float16 = paddle::platform::float16;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
mean
,
GPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
,
float16
)
{}
mean
,
GPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
,
float16
)
{}
PT_REGISTER_
CTX_
KERNEL
(
add
,
PT_REGISTER_KERNEL
(
add
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
AddKernel
,
pten
::
AddKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
subtract
,
PT_REGISTER_KERNEL
(
subtract
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
pten
::
SubtractKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
divide
,
PT_REGISTER_KERNEL
(
divide
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
pten
::
DivideKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
multiply
,
PT_REGISTER_KERNEL
(
multiply
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
pten
::
MultiplyKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
sum
,
PT_REGISTER_KERNEL
(
sum
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SumKernel
,
pten
::
SumKernel
,
bool
,
bool
,
float
,
float
,
double
,
double
,
float16
,
float16
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{
complex128
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
}
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
浏览文件 @
158bf13f
...
@@ -19,32 +19,32 @@ limitations under the License. */
...
@@ -19,32 +19,32 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul_grad
,
PT_REGISTER_KERNEL
(
matmul_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
pten
::
MatmulGradKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_double_grad
,
PT_REGISTER_KERNEL
(
matmul_double_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
pten
::
MatmulDoubleGradKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_triple_grad
,
PT_REGISTER_KERNEL
(
matmul_triple_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
pten
::
MatmulTripleGradKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/gpu/matmul_kernel.cu
浏览文件 @
158bf13f
...
@@ -20,12 +20,12 @@ limitations under the License. */
...
@@ -20,12 +20,12 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul
,
PT_REGISTER_KERNEL
(
matmul
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
pten
::
MatmulKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/gpu/scale_kernel.cu
浏览文件 @
158bf13f
...
@@ -64,15 +64,15 @@ void ScaleKernel(const ContextT& dev_ctx,
...
@@ -64,15 +64,15 @@ void ScaleKernel(const ContextT& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
scale
,
PT_REGISTER_KERNEL
(
scale
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
pten
::
ScaleKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
paddle/pten/kernels/gpu/sign_kernel.cu
浏览文件 @
158bf13f
...
@@ -23,5 +23,5 @@ limitations under the License. */
...
@@ -23,5 +23,5 @@ limitations under the License. */
using
float16
=
paddle
::
platform
::
float16
;
using
float16
=
paddle
::
platform
::
float16
;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
sign
,
GPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
,
float16
)
{}
sign
,
GPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
,
float16
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录