未验证 提交 53b3f40f 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Fix detail bugs and append registry macro (#36866)


* fix several bugs

* fix elementwith override error
上级 8fb6e77b
...@@ -129,7 +129,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -129,7 +129,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor, const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const { const framework::OpKernelType &expected_kernel_type) const override {
if (framework::IsComplexType(expected_kernel_type.data_type_)) { if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input // only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(), return framework::OpKernelType(tensor.type(), tensor.place(),
......
...@@ -28,7 +28,7 @@ uint32_t KernelKey::Hash::operator()(const KernelKey& key) const { ...@@ -28,7 +28,7 @@ uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
(static_cast<uint8_t>(key.layout()) << KernelKey::kBackendBitLength); (static_cast<uint8_t>(key.layout()) << KernelKey::kBackendBitLength);
hash_value |= hash_value |=
(static_cast<uint16_t>(key.dtype()) (static_cast<uint16_t>(key.dtype())
<< (KernelKey::kBackendBitLength + KernelKey::kDataTypeBitLength)); << (KernelKey::kBackendBitLength + KernelKey::kDataLayoutBitLength));
return hash_value; return hash_value;
} }
...@@ -60,7 +60,8 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( ...@@ -60,7 +60,8 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
auto kernel_iter = iter->second.find(kernel_key); auto kernel_iter = iter->second.find(kernel_key);
// TODO(chenweihang): polish refind impl here // TODO(chenweihang): polish refind impl here
if (kernel_key.layout() != pten::DataLayout::ANY) { if (kernel_iter == iter->second.end() &&
kernel_key.layout() != pten::DataLayout::ANY) {
pten::KernelKey any_layout_kernel_key( pten::KernelKey any_layout_kernel_key(
kernel_key.backend(), pten::DataLayout::ANY, kernel_key.dtype()); kernel_key.backend(), pten::DataLayout::ANY, kernel_key.dtype());
kernel_iter = iter->second.find(any_layout_kernel_key); kernel_iter = iter->second.find(any_layout_kernel_key);
......
...@@ -198,9 +198,11 @@ struct KernelRegistrar { ...@@ -198,9 +198,11 @@ struct KernelRegistrar {
*/ */
#define PT_NARGS(...) _PT_NARGS((__VA_ARGS__, _PT_RESQ_N())) #define PT_NARGS(...) _PT_NARGS((__VA_ARGS__, _PT_RESQ_N()))
#define _PT_NARGS(...) _PT_ARG_N(__VA_ARGS__) #define _PT_NARGS(...) _PT_ARG_N(__VA_ARGS__)
#define _PT_ARG_N_EXPAND(_1, _2, _3, _4, _5, _6, _7, _8, N, ...) N #define _PT_ARG_N_EXPAND( \
_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \
N
#define _PT_ARG_N(args) _PT_ARG_N_EXPAND args #define _PT_ARG_N(args) _PT_ARG_N_EXPAND args
#define _PT_RESQ_N() 8, 7, 6, 5, 4, 3, 2, 1, 0 #define _PT_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
#define PT_REGISTER_KERNEL( \ #define PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
...@@ -296,6 +298,27 @@ struct KernelRegistrar { ...@@ -296,6 +298,27 @@ struct KernelRegistrar {
#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, cpp_dtype, ...) \ #define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, __VA_ARGS__)) PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT(kernel_name, \ #define PT_KERNEL_REGISTRAR_INIT(kernel_name, \
func_id, \ func_id, \
...@@ -549,6 +572,195 @@ struct KernelRegistrar { ...@@ -549,6 +572,195 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define PT_REGISTER_KERNEL_STANDARD( \ #define PT_REGISTER_KERNEL_STANDARD( \
kernel_name, backend, layout, dtype, kernel_fn) \ kernel_name, backend, layout, dtype, kernel_fn) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册