diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index ed06fac298a8fd2965f93993e796ef1ed1ba0c6e..155eb1ebbe3db3416aa6f14e1e7f6e083bfb54e6 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -555,10 +555,10 @@ class Reshape2Op : public ReshapeOp { const framework::ExecutionContext &ctx) const override { auto multi_inputs = ctx.MultiInput("ShapeTensor"); if (multi_inputs.size() > 0) { - return framework::KernelSignature("reshape.mulhost", {"X", "ShapeTensor"}, + return framework::KernelSignature("reshape_mulhost", {"X", "ShapeTensor"}, {}, {"Out"}); } else if (ctx.HasInput("Shape")) { - return framework::KernelSignature("reshape.host", {"X", "Shape"}, {}, + return framework::KernelSignature("reshape_host", {"X", "Shape"}, {}, {"Out"}); } else { return framework::KernelSignature("reshape", {"X"}, {"shape"}, {"Out"}); diff --git a/paddle/pten/api/lib/kernel_declare.h b/paddle/pten/api/lib/kernel_declare.h new file mode 100644 index 0000000000000000000000000000000000000000..8c21094a4af20257cb488a8b74d804dcba2bce87 --- /dev/null +++ b/paddle/pten/api/lib/kernel_declare.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/core/kernel_registry.h" + +// TODO(chenweihang) After the kernel is split into a single file, +// the kernel declare statement is automatically generated according to the +// file name of the kernel, and this header file will be removed + +PT_DECLARE_KERNEL(full_like, CPU); +PT_DECLARE_KERNEL(dot, CPU); +PT_DECLARE_KERNEL(flatten, CPU); +PT_DECLARE_KERNEL(sign, CPU); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_DECLARE_KERNEL(full_like, CUDA); +PT_DECLARE_KERNEL(dot, CUDA); +PT_DECLARE_KERNEL(flatten, CUDA); +PT_DECLARE_KERNEL(sign, CUDA); +#endif + +#ifdef PADDLE_WITH_XPU +PT_DECLARE_KERNEL(flatten, XPU); +#endif diff --git a/paddle/pten/api/lib/utils.cc b/paddle/pten/api/lib/utils.cc index e17b19d9f689e18eecbeeb6f0d0730cfde1aa3e5..bfde9b14b0020d9a726bdd1fb6f211274a4e39d2 100644 --- a/paddle/pten/api/lib/utils.cc +++ b/paddle/pten/api/lib/utils.cc @@ -25,10 +25,14 @@ limitations under the License. */ #include "paddle/pten/include/core.h" #include "paddle/pten/include/infermeta.h" -PT_DECLARE_MODULE(UtilsCPU); +PT_DECLARE_KERNEL(copy, CPU); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PT_DECLARE_MODULE(UtilsCUDA); +PT_DECLARE_KERNEL(copy, CUDA); +#endif + +#ifdef PADDLE_WITH_XPU +PT_DECLARE_KERNEL(copy, XPU); #endif namespace paddle { diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index cd6fa80906cfbdaee4514d6e4cfabca78bf0f26e..be624177dfb14cf3240ee935c5446e0597e0da9f 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -24,6 +25,8 @@ #include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/core/kernel_utils.h" +#include "paddle/fluid/platform/enforce.h" + namespace pten { #define BACKEND(arg__) pten::Backend::arg__ @@ -193,64 +196,35 @@ struct KernelRegistrar { #define _PT_ARG_N(args) _PT_ARG_N_EXPAND args #define _PT_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +/** PT_REGISTER_KERNEL + * + * The most frequently used kernel registration macro, used for kernel + * registration with only data type as template parameter, and the function + * pointer of the corresponding data type is automatically instantiated + * during registration. + */ #define PT_REGISTER_KERNEL( \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - _PT_REGISTER_KERNEL(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__) + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_kernel_ns_check_##kernel_name, \ + "PT_REGISTER_KERNEL must be called in global namespace."); \ + _PT_REGISTER_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__) + #ifndef _WIN32 -#define _PT_REGISTER_KERNEL( \ - kernel_name, func_id, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ - "PT_REGISTER_KERNEL must be called in global namespace."); \ - PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \ - static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ - func_id)(::pten::Kernel*); \ - PT_KERNEL_REGISTRAR_INIT(kernel_name, \ - func_id, \ - backend, \ - layout, \ - &PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__); \ - void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ - func_id)(::pten::Kernel * kernel) +#define _PT_REGISTER_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \ + static void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \ + PT_KERNEL_REGISTRAR_INIT(kernel_name, \ + backend, \ + layout, \ + &__PT_KERNEL_args_def_FN_##kernel_name, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__); \ + void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel) #else -#define _PT_REGISTER_KERNEL( \ - kernel_name, func_id, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ - "PT_REGISTER_KERNEL must be called in global namespace."); \ - static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ - func_id)(::pten::Kernel*); \ - PT_KERNEL_REGISTRAR_INIT(kernel_name, \ - func_id, \ - backend, \ - layout, \ - &PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__); \ - void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ - func_id)(::pten::Kernel * kernel) -#endif - -#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \ - _PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__) - -#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, cpp_dtype, ...) \ - PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ - (meta_kernel_fn, cpp_dtype, __VA_ARGS__) - /** * `template decltype(fn) fn` can work on gcc and clang, * but msvc will failed, error like: @@ -261,8 +235,30 @@ struct KernelRegistrar { * * https://stackoverflow.com/questions/63989585/explicit-instantiation-of-function-using-decltype-work-on-g-but-not-on-visua * - * So we solve the explict instantiation of kernel by CMake + * And msvc can work without template instantiation */ +#define _PT_REGISTER_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + static void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \ + PT_KERNEL_REGISTRAR_INIT(kernel_name, \ + backend, \ + layout, \ + &__PT_KERNEL_args_def_FN_##kernel_name, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__); \ + void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel) +#endif + +#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \ + _PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__) + +#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, cpp_dtype, ...) \ + PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ + (meta_kernel_fn, cpp_dtype, __VA_ARGS__) #define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, cpp_dtype, ...) \ template decltype(meta_kernel_fn) meta_kernel_fn @@ -309,22 +305,15 @@ struct KernelRegistrar { template decltype(meta_kernel_fn) meta_kernel_fn; \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__)) -#define PT_KERNEL_REGISTRAR_INIT(kernel_name, \ - func_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - _PT_KERNEL_REGISTRAR_INIT(PT_NARGS(cpp_dtype, __VA_ARGS__), \ - kernel_name, \ - func_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ +#define PT_KERNEL_REGISTRAR_INIT( \ + kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \ + _PT_KERNEL_REGISTRAR_INIT(PT_NARGS(cpp_dtype, __VA_ARGS__), \ + kernel_name, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ __VA_ARGS__) // clang-format off @@ -333,7 +322,6 @@ struct KernelRegistrar { and multi-line macros cannot be skipped with NOLINT.*/ #define _PT_KERNEL_REGISTRAR_INIT(N, \ kernel_name, \ - func_id, \ backend, \ layout, \ args_def_fn, \ @@ -342,7 +330,6 @@ struct KernelRegistrar { ...) \ PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \ kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -354,7 +341,6 @@ struct KernelRegistrar { // clang-format on #define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -363,17 +349,17 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn)); + PT_KERNEL(meta_kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } #define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -382,8 +368,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -392,7 +378,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -400,7 +385,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -409,8 +393,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -419,7 +403,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -427,7 +410,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -436,8 +418,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -446,7 +428,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -454,7 +435,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -463,8 +443,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -473,7 +453,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -481,7 +460,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -490,8 +468,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -500,7 +478,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -508,7 +485,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -517,8 +493,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -527,7 +503,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -535,7 +510,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -544,8 +518,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -554,7 +528,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -562,7 +535,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -571,8 +543,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -581,7 +553,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -589,7 +560,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -598,8 +568,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -608,7 +578,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -616,7 +585,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -625,8 +593,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -635,7 +603,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -643,7 +610,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -652,8 +618,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -662,7 +628,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -670,7 +635,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -679,8 +643,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -689,7 +653,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -697,7 +660,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -706,8 +668,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -716,7 +678,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -724,7 +685,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) #define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \ - func_id, \ registrar_id, \ backend, \ layout, \ @@ -733,8 +693,8 @@ struct KernelRegistrar { cpp_dtype, \ ...) \ static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ + __reg_pt_kernel_##kernel_name##_, registrar_id)( \ + #kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ @@ -743,7 +703,6 @@ struct KernelRegistrar { args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ - func_id, \ PT_ID, \ backend, \ layout, \ @@ -751,90 +710,59 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) -#define PT_REGISTER_KERNEL_STANDARD( \ - kernel_name, backend, layout, dtype, kernel_fn) \ - _PT_REGISTER_KERNEL_STANDARD( \ - kernel_name, PT_ID, backend, layout, dtype, kernel_fn) - -#define _PT_REGISTER_KERNEL_STANDARD( \ - kernel_name, func_id, backend, layout, dtype, kernel_fn) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ - "_PT_REGISTER_KERNEL_STANDARD must be called in global namespace."); \ - template decltype(kernel_fn) kernel_fn; \ - static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ - func_id)(::pten::Kernel*); \ - static const ::pten::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \ - func_id)( \ - kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - DATATYPE(dtype), \ - ::pten::KernelArgsParseFunctor::Parse, \ - args_def_fn, \ - PT_KERNEL(kernel_fn)); \ - void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id)(::pten::Kernel*) - -// use to declare symbol -#define PT_REGISTER_MODULE(name) \ - int RegisterSymbolsFor##name() { return 0; } - -#define PT_DECLARE_MODULE(name) \ - extern int RegisterSymbolsFor##name(); \ - UNUSED static int use_kernel_module_##name = RegisterSymbolsFor##name() - -// only used in cpp tests - -#define PT_REGISTER_KERNEL_FOR_TEST( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - _PT_REGISTER_KERNEL_FOR_TEST(kernel_name, \ - PT_ID, \ - backend, \ - layout, \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__) - -#define _PT_REGISTER_KERNEL_FOR_TEST( \ - kernel_name, func_id, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - PT_CONCATENATE(pt_op_kernel_for_test_ns_check_, func_id), \ - "PT_REGISTER_KERNEL must be called in global namespace."); \ - static void PT_CONCATENATE(__PT_KERNEL_for_test_args_def_FN_, \ - func_id)(::pten::Kernel*); \ - PT_KERNEL_REGISTRAR_INIT( \ - kernel_name, \ - func_id, \ - backend, \ - layout, \ - &PT_CONCATENATE(__PT_KERNEL_for_test_args_def_FN_, func_id), \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__); \ - void PT_CONCATENATE(__PT_KERNEL_for_test_args_def_FN_, \ - func_id)(::pten::Kernel * kernel) - -#define PT_REGISTER_KERNEL_WITH_NO_TYPE( \ - kernel_name, backend, layout, meta_kernel_fn) \ - _PT_REGISTER_KERNEL_WITH_NO_TYPE( \ - kernel_name, PT_ID, backend, layout, meta_kernel_fn) - -#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \ - kernel_name, func_id, backend, layout, meta_kernel_fn) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ - "PT_REGISTER_KERNEL must be called in global namespace."); \ - decltype(meta_kernel_fn) meta_kernel_fn; \ - static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ - func_id)(::pten::Kernel*); \ - static const ::pten::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \ - func_id)( \ - kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pten::KernelArgsParseFunctor::Parse, \ - &PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \ - PT_KERNEL(meta_kernel_fn)); \ - void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ - func_id)(::pten::Kernel * kernel) +/** PT_REGISTER_SINGLE_KERNEL + * + * Used to register a single kernel, pass in the complete function pointer + * of the kernel, this registration macro will not do automatic template + * instantiation. + */ +#define PT_REGISTER_SINGLE_KERNEL( \ + kernel_name, backend, layout, dtype, kernel_fn) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_single_kernel_ns_check_##kernel_name, \ + "PT_REGISTER_SINGLE_KERNEL must be called in global namespace."); \ + static void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \ + static const ::pten::KernelRegistrar __reg_pt_single_kernel_##kernel_name( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + DATATYPE(dtype), \ + ::pten::KernelArgsParseFunctor::Parse, \ + args_def_fn, \ + PT_KERNEL(kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \ + void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*) + +/** PT_REGISTER_KERNEL_ALL_DTYPE + * + * Used to register a kernel that supports all data types, such as copy and + * reshape that are not sensitive to data types. + */ +#define PT_REGISTER_KERNEL_ALL_DTYPE(kernel_name, backend, layout, kernel_fn) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_kernel_all_dtype_ns_check_##kernel_name, \ + "PT_REGISTER_KERNEL_ALL_DTYPE must be called in global namespace."); \ + static void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name( \ + ::pten::Kernel*); \ + static const ::pten::KernelRegistrar \ + __reg_pt_kernel_all_dtype_##kernel_name( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::pten::KernelArgsParseFunctor::Parse, \ + &__PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name, \ + PT_KERNEL(kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \ + void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name(::pten::Kernel* kernel) + +/** PT_DECLARE_KERNEL + * + * Used to export the symbols of the file where the kernel is located, + * to avoid being removed by linker + */ +#define PT_DECLARE_KERNEL(kernel_name, backend) \ + extern int TouchKernelSymbolFor_##kernel_name##_##backend(); \ + UNUSED static int __declare_kernel_symbol_for_##kernel_name##_##backend = \ + TouchKernelSymbolFor_##kernel_name##_##backend() + } // namespace pten diff --git a/paddle/pten/kernels/cpu/creation.cc b/paddle/pten/kernels/cpu/creation.cc index 4f09fc489f8f67a8b2c53af7da16e716f4ad58fc..4175203410f8da092cb2aa7673d5651d17245fbb 100644 --- a/paddle/pten/kernels/cpu/creation.cc +++ b/paddle/pten/kernels/cpu/creation.cc @@ -61,9 +61,7 @@ void FillConstant(const CPUContext& dev_ctx, } // namespace pten -PT_REGISTER_MODULE(CreationCPU); - -PT_REGISTER_KERNEL("full_like", +PT_REGISTER_KERNEL(full_like, CPU, ANY, pten::FillAnyLike, @@ -74,7 +72,7 @@ PT_REGISTER_KERNEL("full_like", bool, paddle::platform::float16) {} -PT_REGISTER_KERNEL("full", +PT_REGISTER_KERNEL(full, CPU, ANY, pten::FillConstant, diff --git a/paddle/pten/kernels/cpu/linalg.cc b/paddle/pten/kernels/cpu/linalg.cc index 32411560b55168cc1caa2b95253014aea03eb98a..7ffac0537b60c075d9667836a31c24ccf268199b 100644 --- a/paddle/pten/kernels/cpu/linalg.cc +++ b/paddle/pten/kernels/cpu/linalg.cc @@ -70,12 +70,10 @@ void Matmul(const CPUContext& dev_ctx, } // namespace pten -PT_REGISTER_MODULE(LinalgCPU); - using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; -PT_REGISTER_KERNEL("dot", +PT_REGISTER_KERNEL(dot, CPU, ANY, pten::Dot, @@ -87,5 +85,4 @@ PT_REGISTER_KERNEL("dot", complex128) {} PT_REGISTER_KERNEL( - "matmul_v2", CPU, ANY, pten::Matmul, float, double, complex64, complex128) { -} + matmul_v2, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {} diff --git a/paddle/pten/kernels/cpu/manipulation.cc b/paddle/pten/kernels/cpu/manipulation.cc index e0e9cefbf671bf07c375a4452453a8ebef8cd982..61c6cb57a9f780010a5f83dd874cb790fb9879dd 100644 --- a/paddle/pten/kernels/cpu/manipulation.cc +++ b/paddle/pten/kernels/cpu/manipulation.cc @@ -130,12 +130,9 @@ void Cast(const CPUContext& dev_ctx, } // namespace pten -// TODO(chenweihang): replace by better impl -PT_REGISTER_MODULE(ManipulationCPU); - // TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel // architecture, kernel_name should be "flatten". -PT_REGISTER_KERNEL("flatten", +PT_REGISTER_KERNEL(flatten, CPU, ANY, pten::Flatten, @@ -145,8 +142,7 @@ PT_REGISTER_KERNEL("flatten", int8_t, int, int64_t) {} - -PT_REGISTER_KERNEL("flatten.mid", +PT_REGISTER_KERNEL(flatten_mid, CPU, ANY, pten::FlattenWithXShape, @@ -156,7 +152,8 @@ PT_REGISTER_KERNEL("flatten.mid", int8_t, int, int64_t) {} -PT_REGISTER_KERNEL("cast", + +PT_REGISTER_KERNEL(cast, CPU, ANY, pten::Cast, @@ -174,42 +171,33 @@ PT_REGISTER_KERNEL("cast", kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } -// TODO(yuanrisheng): "reshape2" is compatible with old kernel -// architecture, kernel_name should be "reshape". -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape", - CPU, - ANY, - pten::ReshapeFromVectorVal) {} - -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mid", - CPU, - ANY, - pten::ReshapeFromVectorValWithXShape) {} - -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host", CPU, ANY, pten::ReshapeFromDT) { +PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::ReshapeFromVectorVal) {} +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid, + CPU, + ANY, + pten::ReshapeFromVectorValWithXShape) {} +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CPU, ANY, pten::ReshapeFromDT) { kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); } - -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host.mid", - CPU, - ANY, - pten::ReshapeFromDTWithXShape) { +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid, + CPU, + ANY, + pten::ReshapeFromDTWithXShape) { kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); } -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost", - CPU, - ANY, - pten::ReshapeFromVectorDT) { +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost, + CPU, + ANY, + pten::ReshapeFromVectorDT) { kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); } - -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost.mid", - CPU, - ANY, - pten::ReshapeFromVectorDTWithXShape) { +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid, + CPU, + ANY, + pten::ReshapeFromVectorDTWithXShape) { kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); } diff --git a/paddle/pten/kernels/cpu/math.cc b/paddle/pten/kernels/cpu/math.cc index ddfb2f5f45854001c6d4d7000d64be23058153b5..2d556d96c2fcf7effa1dbe7f0f31adeded6a403a 100644 --- a/paddle/pten/kernels/cpu/math.cc +++ b/paddle/pten/kernels/cpu/math.cc @@ -106,18 +106,14 @@ DEFINE_CPU_ELEMENTWISE_OP(Mul) } // namespace pten -// TODO(chenweihang): replace by better impl -PT_REGISTER_MODULE(MathCPU); - using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // using bfloat16 = ::paddle::platform::bfloat16; - -PT_REGISTER_KERNEL("sign", CPU, ANY, pten::Sign, float, double) {} -PT_REGISTER_KERNEL("mean", CPU, ANY, pten::Mean, float, double, bool) {} -PT_REGISTER_KERNEL("scale", +PT_REGISTER_KERNEL(sign, CPU, ANY, pten::Sign, float, double) {} +PT_REGISTER_KERNEL(mean, CPU, ANY, pten::Mean, float, double, bool) {} +PT_REGISTER_KERNEL(scale, CPU, ANY, pten::Scale, @@ -129,8 +125,7 @@ PT_REGISTER_KERNEL("scale", int16_t, int, int64_t) {} - -PT_REGISTER_KERNEL("add", +PT_REGISTER_KERNEL(add, CPU, ANY, pten::ElementwiseAdd, @@ -140,7 +135,7 @@ PT_REGISTER_KERNEL("add", int64_t, complex64, complex128) {} -PT_REGISTER_KERNEL("subtract", +PT_REGISTER_KERNEL(subtract, CPU, ANY, pten::ElementwiseSub, @@ -150,7 +145,7 @@ PT_REGISTER_KERNEL("subtract", int64_t, complex64, complex128) {} -PT_REGISTER_KERNEL("divide", +PT_REGISTER_KERNEL(divide, CPU, ANY, pten::ElementwiseDiv, @@ -160,7 +155,7 @@ PT_REGISTER_KERNEL("divide", int64_t, complex64, complex128) {} -PT_REGISTER_KERNEL("multiply", +PT_REGISTER_KERNEL(multiply, CPU, ANY, pten::ElementwiseMul, @@ -171,8 +166,7 @@ PT_REGISTER_KERNEL("multiply", bool, complex64, complex128) {} - -PT_REGISTER_KERNEL("sum", +PT_REGISTER_KERNEL(sum, CPU, ANY, pten::Sum, diff --git a/paddle/pten/kernels/cpu/utils.cc b/paddle/pten/kernels/cpu/utils.cc index b462ef70c2f06ecda8e21cc32a847f2331be978c..500b4664d638885d8faba393e267c2fb0e556577 100644 --- a/paddle/pten/kernels/cpu/utils.cc +++ b/paddle/pten/kernels/cpu/utils.cc @@ -57,7 +57,4 @@ void Copy(const CPUContext& dev_ctx, } // namespace pten -// TODO(chenweihang): replace by better impl -PT_REGISTER_MODULE(UtilsCPU); - -PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", CPU, ANY, pten::Copy) {} +PT_REGISTER_KERNEL_ALL_DTYPE(copy, CPU, ANY, pten::Copy) {} diff --git a/paddle/pten/kernels/cuda/creation.cu b/paddle/pten/kernels/cuda/creation.cu index 8bc23fb6af056fa3f32b20ec021151b084d7c19d..dd29fd5fbb84d2536ef63c37c8fb0fd59d285258 100644 --- a/paddle/pten/kernels/cuda/creation.cu +++ b/paddle/pten/kernels/cuda/creation.cu @@ -62,9 +62,7 @@ void FillConstant(const CUDAContext& dev_ctx, } // namespace pten -PT_REGISTER_MODULE(CreationCUDA); - -PT_REGISTER_KERNEL("full_like", +PT_REGISTER_KERNEL(full_like, CUDA, ANY, pten::FillAnyLike, @@ -75,7 +73,7 @@ PT_REGISTER_KERNEL("full_like", bool, paddle::platform::float16) {} -PT_REGISTER_KERNEL("full", +PT_REGISTER_KERNEL(full, CUDA, ANY, pten::FillConstant, diff --git a/paddle/pten/kernels/cuda/linalg.cu b/paddle/pten/kernels/cuda/linalg.cu index fe2ac6f184ff7353df5bb55b9b7ff18911d3bf1d..b08ed8f71ee6b274e7abd54fb682220328342661 100644 --- a/paddle/pten/kernels/cuda/linalg.cu +++ b/paddle/pten/kernels/cuda/linalg.cu @@ -54,13 +54,11 @@ void Matmul(const CUDAContext& dev_ctx, } // namespace pten -PT_REGISTER_MODULE(LinalgCUDA); - using float16 = paddle::platform::float16; using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; -PT_REGISTER_KERNEL("dot", +PT_REGISTER_KERNEL(dot, CUDA, ANY, pten::Dot, @@ -71,7 +69,7 @@ PT_REGISTER_KERNEL("dot", complex64, complex128) {} -PT_REGISTER_KERNEL("matmul_v2", +PT_REGISTER_KERNEL(matmul_v2, CUDA, ANY, pten::Matmul, diff --git a/paddle/pten/kernels/cuda/manipulation.cu b/paddle/pten/kernels/cuda/manipulation.cu index acaf1ac2cc62b28002a6c15d564251cfe352f5dc..e668d1b04d7238e6bbcebc7710abb61f885b1659 100644 --- a/paddle/pten/kernels/cuda/manipulation.cu +++ b/paddle/pten/kernels/cuda/manipulation.cu @@ -129,13 +129,9 @@ void Cast(const CUDAContext& dev_ctx, } // namespace pten -// TODO(chenweihang): replace by better impl -PT_REGISTER_MODULE(ManipulationCUDA); - using float16 = paddle::platform::float16; -// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel -// architecture, kernel_name should be "flatten". -PT_REGISTER_KERNEL("flatten", + +PT_REGISTER_KERNEL(flatten, CUDA, ANY, pten::Flatten, @@ -146,8 +142,7 @@ PT_REGISTER_KERNEL("flatten", int8_t, int, int64_t) {} - -PT_REGISTER_KERNEL("flatten.mid", +PT_REGISTER_KERNEL(flatten_mid, CUDA, ANY, pten::FlattenWithXShape, @@ -159,7 +154,7 @@ PT_REGISTER_KERNEL("flatten.mid", int64_t) {} #define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \ - PT_REGISTER_KERNEL("cast", \ + PT_REGISTER_KERNEL(cast, \ CUDA, \ ANY, \ pten::Cast, \ @@ -184,44 +179,33 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) #endif -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape", - CUDA, - ANY, - pten::ReshapeFromVectorVal) {} - -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mid", - CUDA, - ANY, - pten::ReshapeFromVectorValWithXShape) {} - -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host", - CUDA, - ANY, - pten::ReshapeFromDT) { +PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::ReshapeFromVectorVal) {} +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid, + CUDA, + ANY, + pten::ReshapeFromVectorValWithXShape) {} +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CUDA, ANY, pten::ReshapeFromDT) { kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); } - -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host.mid", - CUDA, - ANY, - pten::ReshapeFromDTWithXShape) { +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid, + CUDA, + ANY, + pten::ReshapeFromDTWithXShape) { kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); } - -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost", - CUDA, - ANY, - pten::ReshapeFromVectorDT) { +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost, + CUDA, + ANY, + pten::ReshapeFromVectorDT) { kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); } - -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost.mid", - CUDA, - ANY, - pten::ReshapeFromVectorDTWithXShape) { +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid, + CUDA, + ANY, + pten::ReshapeFromVectorDTWithXShape) { kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); } diff --git a/paddle/pten/kernels/cuda/math.cu b/paddle/pten/kernels/cuda/math.cu index 388d42942c10a4afbdca2b4dd5c8ed9a0b9c319d..66aaf14dcd0f62ac2b752d803b136994815963df 100644 --- a/paddle/pten/kernels/cuda/math.cu +++ b/paddle/pten/kernels/cuda/math.cu @@ -111,16 +111,13 @@ void Sum(const CUDAContext& dev_ctx, } // namespace pten -// TODO(chenweihang): replace by better impl -PT_REGISTER_MODULE(MathCUDA); - using float16 = paddle::platform::float16; using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; -PT_REGISTER_KERNEL("sign", CUDA, ANY, pten::Sign, float, double, float16) {} -PT_REGISTER_KERNEL("mean", CUDA, ANY, pten::Mean, float, double, bool) {} -PT_REGISTER_KERNEL("scale", +PT_REGISTER_KERNEL(sign, CUDA, ANY, pten::Sign, float, double, float16) {} +PT_REGISTER_KERNEL(mean, CUDA, ANY, pten::Mean, float, double, bool) {} +PT_REGISTER_KERNEL(scale, CUDA, ANY, pten::Scale, @@ -132,7 +129,7 @@ PT_REGISTER_KERNEL("scale", int16_t, int, int64_t) {} -PT_REGISTER_KERNEL("add", +PT_REGISTER_KERNEL(add, CUDA, ANY, pten::ElementwiseAdd, @@ -143,7 +140,7 @@ PT_REGISTER_KERNEL("add", float16, complex64, complex128) {} -PT_REGISTER_KERNEL("subtract", +PT_REGISTER_KERNEL(subtract, CUDA, ANY, pten::ElementwiseSub, @@ -154,7 +151,7 @@ PT_REGISTER_KERNEL("subtract", float16, complex64, complex128) {} -PT_REGISTER_KERNEL("divide", +PT_REGISTER_KERNEL(divide, CUDA, ANY, pten::ElementwiseDiv, @@ -165,7 +162,7 @@ PT_REGISTER_KERNEL("divide", float16, complex64, complex128) {} -PT_REGISTER_KERNEL("multiply", +PT_REGISTER_KERNEL(multiply, CUDA, ANY, pten::ElementwiseMul, @@ -177,7 +174,7 @@ PT_REGISTER_KERNEL("multiply", float16, complex64, complex128) {} -PT_REGISTER_KERNEL("sum", +PT_REGISTER_KERNEL(sum, CUDA, ANY, pten::Sum, diff --git a/paddle/pten/kernels/cuda/utils.cu b/paddle/pten/kernels/cuda/utils.cu index 24da650d1f3eb9664bd69923ba4247d4a341476c..49027e956b2d7dc28b77d72c2db684ae0007a15c 100644 --- a/paddle/pten/kernels/cuda/utils.cu +++ b/paddle/pten/kernels/cuda/utils.cu @@ -234,7 +234,4 @@ void Copy(const CUDAContext& dev_ctx, } } // namespace pten -// TODO(chenweihang): replace by better impl -PT_REGISTER_MODULE(UtilsCUDA); - -PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", CUDA, ANY, pten::Copy) {} +PT_REGISTER_KERNEL_ALL_DTYPE(copy, CUDA, ANY, pten::Copy) {} diff --git a/paddle/pten/kernels/xpu/manipulation.cc b/paddle/pten/kernels/xpu/manipulation.cc index 5f1c0d42eb5a8faf2d6b488ebded52684abfab3c..f361933cad45a5e703b106588a8b3c5514269e78 100644 --- a/paddle/pten/kernels/xpu/manipulation.cc +++ b/paddle/pten/kernels/xpu/manipulation.cc @@ -95,12 +95,7 @@ void ReshapeFromVectorDT(const XPUContext& dev_ctx, } // namespace pten -// TODO(chenweihang): replace by better impl -PT_REGISTER_MODULE(ManipulationXPU); - -// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel -// architecture, kernel_name should be "flatten". -PT_REGISTER_KERNEL("flatten_contiguous_range", +PT_REGISTER_KERNEL(flatten, XPU, ANY, pten::Flatten, @@ -112,7 +107,7 @@ PT_REGISTER_KERNEL("flatten_contiguous_range", int, int64_t) {} -PT_REGISTER_KERNEL("flatten_contiguous_range.mid", +PT_REGISTER_KERNEL(flatten_mid, XPU, ANY, pten::FlattenWithXShape, @@ -124,9 +119,4 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid", int, int64_t) {} -// TODO(yuanrisheng): "reshape2" is compatible with old kernel -// architecture, kernel_name should be "reshape". -PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2", - XPU, - ANY, - pten::ReshapeFromVectorVal) {} +PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::ReshapeFromVectorVal) {} diff --git a/paddle/pten/kernels/xpu/utils.cc b/paddle/pten/kernels/xpu/utils.cc index 329dc2baf87b58e03906aba0b83fbb2552fd56ca..5c98217f4ec2c41b39729f9df76ea4ffba277286 100644 --- a/paddle/pten/kernels/xpu/utils.cc +++ b/paddle/pten/kernels/xpu/utils.cc @@ -76,7 +76,4 @@ void Copy(const XPUDeviceContext& dev_ctx, } // namespace pten -// TODO(chenweihang): replace by better impl -PT_REGISTER_MODULE(UtilsXPU); - -PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", XPU, ANY, pten::Copy) {} +PT_REGISTER_KERNEL_ALL_DTYPE(copy, XPU, ANY, pten::Copy) {} diff --git a/paddle/pten/tests/api/test_reshape_api.cc b/paddle/pten/tests/api/test_reshape_api.cc index b6179f11b1019e26ac248e92a48d8be2971a2fac..227dcc6e9568d7c44730eeab70dca27cace1f482 100644 --- a/paddle/pten/tests/api/test_reshape_api.cc +++ b/paddle/pten/tests/api/test_reshape_api.cc @@ -21,12 +21,6 @@ limitations under the License. */ #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_registry.h" -PT_DECLARE_MODULE(ManipulationCPU); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PT_DECLARE_MODULE(ManipulationCUDA); -#endif - namespace paddle { namespace tests { diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index 5506ee95bd7c9ed40f4b53a23109530f2e7a46cf..ed3bb1dc5f1f0122eb5be79ac07b5067c1288ed1 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -345,6 +345,7 @@ def source_include(header_file_path): #include "glog/logging.h" #include "paddle/pten/api/lib/api_registry.h" +#include "paddle/pten/api/lib/kernel_declare.h" #include "paddle/pten/api/lib/kernel_dispatch.h" #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/core/kernel_registry.h" @@ -353,22 +354,6 @@ def source_include(header_file_path): """ -def module_declare(): - return """ -PT_DECLARE_MODULE(CreationCPU); -PT_DECLARE_MODULE(LinalgCPU); -PT_DECLARE_MODULE(ManipulationCPU); -PT_DECLARE_MODULE(MathCPU); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PT_DECLARE_MODULE(CreationCUDA); -PT_DECLARE_MODULE(LinalgCUDA); -PT_DECLARE_MODULE(ManipulationCUDA); -PT_DECLARE_MODULE(MathCUDA); -#endif -""" - - def api_register(): return """ PT_REGISTER_API(Creation); @@ -405,7 +390,6 @@ def generate_api(api_yaml_path, header_file_path, source_file_path): include_header_file = "paddle/pten/api/include/api.h" source_file.write(source_include(include_header_file)) - source_file.write(module_declare()) source_file.write(namespace[0]) for api in apis: