Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
158bf13f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
158bf13f
编写于
1月 13, 2022
作者:
C
Chen Weihang
提交者:
GitHub
1月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PTen] Rename kernel register marco (#38861)
* rename register marco * fix error changing * fix format error
上级
dccdc719
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
636 addition
and
1193 deletion
+636
-1193
cmake/pten_kernel.cmake
cmake/pten_kernel.cmake
+3
-3
paddle/pten/core/kernel_registry.h
paddle/pten/core/kernel_registry.h
+132
-688
paddle/pten/kernels/cpu/cast_kernel.cc
paddle/pten/kernels/cpu/cast_kernel.cc
+15
-15
paddle/pten/kernels/cpu/complex_kernel.cc
paddle/pten/kernels/cpu/complex_kernel.cc
+10
-10
paddle/pten/kernels/cpu/dot_grad_kernel.cc
paddle/pten/kernels/cpu/dot_grad_kernel.cc
+10
-10
paddle/pten/kernels/cpu/dot_kernel.cc
paddle/pten/kernels/cpu/dot_kernel.cc
+10
-10
paddle/pten/kernels/cpu/full_kernel.cc
paddle/pten/kernels/cpu/full_kernel.cc
+25
-25
paddle/pten/kernels/cpu/math_kernel.cc
paddle/pten/kernels/cpu/math_kernel.cc
+54
-54
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
+26
-26
paddle/pten/kernels/cpu/matmul_kernel.cc
paddle/pten/kernels/cpu/matmul_kernel.cc
+8
-8
paddle/pten/kernels/cpu/scale_kernel.cc
paddle/pten/kernels/cpu/scale_kernel.cc
+12
-12
paddle/pten/kernels/cpu/sign_kernel.cc
paddle/pten/kernels/cpu/sign_kernel.cc
+1
-2
paddle/pten/kernels/empty_kernel.cc
paddle/pten/kernels/empty_kernel.cc
+58
-58
paddle/pten/kernels/flatten_grad_kernel.cc
paddle/pten/kernels/flatten_grad_kernel.cc
+30
-30
paddle/pten/kernels/flatten_kernel.cc
paddle/pten/kernels/flatten_kernel.cc
+60
-60
paddle/pten/kernels/gpu/cast_kernel.cu
paddle/pten/kernels/gpu/cast_kernel.cu
+18
-18
paddle/pten/kernels/gpu/complex_kernel.cu
paddle/pten/kernels/gpu/complex_kernel.cu
+11
-11
paddle/pten/kernels/gpu/dot_grad_kernel.cu
paddle/pten/kernels/gpu/dot_grad_kernel.cu
+10
-10
paddle/pten/kernels/gpu/dot_kernel.cu
paddle/pten/kernels/gpu/dot_kernel.cu
+10
-10
paddle/pten/kernels/gpu/full_kernel.cu
paddle/pten/kernels/gpu/full_kernel.cu
+24
-24
paddle/pten/kernels/gpu/math_kernel.cu
paddle/pten/kernels/gpu/math_kernel.cu
+58
-58
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
+29
-29
paddle/pten/kernels/gpu/matmul_kernel.cu
paddle/pten/kernels/gpu/matmul_kernel.cu
+9
-9
paddle/pten/kernels/gpu/scale_kernel.cu
paddle/pten/kernels/gpu/scale_kernel.cu
+12
-12
paddle/pten/kernels/gpu/sign_kernel.cu
paddle/pten/kernels/gpu/sign_kernel.cu
+1
-1
未找到文件。
cmake/pten_kernel.cmake
浏览文件 @
158bf13f
...
...
@@ -16,12 +16,12 @@
function
(
kernel_declare TARGET_LIST
)
foreach
(
kernel_path
${
TARGET_LIST
}
)
file
(
READ
${
kernel_path
}
kernel_impl
)
# TODO(chenweihang): rename PT_REGISTER_
CTX_
KERNEL to PT_REGISTER_KERNEL
# TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
string
(
REGEX MATCH
"(PT_REGISTER_
CTX_
KERNEL|PT_REGISTER_GENERAL_KERNEL)
\\
([
\t\r\n
]*[a-z0-9_]*,"
first_registry
"
${
kernel_impl
}
"
)
string
(
REGEX MATCH
"(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)
\\
([
\t\r\n
]*[a-z0-9_]*,"
first_registry
"
${
kernel_impl
}
"
)
if
(
NOT first_registry STREQUAL
""
)
# parse the first kernel name
string
(
REPLACE
"PT_REGISTER_
CTX_
KERNEL("
""
kernel_name
"
${
first_registry
}
"
)
string
(
REPLACE
"PT_REGISTER_KERNEL("
""
kernel_name
"
${
first_registry
}
"
)
string
(
REPLACE
"PT_REGISTER_GENERAL_KERNEL("
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
","
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REGEX REPLACE
"[
\t\r\n
]+"
""
kernel_name
"
${
kernel_name
}
"
)
...
...
paddle/pten/core/kernel_registry.h
浏览文件 @
158bf13f
...
...
@@ -213,20 +213,20 @@ struct KernelRegistrar {
* pointer of the corresponding data type is automatically instantiated
* during registration.
*
* Note: `
1TA` means `1
template argument`
* Note: `
2TA` means `2
template argument`
*/
#define PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_KERNEL must be called in global namespace."); \
_PT_REGISTER_
1
TA_KERNEL( \
_PT_REGISTER_
2
TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32
#define _PT_REGISTER_
1
TA_KERNEL( \
#define _PT_REGISTER_
2
TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn,
cpp_dtype, __VA_ARGS__);
\
PT_KERNEL_INSTANTIATION(meta_kernel_fn,
backend, cpp_dtype, __VA_ARGS__);
\
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \
...
...
@@ -252,7 +252,7 @@ struct KernelRegistrar {
*
* And msvc can work without template instantiation
*/
#define _PT_REGISTER_
1
TA_KERNEL( \
#define _PT_REGISTER_
2
TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
...
...
@@ -268,60 +268,76 @@ struct KernelRegistrar {
::pten::Kernel* kernel)
#endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \
cpp_dtype, \
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \
backend, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn,
backend,
cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N)
\
(meta_kernel_fn,
backend,
cpp_dtype, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
...
...
@@ -373,10 +389,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \
...
...
@@ -393,10 +410,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
backend, \
layout, \
...
...
@@ -419,10 +437,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
backend, \
layout, \
...
...
@@ -445,10 +464,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
backend, \
layout, \
...
...
@@ -471,10 +491,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
backend, \
layout, \
...
...
@@ -497,10 +518,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
backend, \
layout, \
...
...
@@ -523,10 +545,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
backend, \
layout, \
...
...
@@ -549,10 +572,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
backend, \
layout, \
...
...
@@ -575,10 +599,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
backend, \
layout, \
...
...
@@ -601,10 +626,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
backend, \
layout, \
...
...
@@ -627,10 +653,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
backend, \
layout, \
...
...
@@ -653,10 +680,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
backend, \
layout, \
...
...
@@ -679,10 +707,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
backend, \
layout, \
...
...
@@ -705,10 +734,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
backend, \
layout, \
...
...
@@ -731,10 +761,11 @@ struct KernelRegistrar {
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype
>)>::Parse,
\
&meta_kernel_fn<cpp_dtype
, ::pten::backend##Context>)>::Parse,
\
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
backend, \
layout, \
...
...
@@ -743,41 +774,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
/** PT_REGISTER_NO_TEMPLATE_KERNEL
*
* Basic Kernel register marco, used to register a no template argument kernel
* function, pass in the complete function pointe of the kernel, this
* registration macro will not do automatic template instantiation.
*
* Note: developer maybe register 2 kernel with same name, backend and diff
* layout, so the layout also need to be a part of symbol var name. If developer
* register 2 kernel with same name, backend, layout and diff dtype, he should
* use another register marco PT_REGISTER_KERNEL.
*
* TODO(chenweihang): remove this marco later
*/
#define PT_REGISTER_NO_TEMPLATE_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
static const ::pten::KernelRegistrar \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
PT_KERNEL(kernel_fn), \
PT_VARIADIC_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \
return 0; \
} \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
/** PT_REGISTER_GENERAL_KERNEL
*
* Basic Kernel register marco, used to register a instantiated kernel function
...
...
@@ -832,558 +828,6 @@ struct KernelRegistrar {
::pten::Kernel* kernel)
#endif
/** PT_REGISTER_CTX_KERNEL
*
* Used for kernel registration with device context and data type as
* template parameter.
*/
#define PT_REGISTER_CTX_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_ctx_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_CTX_KERNEL must be called in global namespace."); \
_PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#else
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#endif
#define PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION2(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \
backend, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION2(N, meta_kernel_fn, backend, cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION2_, N) \
(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
#define _PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_15(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
_PT_KERNEL_REGISTRAR_INIT2(PT_NARGS(cpp_dtype, __VA_ARGS__), \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
// clang-format off
/* The =pre-commit always treats this macro into the wrong format,
and multi-line macros cannot be skipped with NOLINT.*/
#define _PT_KERNEL_REGISTRAR_INIT2(N, \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT2_, N) ( \
kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
// clang-format on
#define _PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_15(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
/** PT_DECLARE_KERNEL
*
* Used to export the symbols of the file where the kernel is located,
...
...
paddle/pten/kernels/cpu/cast_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -58,20 +58,20 @@ void CastKernel(const Context& dev_ctx,
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
cast
,
CPU
,
ALL_LAYOUT
,
pten
::
CastKernel
,
float
,
double
,
int
,
int64_t
,
int16_t
,
bool
,
uint8_t
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{
PT_REGISTER_KERNEL
(
cast
,
CPU
,
ALL_LAYOUT
,
pten
::
CastKernel
,
float
,
double
,
int
,
int64_t
,
int16_t
,
bool
,
uint8_t
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
paddle/pten/kernels/cpu/complex_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -21,13 +21,13 @@
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
conj
,
CPU
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
,
float
,
double
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
conj
,
CPU
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
,
float
,
double
,
int
,
int64_t
)
{}
paddle/pten/kernels/cpu/dot_grad_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -20,13 +20,13 @@
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
dot_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
float
,
double
,
int
,
int64_t
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
dot_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
float
,
double
,
int
,
int64_t
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/cpu/dot_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -49,13 +49,13 @@ void DotKernel(const Context& dev_ctx,
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
dot
,
CPU
,
ALL_LAYOUT
,
pten
::
DotKernel
,
float
,
double
,
int
,
int64_t
,
complex64
,
complex128
)
{}
PT_REGISTER_KERNEL
(
dot
,
CPU
,
ALL_LAYOUT
,
pten
::
DotKernel
,
float
,
double
,
int
,
int64_t
,
complex64
,
complex128
)
{}
paddle/pten/kernels/cpu/full_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -18,29 +18,29 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
full
,
CPU
,
ALL_LAYOUT
,
pten
::
FullKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
full
,
CPU
,
ALL_LAYOUT
,
pten
::
FullKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
full_like
,
CPU
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
)
{}
PT_REGISTER_KERNEL
(
full_like
,
CPU
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
)
{}
paddle/pten/kernels/cpu/math_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -118,60 +118,60 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
mean
,
CPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
)
{}
PT_REGISTER_
CTX_
KERNEL
(
add
,
CPU
,
ALL_LAYOUT
,
pten
::
AddKernel
,
float
,
double
,
int
,
int64_t
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
subtract
,
CPU
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
float
,
double
,
int
,
int64_t
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
divide
,
CPU
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
float
,
double
,
int
,
int64_t
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
multiply
,
CPU
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
sum
,
CPU
,
ALL_LAYOUT
,
pten
::
SumKernel
,
bool
,
float
,
double
,
paddle
::
platform
::
float16
,
int
,
int64_t
,
complex64
,
complex128
)
{
PT_REGISTER_KERNEL
(
add
,
CPU
,
ALL_LAYOUT
,
pten
::
AddKernel
,
float
,
double
,
int
,
int64_t
,
complex64
,
complex128
)
{}
PT_REGISTER_KERNEL
(
subtract
,
CPU
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
float
,
double
,
int
,
int64_t
,
complex64
,
complex128
)
{}
PT_REGISTER_KERNEL
(
divide
,
CPU
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
float
,
double
,
int
,
int64_t
,
complex64
,
complex128
)
{}
PT_REGISTER_KERNEL
(
multiply
,
CPU
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
complex64
,
complex128
)
{}
PT_REGISTER_KERNEL
(
sum
,
CPU
,
ALL_LAYOUT
,
pten
::
SumKernel
,
bool
,
float
,
double
,
paddle
::
platform
::
float16
,
int
,
int64_t
,
complex64
,
complex128
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -19,29 +19,29 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
float
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_double_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
float
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_triple_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
float
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
matmul_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
float
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
matmul_double_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
float
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
matmul_triple_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
float
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/cpu/matmul_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -20,11 +20,11 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul
,
CPU
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
float
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
matmul
,
CPU
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
float
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/cpu/scale_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -51,15 +51,15 @@ void ScaleKernel(const Context& dev_ctx,
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
scale
,
CPU
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
float
,
double
,
paddle
::
platform
::
bfloat16
,
uint8_t
,
int8_t
,
int16_t
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
scale
,
CPU
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
float
,
double
,
paddle
::
platform
::
bfloat16
,
uint8_t
,
int8_t
,
int16_t
,
int
,
int64_t
)
{}
paddle/pten/kernels/cpu/sign_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -21,5 +21,4 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/bfloat16.h"
PT_REGISTER_CTX_KERNEL
(
sign
,
CPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
)
{
}
PT_REGISTER_KERNEL
(
sign
,
CPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
)
{}
paddle/pten/kernels/empty_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -34,66 +34,66 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
empty
,
CPU
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
empty
,
CPU
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
empty_like
,
CPU
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
empty_like
,
CPU
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
empty
,
GPU
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
empty
,
GPU
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
empty_like
,
GPU
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
empty_like
,
GPU
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
#endif
paddle/pten/kernels/flatten_grad_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -33,41 +33,41 @@ void FlattenGradKernel(const Context& dev_ctx,
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
float
,
double
,
uint8_t
,
int8_t
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
flatten_grad
,
CPU
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
float
,
double
,
uint8_t
,
int8_t
,
int
,
int64_t
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
float
,
paddle
::
platform
::
float16
,
double
,
uint8_t
,
int8_t
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
flatten_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
float
,
paddle
::
platform
::
float16
,
double
,
uint8_t
,
int8_t
,
int
,
int64_t
)
{}
#endif
#ifdef PADDLE_WITH_XPU
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
XPU
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
float
,
paddle
::
platform
::
float16
,
int8_t
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
flatten_grad
,
XPU
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
float
,
paddle
::
platform
::
float16
,
int8_t
,
int
,
int64_t
)
{}
#endif
paddle/pten/kernels/flatten_kernel.cc
浏览文件 @
158bf13f
...
...
@@ -48,72 +48,72 @@ void FlattenWithXShape(const Context& dev_ctx,
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
CPU
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
float
,
double
,
uint8_t
,
int8_t
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
flatten
,
CPU
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
float
,
double
,
uint8_t
,
int8_t
,
int
,
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
CPU
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
float
,
double
,
uint8_t
,
int8_t
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
CPU
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
float
,
double
,
uint8_t
,
int8_t
,
int
,
int64_t
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
GPU
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
float
,
paddle
::
platform
::
float16
,
double
,
uint8_t
,
int8_t
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
flatten
,
GPU
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
float
,
paddle
::
platform
::
float16
,
double
,
uint8_t
,
int8_t
,
int
,
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
GPU
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
float
,
paddle
::
platform
::
float16
,
double
,
uint8_t
,
int8_t
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
GPU
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
float
,
paddle
::
platform
::
float16
,
double
,
uint8_t
,
int8_t
,
int
,
int64_t
)
{}
#endif
#ifdef PADDLE_WITH_XPU
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
XPU
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
float
,
paddle
::
platform
::
float16
,
int8_t
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
flatten
,
XPU
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
float
,
paddle
::
platform
::
float16
,
int8_t
,
int
,
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
XPU
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
float
,
paddle
::
platform
::
float16
,
int8_t
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
XPU
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
float
,
paddle
::
platform
::
float16
,
int8_t
,
int
,
int64_t
)
{}
#endif
paddle/pten/kernels/gpu/cast_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -60,24 +60,24 @@ void CastKernel(const Context& dev_ctx,
}
// namespace pten
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...)
\
PT_REGISTER_
CTX_
KERNEL(cast, \
GPU, \
ALL_LAYOUT, \
pten::CastKernel, \
float, \
double, \
int, \
int64_t, \
int16_t, \
bool, \
uint8_t, \
paddle::platform::float16, \
paddle::platform::complex<float>, \
paddle::platform::complex<double>, \
##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType(
\
paddle::experimental::DataType::UNDEFINED);
\
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL(cast, \
GPU, \
ALL_LAYOUT, \
pten::CastKernel, \
float, \
double, \
int, \
int64_t, \
int16_t, \
bool, \
uint8_t, \
paddle::platform::float16, \
paddle::platform::complex<float>, \
paddle::platform::complex<double>, \
##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED); \
}
#if !defined(PADDLE_WITH_HIP)
...
...
paddle/pten/kernels/gpu/complex_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -21,14 +21,14 @@
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
conj
,
GPU
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
,
float
,
double
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
conj
,
GPU
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
,
float
,
double
,
int
,
int64_t
)
{}
paddle/pten/kernels/gpu/dot_grad_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -20,13 +20,13 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
dot_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
float
,
double
,
int
,
int64_t
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
dot_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
float
,
double
,
int
,
int64_t
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/gpu/dot_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -52,13 +52,13 @@ void DotKernel(const Context& dev_ctx,
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
dot
,
GPU
,
ALL_LAYOUT
,
pten
::
DotKernel
,
float
,
double
,
int
,
int64_t
,
complex64
,
complex128
)
{}
PT_REGISTER_KERNEL
(
dot
,
GPU
,
ALL_LAYOUT
,
pten
::
DotKernel
,
float
,
double
,
int
,
int64_t
,
complex64
,
complex128
)
{}
paddle/pten/kernels/gpu/full_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -18,28 +18,28 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
full
,
GPU
,
ALL_LAYOUT
,
pten
::
FullKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
full
,
GPU
,
ALL_LAYOUT
,
pten
::
FullKernel
,
float
,
double
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
full_like
,
GPU
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
)
{}
PT_REGISTER_KERNEL
(
full_like
,
GPU
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
paddle
::
platform
::
float16
)
{}
paddle/pten/kernels/gpu/math_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -110,64 +110,64 @@ using float16 = paddle::platform::float16;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
mean
,
GPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
,
float16
)
{}
PT_REGISTER_
CTX_
KERNEL
(
add
,
GPU
,
ALL_LAYOUT
,
pten
::
AddKernel
,
float
,
double
,
int
,
int64_t
,
float16
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
subtract
,
GPU
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
float
,
double
,
int
,
int64_t
,
float16
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
divide
,
GPU
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
float
,
double
,
int
,
int64_t
,
float16
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
multiply
,
GPU
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
float16
,
complex64
,
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
sum
,
GPU
,
ALL_LAYOUT
,
pten
::
SumKernel
,
bool
,
float
,
double
,
float16
,
int
,
int64_t
,
complex64
,
complex128
)
{
PT_REGISTER_KERNEL
(
add
,
GPU
,
ALL_LAYOUT
,
pten
::
AddKernel
,
float
,
double
,
int
,
int64_t
,
float16
,
complex64
,
complex128
)
{}
PT_REGISTER_KERNEL
(
subtract
,
GPU
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
float
,
double
,
int
,
int64_t
,
float16
,
complex64
,
complex128
)
{}
PT_REGISTER_KERNEL
(
divide
,
GPU
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
float
,
double
,
int
,
int64_t
,
float16
,
complex64
,
complex128
)
{}
PT_REGISTER_KERNEL
(
multiply
,
GPU
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
float16
,
complex64
,
complex128
)
{}
PT_REGISTER_KERNEL
(
sum
,
GPU
,
ALL_LAYOUT
,
pten
::
SumKernel
,
bool
,
float
,
double
,
float16
,
int
,
int64_t
,
complex64
,
complex128
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -19,32 +19,32 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
float
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_double_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
float
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_triple_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
float
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
matmul_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
float
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
matmul_double_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
float
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
matmul_triple_grad
,
GPU
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
float
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/gpu/matmul_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -20,12 +20,12 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul
,
GPU
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
float
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_KERNEL
(
matmul
,
GPU
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
float
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/gpu/scale_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -64,15 +64,15 @@ void ScaleKernel(const ContextT& dev_ctx,
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
scale
,
GPU
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
float
,
double
,
paddle
::
platform
::
float16
,
uint8_t
,
int8_t
,
int16_t
,
int
,
int64_t
)
{}
PT_REGISTER_KERNEL
(
scale
,
GPU
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
float
,
double
,
paddle
::
platform
::
float16
,
uint8_t
,
int8_t
,
int16_t
,
int
,
int64_t
)
{}
paddle/pten/kernels/gpu/sign_kernel.cu
浏览文件 @
158bf13f
...
...
@@ -23,5 +23,5 @@ limitations under the License. */
using
float16
=
paddle
::
platform
::
float16
;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
sign
,
GPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
,
float16
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录