diff --git a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc index e1e8f09ddff8fec6132546e9a97874c02a672199..e422774bf9cf0e9801bb41a38620624c8d25af4e 100644 --- a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc +++ b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc @@ -33,31 +33,31 @@ static void ScaleDeviceDispatch(const pten::DenseTensor& dense_tensor, pten::DenseTensor* dense_out) { switch (dense_tensor.dtype()) { case pten::DataType::FLOAT64: { - pten::Scale(dev_ctx, dense_tensor /* tensor */, scale /* scale */, - bias /* bias */, - bias_after_scale /* bias_after_scale */, - dense_out /* out tensor */); + pten::Scale( + dev_ctx, dense_tensor /* tensor */, scale /* scale */, + bias /* bias */, bias_after_scale /* bias_after_scale */, + dense_out /* out tensor */); break; } case pten::DataType::FLOAT32: { - pten::Scale(dev_ctx, dense_tensor /* tensor */, scale /* scale */, - bias /* bias */, - bias_after_scale /* bias_after_scale */, - dense_out /* out tensor */); + pten::Scale(dev_ctx, dense_tensor /* tensor */, + scale /* scale */, bias /* bias */, + bias_after_scale /* bias_after_scale */, + dense_out /* out tensor */); break; } case pten::DataType::INT64: { - pten::Scale(dev_ctx, dense_tensor /* tensor */, - scale /* scale */, bias /* bias */, - bias_after_scale /* bias_after_scale */, - dense_out /* out tensor */); + pten::Scale( + dev_ctx, dense_tensor /* tensor */, scale /* scale */, + bias /* bias */, bias_after_scale /* bias_after_scale */, + dense_out /* out tensor */); break; } case pten::DataType::INT32: { - pten::Scale(dev_ctx, dense_tensor /* tensor */, - scale /* scale */, bias /* bias */, - bias_after_scale /* bias_after_scale */, - dense_out /* out tensor */); + pten::Scale( + dev_ctx, dense_tensor /* tensor */, scale /* scale */, + bias /* bias */, bias_after_scale /* bias_after_scale */, + dense_out /* out tensor */); break; } default: { diff --git a/paddle/fluid/operators/scale_op.h b/paddle/fluid/operators/scale_op.h index a75c9fd4fd2450a86fb2e59ac492c897f68d1bb7..6011fe9a66b60b8cb3e5438f3233e4f25ca486fb 100644 --- a/paddle/fluid/operators/scale_op.h +++ b/paddle/fluid/operators/scale_op.h @@ -20,7 +20,7 @@ limitations under the License. */ // only can include the headers in paddle/top/api dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/include/core.h" -#include "paddle/pten/include/math.h" +#include "paddle/pten/kernels/scale_kernel.h" namespace paddle { namespace operators { diff --git a/paddle/pten/CMakeLists.txt b/paddle/pten/CMakeLists.txt index 4538016333e5e76663c328c7b3e78c3ceb2b7fda..8990def122601e99678b213aa16f3f69bf8b0ffc 100644 --- a/paddle/pten/CMakeLists.txt +++ b/paddle/pten/CMakeLists.txt @@ -24,6 +24,7 @@ add_subdirectory(tests) # make an unity target for compile deps set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context) +set(PTEN_DEPS scale_kernel_eigen) set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu creation_cpu manipulation_cpu) set(PTEN_DEPS ${PTEN_DEPS} nary unary binary) if(WITH_GPU OR WITH_ROCM) diff --git a/paddle/pten/api/lib/kernel_declare.h b/paddle/pten/api/lib/kernel_declare.h index fa11811178322fac4b8c8b618ffc9fe4506a283e..0f4f82b9d7c51f72c59e14b5f8bac287ea850f7f 100644 --- a/paddle/pten/api/lib/kernel_declare.h +++ b/paddle/pten/api/lib/kernel_declare.h @@ -24,6 +24,7 @@ PT_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(dot, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(sign, CPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PT_DECLARE_KERNEL(full_like, CUDA, ALL_LAYOUT); diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index f46e9afd3defbd2293bc86e62c2dbcb2e4cf74e1..62a46e128e513a779b379c1c76f806169e1a426d 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -772,6 +772,558 @@ struct KernelRegistrar { void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ ::pten::Kernel* kernel) +/** PT_REGISTER_CTX_KERNEL + * + * Used for kernel registration with device context and data type as + * template parameter. + */ +#define PT_REGISTER_CTX_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_tp_ctx_kernel_ns_check_##kernel_name##_##backend##_##layout, \ + "PT_REGISTER_CTX_KERNEL must be called in global namespace."); \ + _PT_REGISTER_2TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__) + +#ifndef _WIN32 +#define _PT_REGISTER_2TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + PT_KERNEL_REGISTRAR_INIT2( \ + kernel_name, \ + backend, \ + layout, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__); \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) +#else +#define _PT_REGISTER_2TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + PT_KERNEL_REGISTRAR_INIT2( \ + kernel_name, \ + backend, \ + layout, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__); \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) +#endif + +#define PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, ...) \ + _PT_KERNEL_INSTANTIATION2(PT_NARGS(cpp_dtype, __VA_ARGS__), \ + meta_kernel_fn, \ + backend, \ + cpp_dtype, \ + __VA_ARGS__) + +#define _PT_KERNEL_INSTANTIATION2(N, meta_kernel_fn, backend, cpp_dtype, ...) \ + PT_CONCATENATE(_PT_KERNEL_INSTANTIATION2_, N) \ + (meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__) + +#define _PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn +#define _PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_15(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, __VA_ARGS__)) + +#define PT_KERNEL_REGISTRAR_INIT2( \ + kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \ + _PT_KERNEL_REGISTRAR_INIT2(PT_NARGS(cpp_dtype, __VA_ARGS__), \ + kernel_name, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__) + +// clang-format off + +/* The =pre-commit always treats this macro into the wrong format, + and multi-line macros cannot be skipped with NOLINT.*/ +#define _PT_KERNEL_REGISTRAR_INIT2(N, \ + kernel_name, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT2_, N) ( \ + kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__) + +// clang-format on + +#define _PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; } +#define _PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_15(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) + /** PT_DECLARE_KERNEL * * Used to export the symbols of the file where the kernel is located, diff --git a/paddle/pten/include/math.h b/paddle/pten/include/math.h index 01dd0a7fda79a6bc459fa250ef1848a6bc38296b..3872f663fed7800be161dfceaebf319532b4ab54 100644 --- a/paddle/pten/include/math.h +++ b/paddle/pten/include/math.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/pten/include/infermeta.h" #include "paddle/pten/kernels/cpu/math.h" #include "paddle/pten/kernels/cuda/math.h" +#include "paddle/pten/kernels/scale_kernel.h" namespace pten { @@ -86,7 +87,7 @@ DenseTensor Scale(const ContextT& dev_ctx, pten::make_intrusive( dev_ctx.GetPlace()), std::move(out_meta)); - Scale(dev_ctx, x, scale, bias, bias_after_scale, &dense_out); + Scale(dev_ctx, x, scale, bias, bias_after_scale, &dense_out); return dense_out; } diff --git a/paddle/pten/kernels/cpu/math.cc b/paddle/pten/kernels/cpu/math.cc index 67493015c61e2e50bfbd46226f7183ac15071fc2..659a4d0e09686c7d05d372ecd9e8aaf5e5c13b03 100644 --- a/paddle/pten/kernels/cpu/math.cc +++ b/paddle/pten/kernels/cpu/math.cc @@ -17,7 +17,6 @@ #include "paddle/pten/api/ext/dispatch.h" #include "paddle/pten/kernels/hybird/cpu/elementwise.h" #include "paddle/pten/kernels/hybird/eigen/reduce.h" -#include "paddle/pten/kernels/hybird/eigen/scale.h" #include "paddle/pten/kernels/hybird/eigen/sign.h" #include "paddle/pten/kernels/hybird/general/elementwise_functor.h" #include "paddle/pten/kernels/hybird/general/reduce_impl.h" @@ -47,17 +46,6 @@ void Mean(const CPUContext& dev_ctx, dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } -template -void Scale(const CPUContext& dev_ctx, - const DenseTensor& x, - const Scalar& scale, - float bias, - bool bias_after_scale, - DenseTensor* out) { - eigen::Scale( - dev_ctx, x, scale.to(), bias, bias_after_scale, out); -} - template void Divide(const CPUContext& dev_ctx, const DenseTensor& x, @@ -113,18 +101,6 @@ using complex128 = ::paddle::platform::complex; // using bfloat16 = ::paddle::platform::bfloat16; PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::Sign, float, double) {} PT_REGISTER_KERNEL(mean, CPU, ALL_LAYOUT, pten::Mean, float, double, bool) {} -PT_REGISTER_KERNEL(scale, - CPU, - ALL_LAYOUT, - pten::Scale, - float, - double, - paddle::platform::bfloat16, - uint8_t, - int8_t, - int16_t, - int, - int64_t) {} PT_REGISTER_KERNEL(add, CPU, ALL_LAYOUT, diff --git a/paddle/pten/kernels/cpu/math.h b/paddle/pten/kernels/cpu/math.h index 5ee0f9f8956b16a6aa960656bcf82dfd530060f4..c53e659cf83200c48b750b8085d98795e820ed2f 100644 --- a/paddle/pten/kernels/cpu/math.h +++ b/paddle/pten/kernels/cpu/math.h @@ -38,14 +38,6 @@ void Mean(const CPUContext& dev_ctx, DataType out_dtype, DenseTensor* out); -template -void Scale(const CPUContext& dev_ctx, - const DenseTensor& x, - const Scalar& scale, - float bias, - bool bias_after_scale, - DenseTensor* out); - template void Add(const CPUContext& dev_ctx, const DenseTensor& x, diff --git a/paddle/pten/kernels/cuda/math.cu b/paddle/pten/kernels/cuda/math.cu index e0974181dc833de241f40e8b4487b2309e8107d4..27d1ba1e043fe88f8edd0afd6e81f99ef41cffed 100644 --- a/paddle/pten/kernels/cuda/math.cu +++ b/paddle/pten/kernels/cuda/math.cu @@ -17,7 +17,6 @@ limitations under the License. */ #include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h" #include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h" -#include "paddle/pten/kernels/hybird/eigen/scale.h" #include "paddle/pten/kernels/hybird/eigen/sign.h" #include "paddle/pten/kernels/hybird/general/elementwise_functor.h" #include "paddle/pten/kernels/hybird/general/reduce_impl.h" @@ -76,17 +75,6 @@ void Mean(const CUDAContext& dev_ctx, dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } -template -void Scale(const CUDAContext& dev_ctx, - const DenseTensor& x, - const Scalar& scale, - float bias, - bool bias_after_scale, - DenseTensor* out) { - eigen::Scale( - dev_ctx, x, scale.to(), bias, bias_after_scale, out); -} - // Create the definition of Add DEFINE_CUDA_ELEMENTWISE_OP(Add) // Create the definition of Subtract @@ -118,18 +106,6 @@ using complex128 = ::paddle::platform::complex; PT_REGISTER_KERNEL(sign, CUDA, ALL_LAYOUT, pten::Sign, float, double, float16) { } PT_REGISTER_KERNEL(mean, CUDA, ALL_LAYOUT, pten::Mean, float, double, bool) {} -PT_REGISTER_KERNEL(scale, - CUDA, - ALL_LAYOUT, - pten::Scale, - float, - double, - float16, - uint8_t, - int8_t, - int16_t, - int, - int64_t) {} PT_REGISTER_KERNEL(add, CUDA, ALL_LAYOUT, diff --git a/paddle/pten/kernels/eigen/CMakeLists.txt b/paddle/pten/kernels/eigen/CMakeLists.txt index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..27e757256881b276d752b25c42e597c34609c930 100644 --- a/paddle/pten/kernels/eigen/CMakeLists.txt +++ b/paddle/pten/kernels/eigen/CMakeLists.txt @@ -0,0 +1,7 @@ +if(WITH_GPU) + nv_library(scale_kernel_eigen SRCS scale_kernel.cc DEPS dense_tensor eigen_function) +elseif(WITH_ROCM) + hip_library(scale_kernel_eigen SRCS scale_kernel.cc DEPS dense_tensor eigen_function) +else() + cc_library(scale_kernel_eigen SRCS scale_kernel.cc DEPS dense_tensor eigen_function) +endif() diff --git a/paddle/pten/kernels/hybird/eigen/scale.h b/paddle/pten/kernels/eigen/scale_kernel.cc similarity index 50% rename from paddle/pten/kernels/hybird/eigen/scale.h rename to paddle/pten/kernels/eigen/scale_kernel.cc index 111f6c22cc35e951825efa707ec3bc8ebe8f9662..5ec27be3af9015b423f837ff16dcb0e19b38e9d2 100644 --- a/paddle/pten/kernels/hybird/eigen/scale.h +++ b/paddle/pten/kernels/eigen/scale_kernel.cc @@ -12,21 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once +#include "paddle/pten/kernels/scale_kernel.h" -#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/kernels/hybird/eigen/common.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/float16.h" namespace pten { -namespace eigen { -template -void Scale(const DevCtx& dev_ctx, +template +void Scale(const ContextT& dev_ctx, const DenseTensor& x, - float scale, + const Scalar& scale, float bias, bool bias_after_scale, DenseTensor* out) { @@ -42,10 +44,41 @@ void Scale(const DevCtx& dev_ctx, dev, eigen_out, eigen_x, - static_cast(scale), + scale.to(), static_cast(bias), bias_after_scale); } -} // namespace eigen } // namespace pten + +// TODO(chenweihang): Use EigenContext to specialize the ContextT parameter, +// and only register the backend as Eigen's kernel during registration, +// instead of using macros to register the CPU and CUDA kernels separately + +PT_REGISTER_CTX_KERNEL(scale, + CPU, + ALL_LAYOUT, + pten::Scale, + float, + double, + paddle::platform::bfloat16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_REGISTER_CTX_KERNEL(scale, + CUDA, + ALL_LAYOUT, + pten::Scale, + float, + double, + paddle::platform::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} +#endif diff --git a/paddle/pten/kernels/scale_kernel.h b/paddle/pten/kernels/scale_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..bb3c1968fce9e4487d3eeb2295995d1493cc6467 --- /dev/null +++ b/paddle/pten/kernels/scale_kernel.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void Scale(const ContextT& dev_ctx, + const DenseTensor& x, + const Scalar& scale, + float bias, + bool bias_after_scale, + DenseTensor* out); + +} // namespace pten diff --git a/paddle/pten/tests/api/CMakeLists.txt b/paddle/pten/tests/api/CMakeLists.txt index 88faa773dfaffb36ac2d188f1171811fdd938ef5..a230b6a4181875dff7b625c3277a23088341f4a2 100644 --- a/paddle/pten/tests/api/CMakeLists.txt +++ b/paddle/pten/tests/api/CMakeLists.txt @@ -21,4 +21,4 @@ cc_test(test_to_api SRCS test_to_api.cc DEPS pten_tensor pten_api pten_api_utils cc_test(test_slice_api SRCS test_slice_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_sum_api SRCS test_sum_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api_utils) -cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils) +cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils scale_kernel_eigen) diff --git a/paddle/pten/tests/api/scale_api.h b/paddle/pten/tests/api/scale_api.h index 5668cbe29439c454ef249343a90669de7d210480..1defbd02ddd1104b286d23b5bdbaf6ef62fe924e 100644 --- a/paddle/pten/tests/api/scale_api.h +++ b/paddle/pten/tests/api/scale_api.h @@ -25,8 +25,7 @@ #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/include/core.h" #include "paddle/pten/include/infermeta.h" -#include "paddle/pten/kernels/cpu/math.h" -#include "paddle/pten/kernels/cuda/math.h" +#include "paddle/pten/kernels/scale_kernel.h" namespace paddle { namespace experimental {