Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
53b3f40f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
53b3f40f
编写于
11月 02, 2021
作者:
C
Chen Weihang
提交者:
GitHub
11月 02, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PTen] Fix detail bugs and append registry macro (#36866)
* fix several bugs * fix elementwith override error
上级
8fb6e77b
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
218 addition
and
5 deletion
+218
-5
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+1
-1
paddle/pten/core/kernel_factory.cc
paddle/pten/core/kernel_factory.cc
+3
-2
paddle/pten/core/kernel_registry.h
paddle/pten/core/kernel_registry.h
+214
-2
未找到文件。
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
53b3f40f
...
@@ -129,7 +129,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
...
@@ -129,7 +129,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetKernelTypeForVar
(
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
framework
::
Tensor
&
tensor
,
const
std
::
string
&
var_name
,
const
framework
::
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
{
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
if
(
framework
::
IsComplexType
(
expected_kernel_type
.
data_type_
))
{
if
(
framework
::
IsComplexType
(
expected_kernel_type
.
data_type_
))
{
// only promote inputs’s types when contains complex input
// only promote inputs’s types when contains complex input
return
framework
::
OpKernelType
(
tensor
.
type
(),
tensor
.
place
(),
return
framework
::
OpKernelType
(
tensor
.
type
(),
tensor
.
place
(),
...
...
paddle/pten/core/kernel_factory.cc
浏览文件 @
53b3f40f
...
@@ -28,7 +28,7 @@ uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
...
@@ -28,7 +28,7 @@ uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
(
static_cast
<
uint8_t
>
(
key
.
layout
())
<<
KernelKey
::
kBackendBitLength
);
(
static_cast
<
uint8_t
>
(
key
.
layout
())
<<
KernelKey
::
kBackendBitLength
);
hash_value
|=
hash_value
|=
(
static_cast
<
uint16_t
>
(
key
.
dtype
())
(
static_cast
<
uint16_t
>
(
key
.
dtype
())
<<
(
KernelKey
::
kBackendBitLength
+
KernelKey
::
kData
Type
BitLength
));
<<
(
KernelKey
::
kBackendBitLength
+
KernelKey
::
kData
Layout
BitLength
));
return
hash_value
;
return
hash_value
;
}
}
...
@@ -60,7 +60,8 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
...
@@ -60,7 +60,8 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
auto
kernel_iter
=
iter
->
second
.
find
(
kernel_key
);
auto
kernel_iter
=
iter
->
second
.
find
(
kernel_key
);
// TODO(chenweihang): polish refind impl here
// TODO(chenweihang): polish refind impl here
if
(
kernel_key
.
layout
()
!=
pten
::
DataLayout
::
ANY
)
{
if
(
kernel_iter
==
iter
->
second
.
end
()
&&
kernel_key
.
layout
()
!=
pten
::
DataLayout
::
ANY
)
{
pten
::
KernelKey
any_layout_kernel_key
(
pten
::
KernelKey
any_layout_kernel_key
(
kernel_key
.
backend
(),
pten
::
DataLayout
::
ANY
,
kernel_key
.
dtype
());
kernel_key
.
backend
(),
pten
::
DataLayout
::
ANY
,
kernel_key
.
dtype
());
kernel_iter
=
iter
->
second
.
find
(
any_layout_kernel_key
);
kernel_iter
=
iter
->
second
.
find
(
any_layout_kernel_key
);
...
...
paddle/pten/core/kernel_registry.h
浏览文件 @
53b3f40f
...
@@ -198,9 +198,11 @@ struct KernelRegistrar {
...
@@ -198,9 +198,11 @@ struct KernelRegistrar {
*/
*/
#define PT_NARGS(...) _PT_NARGS((__VA_ARGS__, _PT_RESQ_N()))
#define PT_NARGS(...) _PT_NARGS((__VA_ARGS__, _PT_RESQ_N()))
#define _PT_NARGS(...) _PT_ARG_N(__VA_ARGS__)
#define _PT_NARGS(...) _PT_ARG_N(__VA_ARGS__)
#define _PT_ARG_N_EXPAND(_1, _2, _3, _4, _5, _6, _7, _8, N, ...) N
#define _PT_ARG_N_EXPAND( \
_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \
N
#define _PT_ARG_N(args) _PT_ARG_N_EXPAND args
#define _PT_ARG_N(args) _PT_ARG_N_EXPAND args
#define _PT_RESQ_N() 8, 7, 6, 5, 4, 3, 2, 1, 0
#define _PT_RESQ_N()
15, 14, 13, 12, 11, 10, 9,
8, 7, 6, 5, 4, 3, 2, 1, 0
#define PT_REGISTER_KERNEL( \
#define PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
...
@@ -296,6 +298,27 @@ struct KernelRegistrar {
...
@@ -296,6 +298,27 @@ struct KernelRegistrar {
#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, cpp_dtype, ...) \
#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT(kernel_name, \
#define PT_KERNEL_REGISTRAR_INIT(kernel_name, \
func_id, \
func_id, \
...
@@ -549,6 +572,195 @@ struct KernelRegistrar {
...
@@ -549,6 +572,195 @@ struct KernelRegistrar {
args_def_fn, \
args_def_fn, \
meta_kernel_fn, \
meta_kernel_fn, \
__VA_ARGS__))
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define PT_REGISTER_KERNEL_STANDARD( \
#define PT_REGISTER_KERNEL_STANDARD( \
kernel_name, backend, layout, dtype, kernel_fn) \
kernel_name, backend, layout, dtype, kernel_fn) \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录