From 53b3f40fe758600b0213d9c6e8ff484518c5e488 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 2 Nov 2021 11:10:15 +0800 Subject: [PATCH] [PTen] Fix detail bugs and append registry macro (#36866) * fix several bugs * fix elementwith override error --- .../operators/elementwise/elementwise_op.h | 2 +- paddle/pten/core/kernel_factory.cc | 5 +- paddle/pten/core/kernel_registry.h | 216 +++++++++++++++++- 3 files changed, 218 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 13e4624ef7..651f0e3dc8 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -129,7 +129,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const framework::Tensor &tensor, - const framework::OpKernelType &expected_kernel_type) const { + const framework::OpKernelType &expected_kernel_type) const override { if (framework::IsComplexType(expected_kernel_type.data_type_)) { // only promote inputs’s types when contains complex input return framework::OpKernelType(tensor.type(), tensor.place(), diff --git a/paddle/pten/core/kernel_factory.cc b/paddle/pten/core/kernel_factory.cc index 729f137c08..aeefb7cfef 100644 --- a/paddle/pten/core/kernel_factory.cc +++ b/paddle/pten/core/kernel_factory.cc @@ -28,7 +28,7 @@ uint32_t KernelKey::Hash::operator()(const KernelKey& key) const { (static_cast(key.layout()) << KernelKey::kBackendBitLength); hash_value |= (static_cast(key.dtype()) - << (KernelKey::kBackendBitLength + KernelKey::kDataTypeBitLength)); + << (KernelKey::kBackendBitLength + KernelKey::kDataLayoutBitLength)); return hash_value; } @@ -60,7 +60,8 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( auto kernel_iter = iter->second.find(kernel_key); // TODO(chenweihang): polish refind impl here - if (kernel_key.layout() != pten::DataLayout::ANY) { + if (kernel_iter == iter->second.end() && + kernel_key.layout() != pten::DataLayout::ANY) { pten::KernelKey any_layout_kernel_key( kernel_key.backend(), pten::DataLayout::ANY, kernel_key.dtype()); kernel_iter = iter->second.find(any_layout_kernel_key); diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index adfe0d98b6..d3422d173a 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -198,9 +198,11 @@ struct KernelRegistrar { */ #define PT_NARGS(...) _PT_NARGS((__VA_ARGS__, _PT_RESQ_N())) #define _PT_NARGS(...) _PT_ARG_N(__VA_ARGS__) -#define _PT_ARG_N_EXPAND(_1, _2, _3, _4, _5, _6, _7, _8, N, ...) N +#define _PT_ARG_N_EXPAND( \ + _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \ + N #define _PT_ARG_N(args) _PT_ARG_N_EXPAND args -#define _PT_RESQ_N() 8, 7, 6, 5, 4, 3, 2, 1, 0 +#define _PT_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 #define PT_REGISTER_KERNEL( \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ @@ -296,6 +298,27 @@ struct KernelRegistrar { #define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, cpp_dtype, ...) \ template decltype(meta_kernel_fn) meta_kernel_fn; \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__)) #define PT_KERNEL_REGISTRAR_INIT(kernel_name, \ func_id, \ @@ -549,6 +572,195 @@ struct KernelRegistrar { args_def_fn, \ meta_kernel_fn, \ __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) #define PT_REGISTER_KERNEL_STANDARD( \ kernel_name, backend, layout, dtype, kernel_fn) \ -- GitLab