Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
9fd67ffe
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9fd67ffe
编写于
2月 15, 2022
作者:
C
Chen Weihang
提交者:
GitHub
2月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PTen] Fix single dtype register errror (#39506)
* fix single dtype reg errror * fix windows failed
上级
b81358d1
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
44 addition
and
54 deletion
+44
-54
paddle/pten/core/kernel_registry.h
paddle/pten/core/kernel_registry.h
+44
-54
未找到文件。
paddle/pten/core/kernel_registry.h
浏览文件 @
9fd67ffe
...
@@ -219,18 +219,17 @@ struct KernelRegistrar {
...
@@ -219,18 +219,17 @@ struct KernelRegistrar {
*
*
* Note: `2TA` means `2 template argument`
* Note: `2TA` means `2 template argument`
*/
*/
#define PT_REGISTER_KERNEL( \
#define PT_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_KERNEL must be called in global namespace."); \
"PT_REGISTER_KERNEL must be called in global namespace."); \
PT_EXPAND(_PT_REGISTER_2TA_KERNEL( \
_PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, __VA_ARGS__))
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32
#ifndef _WIN32
#define _PT_REGISTER_2TA_KERNEL( \
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn,
cpp_dtype, ...)
\
kernel_name, backend, layout, meta_kernel_fn,
...)
\
PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend,
cpp_dtype, __VA_ARGS__);
\
PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend,
__VA_ARGS__);
\
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \
PT_KERNEL_REGISTRAR_INIT( \
...
@@ -239,7 +238,6 @@ struct KernelRegistrar {
...
@@ -239,7 +238,6 @@ struct KernelRegistrar {
layout, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
::pten::Kernel* kernel)
...
@@ -257,34 +255,30 @@ struct KernelRegistrar {
...
@@ -257,34 +255,30 @@ struct KernelRegistrar {
* And msvc can work without template instantiation
* And msvc can work without template instantiation
*/
*/
#define _PT_REGISTER_2TA_KERNEL( \
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn,
cpp_dtype, ...)
\
kernel_name, backend, layout, meta_kernel_fn,
...)
\
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
::pten::Kernel*); \
PT_
KERNEL_REGISTRAR_INIT(
\
PT_
EXPAND(PT_KERNEL_REGISTRAR_INIT(
\
kernel_name, \
kernel_name, \
backend, \
backend, \
layout, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)); \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
::pten::Kernel* kernel)
#endif
#endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, ...) \
_PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \
_PT_KERNEL_INSTANTIATION( \
meta_kernel_fn, \
PT_NARGS(__VA_ARGS__), meta_kernel_fn, backend, __VA_ARGS__)
backend, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend,
cpp_dtype,
...) \
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N)
\
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, backend,
cpp_dtype,
__VA_ARGS__)
(meta_kernel_fn, backend, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype
, ...)
\
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype
)
\
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)
\
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
...
@@ -343,38 +337,35 @@ struct KernelRegistrar {
...
@@ -343,38 +337,35 @@ struct KernelRegistrar {
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT( \
#define PT_KERNEL_REGISTRAR_INIT( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, ...) \
_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(cpp_dtype, __VA_ARGS__), \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(__VA_ARGS__), \
kernel_name, \
kernel_name, \
backend, \
backend, \
layout, \
layout, \
args_def_fn, \
args_def_fn, \
meta_kernel_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__))
__VA_ARGS__)
// clang-format off
// clang-format off
/* The =pre-commit always treats this macro into the wrong format,
/* The =pre-commit always treats this macro into the wrong format,
and multi-line macros cannot be skipped with NOLINT.*/
and multi-line macros cannot be skipped with NOLINT.*/
#define _PT_KERNEL_REGISTRAR_INIT(N, \
#define _PT_KERNEL_REGISTRAR_INIT(N, \
kernel_name, \
kernel_name, \
backend, \
backend, \
layout, \
layout, \
args_def_fn, \
args_def_fn, \
meta_kernel_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
...) \
PT_EXPAND(PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \
PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \
kernel_name, \
kernel_name, \
backend, \
backend, \
layout, \
layout, \
PT_ID, \
PT_ID, \
args_def_fn, \
args_def_fn, \
meta_kernel_fn, \
meta_kernel_fn, \
__VA_ARGS__))
cpp_dtype, \
__VA_ARGS__)
// clang-format on
// clang-format on
...
@@ -384,8 +375,7 @@ struct KernelRegistrar {
...
@@ -384,8 +375,7 @@ struct KernelRegistrar {
registrar_id, \
registrar_id, \
args_def_fn, \
args_def_fn, \
meta_kernel_fn, \
meta_kernel_fn, \
cpp_dtype, \
cpp_dtype) \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
#kernel_name, \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录