From 8895379a0a9d1223480071d97befe71876272623 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 25 Feb 2022 10:32:27 +0800 Subject: [PATCH] [Phi] Support cudnn kernel moving & move softmax kernels (#39547) * support cudnn kernel moving * polish cmake rules * add unittest for coverage * remove orig kernel * remove softmax cudnn kernel * fix softmax test failed * fix npu func error * resolve conflict * rename gpu dnn kernels * fix name rule error * fix compile error * update fp16 namespace --- cmake/phi.cmake | 89 ++-- .../ir/mkldnn/mkldnn_inplace_pass_tester.cc | 2 +- paddle/fluid/framework/pten_utils.cc | 4 +- paddle/fluid/framework/pten_utils_test.cc | 37 +- .../inference/tensorrt/convert/softmax_op.cc | 2 +- .../tensorrt/convert/test_softmax_op.cc | 2 +- .../c_softmax_with_cross_entropy_op.cu | 9 +- .../c_softmax_with_cross_entropy_op.h | 1 - paddle/fluid/operators/fused/fmha_ref.h | 22 +- .../operators/margin_cross_entropy_op.cu | 9 +- .../fluid/operators/margin_cross_entropy_op.h | 1 - paddle/fluid/operators/math/softmax.cc | 8 + paddle/fluid/operators/math/softmax.cu | 11 + .../operators/mkldnn/softmax_mkldnn_op.cc | 9 +- .../operators/mkldnn/test_mkldnn_caching.cc | 2 +- .../mkldnn/test_mkldnn_op_inplace.cc | 2 +- paddle/fluid/operators/softmax_cudnn_op.cu | 72 --- paddle/fluid/operators/softmax_op.cc | 10 +- paddle/fluid/operators/softmax_op.cu.cc | 27 -- paddle/fluid/operators/softmax_op.h | 114 ----- paddle/fluid/operators/softmax_op_npu.cc | 5 +- paddle/fluid/operators/softmax_op_npu_test.cc | 2 +- paddle/fluid/operators/softmax_op_xpu.cc | 6 +- .../softmax_with_cross_entropy_op.cc | 4 +- .../softmax_with_cross_entropy_op.cu | 33 +- .../operators/softmax_with_cross_entropy_op.h | 21 +- .../softmax_with_cross_entropy_op_npu.cc | 22 +- .../softmax_with_cross_entropy_op_xpu.cc | 12 +- .../test_common_infer_shape_functions.cc | 2 +- paddle/phi/backends/gpu/gpu_context.h | 7 + paddle/phi/common/backend.h | 10 +- paddle/phi/common/float16.h | 12 + paddle/phi/core/compat/convert_utils.cc | 2 +- paddle/phi/kernels/CMakeLists.txt | 11 +- paddle/phi/kernels/cpu/softmax_grad_kernel.cc | 22 + paddle/phi/kernels/cpu/softmax_kernel.cc | 22 + paddle/phi/kernels/funcs/axis_utils.h | 54 +++ paddle/phi/kernels/funcs/concat_funcs.h | 2 +- paddle/phi/kernels/funcs/eigen/elementwise.cu | 2 +- paddle/phi/kernels/gpu/softmax_grad_kernel.cu | 28 ++ paddle/phi/kernels/gpu/softmax_kernel.cu | 28 ++ .../kernels/gpudnn/softmax_gpudnn.h} | 444 ++++++++++++------ .../gpudnn/softmax_grad_kernel_gpudnn.cu | 50 ++ .../kernels/gpudnn/softmax_kernel_gpudnn.cu | 49 ++ .../kernels/impl/softmax_grad_kernel_impl.h | 51 ++ paddle/phi/kernels/impl/softmax_kernel_impl.h | 48 ++ paddle/phi/kernels/softmax_grad_kernel.h | 29 ++ paddle/phi/kernels/softmax_kernel.h | 38 ++ paddle/phi/ops/compat/softmax_sig.cc | 34 ++ paddle/phi/tests/common/test_backend.cc | 6 +- 50 files changed, 996 insertions(+), 493 deletions(-) delete mode 100644 paddle/fluid/operators/softmax_cudnn_op.cu delete mode 100644 paddle/fluid/operators/softmax_op.cu.cc delete mode 100644 paddle/fluid/operators/softmax_op.h create mode 100644 paddle/phi/kernels/cpu/softmax_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/softmax_kernel.cc create mode 100644 paddle/phi/kernels/funcs/axis_utils.h create mode 100644 paddle/phi/kernels/gpu/softmax_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/softmax_kernel.cu rename paddle/{fluid/operators/softmax_cudnn_op.cu.h => phi/kernels/gpudnn/softmax_gpudnn.h} (63%) create mode 100644 paddle/phi/kernels/gpudnn/softmax_grad_kernel_gpudnn.cu create mode 100644 paddle/phi/kernels/gpudnn/softmax_kernel_gpudnn.cu create mode 100644 paddle/phi/kernels/impl/softmax_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/softmax_kernel_impl.h create mode 100644 paddle/phi/kernels/softmax_grad_kernel.h create mode 100644 paddle/phi/kernels/softmax_kernel.h create mode 100644 paddle/phi/ops/compat/softmax_sig.cc diff --git a/cmake/phi.cmake b/cmake/phi.cmake index f1a6f8e45a..d9132b8445 100644 --- a/cmake/phi.cmake +++ b/cmake/phi.cmake @@ -81,6 +81,8 @@ function(kernel_declare TARGET_LIST) file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, GPU, ALL_LAYOUT);\n") elseif (${kernel_path} MATCHES "./xpu\/") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, XPU, ALL_LAYOUT);\n") + elseif (${kernel_path} MATCHES "./gpudnn\/") + file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, GPUDNN, ALL_LAYOUT);\n") else () # deal with device independent kernel, now we use CPU temporaary file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n") @@ -94,6 +96,7 @@ function(kernel_library TARGET) set(cpu_srcs) set(gpu_srcs) set(xpu_srcs) + set(gpudnn_srcs) set(selected_rows_srcs) # parse and save the deps kerenl targets set(all_srcs) @@ -101,6 +104,8 @@ function(kernel_library TARGET) set(oneValueArgs SUB_DIR) set(multiValueArgs SRCS DEPS) + set(target_build_flag 1) + cmake_parse_arguments(kernel_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -123,6 +128,9 @@ function(kernel_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc) list(APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc) endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpudnn/${TARGET}_gpudnn.cu) + list(APPEND gpudnn_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpudnn/${TARGET}_gpudnn.cu) + endif() endif() if (WITH_XPU) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/xpu/${TARGET}.cc) @@ -141,6 +149,7 @@ function(kernel_library TARGET) list(APPEND all_srcs ${cpu_srcs}) list(APPEND all_srcs ${gpu_srcs}) list(APPEND all_srcs ${xpu_srcs}) + list(APPEND all_srcs ${gpudnn_srcs}) foreach(src ${all_srcs}) file(READ ${src} target_content) string(REGEX MATCHALL "#include \"paddle\/phi\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content}) @@ -166,21 +175,22 @@ function(kernel_library TARGET) list(LENGTH cpu_srcs cpu_srcs_len) list(LENGTH gpu_srcs gpu_srcs_len) list(LENGTH xpu_srcs xpu_srcs_len) + list(LENGTH gpudnn_srcs gpudnn_srcs_len) list(LENGTH selected_rows_srcs selected_rows_srcs_len) # Build Target according different src organization if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR - ${xpu_srcs_len} GREATER 0) AND (${common_srcs_len} GREATER 0 OR - ${selected_rows_srcs_len} GREATER 0)) + ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) AND + (${common_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0)) # If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule. if (WITH_GPU) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) - nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) + nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) nv_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part) endif() elseif (WITH_ROCM) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) - hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) + hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part) endif() else() @@ -190,14 +200,14 @@ function(kernel_library TARGET) endif() endif() # If there are only specific device srcs, build target using this rule. - elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) + elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) if (WITH_GPU) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) - nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) + nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() elseif (WITH_ROCM) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) - hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) + hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() else() if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) @@ -234,35 +244,40 @@ function(kernel_library TARGET) cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() else() - message(FATAL_ERROR "Cannot find any implementation for ${TARGET}") + set(target_build_flag 0) endif() - if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR - ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR - ${selected_rows_srcs_len} GREATER 0) - # append target into PHI_KERNELS property - get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS) - set(phi_kernels ${phi_kernels} ${TARGET}) - set_property(GLOBAL PROPERTY PHI_KERNELS ${phi_kernels}) - endif() + if (${target_build_flag} EQUAL 1) + if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR + ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR + ${gpudnn_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0) + # append target into PHI_KERNELS property + get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS) + set(phi_kernels ${phi_kernels} ${TARGET}) + set_property(GLOBAL PROPERTY PHI_KERNELS ${phi_kernels}) + endif() - # parse kernel name and auto generate kernel declaration - # here, we don't need to check WITH_XXX, because if not WITH_XXX, the - # xxx_srcs_len will be equal to 0 - if (${common_srcs_len} GREATER 0) - kernel_declare(${common_srcs}) - endif() - if (${cpu_srcs_len} GREATER 0) - kernel_declare(${cpu_srcs}) - endif() - if (${gpu_srcs_len} GREATER 0) - kernel_declare(${gpu_srcs}) - endif() - if (${xpu_srcs_len} GREATER 0) - kernel_declare(${xpu_srcs}) - endif() - if (${selected_rows_srcs_len} GREATER 0) - kernel_declare(${selected_rows_srcs}) + # parse kernel name and auto generate kernel declaration + # here, we don't need to check WITH_XXX, because if not WITH_XXX, the + # xxx_srcs_len will be equal to 0 + if (${common_srcs_len} GREATER 0) + kernel_declare(${common_srcs}) + endif() + if (${cpu_srcs_len} GREATER 0) + kernel_declare(${cpu_srcs}) + endif() + if (${gpu_srcs_len} GREATER 0) + kernel_declare(${gpu_srcs}) + endif() + if (${xpu_srcs_len} GREATER 0) + kernel_declare(${xpu_srcs}) + endif() + if (${gpudnn_srcs_len} GREATER 0) + kernel_declare(${gpudnn_srcs}) + endif() + if (${selected_rows_srcs_len} GREATER 0) + kernel_declare(${selected_rows_srcs}) + endif() endif() endfunction() diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc index ea335e9bd6..0a95444f85 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc @@ -20,7 +20,7 @@ #include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/op_registry.h" -USE_OP(softmax); +USE_OP_ITSELF(softmax); USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_ITSELF(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); diff --git a/paddle/fluid/framework/pten_utils.cc b/paddle/fluid/framework/pten_utils.cc index 0ecc04dbd6..af9d62ff7a 100644 --- a/paddle/fluid/framework/pten_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -67,7 +67,7 @@ OpKernelType TransPtenKernelKeyToOpKernelType( LibraryType library_type = LibraryType::kPlain; if (kernel_key.backend() == phi::Backend::MKLDNN) { library_type = LibraryType::kMKLDNN; - } else if (kernel_key.backend() == phi::Backend::CUDNN) { + } else if (kernel_key.backend() == phi::Backend::GPUDNN) { library_type = LibraryType::kCUDNN; } else { // do nothing @@ -82,7 +82,7 @@ phi::KernelKey TransOpKernelTypeToPtenKernelKey( if (kernel_type.library_type_ == LibraryType::kMKLDNN) { backend = phi::Backend::MKLDNN; } else if (kernel_type.library_type_ == LibraryType::kCUDNN) { - backend = phi::Backend::CUDNN; + backend = phi::Backend::GPUDNN; } else { // do } diff --git a/paddle/fluid/framework/pten_utils_test.cc b/paddle/fluid/framework/pten_utils_test.cc index 3c86372e6e..da1431c0ef 100644 --- a/paddle/fluid/framework/pten_utils_test.cc +++ b/paddle/fluid/framework/pten_utils_test.cc @@ -42,7 +42,7 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) { #endif #ifdef PADDLE_WITH_CUDA - phi::KernelKey kernel_key_cudnn(phi::Backend::CUDNN, phi::DataLayout::NCHW, + phi::KernelKey kernel_key_cudnn(phi::Backend::GPUDNN, phi::DataLayout::NCHW, phi::DataType::FLOAT32); op_kernel_type = paddle::framework::TransPtenKernelKeyToOpKernelType(kernel_key_cudnn); @@ -53,3 +53,38 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) { paddle::framework::LibraryType::kCUDNN); #endif } + +TEST(PtenUtils, TransOpKernelTypeToPtenKernelKey) { + paddle::framework::OpKernelType op_kernel_type( + paddle::framework::proto::VarType::FP32, paddle::platform::CPUPlace(), + paddle::framework::DataLayout::kNCHW); + auto kernel_key = + paddle::framework::TransOpKernelTypeToPtenKernelKey(op_kernel_type); + ASSERT_EQ(kernel_key.dtype(), phi::DataType::FLOAT32); + ASSERT_EQ(kernel_key.layout(), phi::DataLayout::NCHW); + ASSERT_EQ(kernel_key.backend(), phi::Backend::CPU); + +#ifdef PADDLE_WITH_MKLDNN + paddle::framework::OpKernelType op_kernel_type_mkldnn( + paddle::framework::proto::VarType::FP32, paddle::platform::CPUPlace(), + paddle::framework::DataLayout::kMKLDNN, + paddle::framework::LibraryType::kMKLDNN); + auto kernel_key_mkldnn = paddle::framework::TransOpKernelTypeToPtenKernelKey( + op_kernel_type_mkldnn); + ASSERT_EQ(kernel_key_mkldnn.dtype(), phi::DataType::FLOAT32); + ASSERT_EQ(kernel_key_mkldnn.layout(), phi::DataLayout::MKLDNN); + ASSERT_EQ(kernel_key_mkldnn.backend(), phi::Backend::MKLDNN); +#endif + +#ifdef PADDLE_WITH_CUDA + paddle::framework::OpKernelType op_kernel_type_cudnn( + paddle::framework::proto::VarType::FP32, paddle::platform::CPUPlace(), + paddle::framework::DataLayout::kNCHW, + paddle::framework::LibraryType::kCUDNN); + auto kernel_key_cudnn = + paddle::framework::TransOpKernelTypeToPtenKernelKey(op_kernel_type_cudnn); + ASSERT_EQ(kernel_key_cudnn.dtype(), phi::DataType::FLOAT32); + ASSERT_EQ(kernel_key_cudnn.layout(), phi::DataLayout::NCHW); + ASSERT_EQ(kernel_key_cudnn.backend(), phi::Backend::GPUDNN); +#endif +} diff --git a/paddle/fluid/inference/tensorrt/convert/softmax_op.cc b/paddle/fluid/inference/tensorrt/convert/softmax_op.cc index 9cefb24751..46e6c18bfb 100644 --- a/paddle/fluid/inference/tensorrt/convert/softmax_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/softmax_op.cc @@ -88,5 +88,5 @@ class SoftMaxOpConverter : public OpConverter { } // namespace inference } // namespace paddle -USE_OP(softmax); +USE_OP_ITSELF(softmax); REGISTER_TRT_OP_CONVERTER(softmax, SoftMaxOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_softmax_op.cc b/paddle/fluid/inference/tensorrt/convert/test_softmax_op.cc index b6fdcddf30..9cd5e81141 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_softmax_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_softmax_op.cc @@ -45,4 +45,4 @@ TEST(SoftMaxOpConverter, main) { } // namespace inference } // namespace paddle -USE_OP(softmax); +USE_OP_ITSELF(softmax); diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu index 4f1f1ec651..b5beb77090 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/string/string_helper.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" namespace paddle { namespace operators { @@ -98,8 +99,8 @@ class CSoftmaxWithCrossEntropyOpCUDAKernel : public framework::OpKernel { const auto& labels_dims = labels->dims(); const int axis = logits_dims.size() - 1; - const int N = SizeToAxis(axis, logits_dims); - const int D = SizeFromAxis(axis, logits_dims); + const int N = phi::funcs::SizeToAxis(axis, logits_dims); + const int D = phi::funcs::SizeFromAxis(axis, logits_dims); Tensor logits_2d, softmax_2d, loss_2d; logits_2d.ShareDataWith(*logits).Resize({N, D}); @@ -220,8 +221,8 @@ class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { } const auto sofrmax_dims = softmax->dims(); const int axis = sofrmax_dims.size() - 1; - const int N = SizeToAxis(axis, sofrmax_dims); - const int D = SizeFromAxis(axis, sofrmax_dims); + const int N = phi::funcs::SizeToAxis(axis, sofrmax_dims); + const int D = phi::funcs::SizeFromAxis(axis, sofrmax_dims); Tensor logit_grad_2d; logit_grad_2d.ShareDataWith(*logit_grad).Resize({N, D}); diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h index c7cfd41fa2..f5399e3215 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h @@ -23,7 +23,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/softmax.h" -#include "paddle/fluid/operators/softmax_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 31fff4b668..0202776757 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -14,8 +14,8 @@ limitations under the License. */ #include "paddle/fluid/operators/dropout_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/operators/softmax_cudnn_op.cu.h" #include "paddle/fluid/operators/transpose_op.cu.h" +#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace paddle { namespace operators { @@ -123,11 +123,11 @@ class FMHARef { T, T>( dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor()); - SoftmaxForwardCUDAKernelDriver(dev_ctx_, *src_mask_out_tensor, - softmax_axis, softmax_out_tensor); + phi::SoftmaxForwardCUDAKernelDriver(dev_ctx_, *src_mask_out_tensor, + softmax_axis, softmax_out_tensor); } else { - SoftmaxForwardCUDAKernelDriver(dev_ctx_, *qk_out_tensor, softmax_axis, - softmax_out_tensor); + phi::SoftmaxForwardCUDAKernelDriver(dev_ctx_, *qk_out_tensor, + softmax_axis, softmax_out_tensor); } transB = CblasNoTrans; @@ -251,9 +251,9 @@ class FMHARef { } if (src_mask_tensor != nullptr) { - SoftmaxBackwardCUDAKernelDriver(dev_ctx_, softmax_out_tensor, - *softmax_out_grad_tensor, softmax_axis, - src_mask_out_grad_tensor); + phi::SoftmaxBackwardCUDAKernelDriver( + dev_ctx_, softmax_out_tensor, *softmax_out_grad_tensor, softmax_axis, + src_mask_out_grad_tensor); // recall LaunchElementwiseCudaKernel fw: src_mask_out = qk_out + // src_mask @@ -272,9 +272,9 @@ class FMHARef { } } else { - SoftmaxBackwardCUDAKernelDriver(dev_ctx_, softmax_out_tensor, - *softmax_out_grad_tensor, softmax_axis, - qk_out_grad_tensor); + phi::SoftmaxBackwardCUDAKernelDriver(dev_ctx_, softmax_out_tensor, + *softmax_out_grad_tensor, + softmax_axis, qk_out_grad_tensor); } T* qk_out_grad_data = qk_out_grad_tensor->data(); diff --git a/paddle/fluid/operators/margin_cross_entropy_op.cu b/paddle/fluid/operators/margin_cross_entropy_op.cu index c6405f65ee..a2e34d9846 100644 --- a/paddle/fluid/operators/margin_cross_entropy_op.cu +++ b/paddle/fluid/operators/margin_cross_entropy_op.cu @@ -26,6 +26,7 @@ namespace cub = hipcub; #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/string/string_helper.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) @@ -246,8 +247,8 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel { const auto& labels_dims = labels->dims(); const int axis = logits_dims.size() - 1; - const int N = SizeToAxis(axis, logits_dims); - const int D = SizeFromAxis(axis, logits_dims); + const int N = phi::funcs::SizeToAxis(axis, logits_dims); + const int D = phi::funcs::SizeFromAxis(axis, logits_dims); int blocks = NumBlocks(N); int threads = kNumCUDAThreads; @@ -401,8 +402,8 @@ class MarginCrossEntropyGradCUDAKernel : public framework::OpKernel { const auto sofrmax_dims = softmax->dims(); const int axis = sofrmax_dims.size() - 1; - const int N = SizeToAxis(axis, sofrmax_dims); - const int D = SizeFromAxis(axis, sofrmax_dims); + const int N = phi::funcs::SizeToAxis(axis, sofrmax_dims); + const int D = phi::funcs::SizeFromAxis(axis, sofrmax_dims); if (return_softmax) { framework::TensorCopy(*softmax, context.GetPlace(), diff --git a/paddle/fluid/operators/margin_cross_entropy_op.h b/paddle/fluid/operators/margin_cross_entropy_op.h index fe0dab5d47..9261c84c85 100644 --- a/paddle/fluid/operators/margin_cross_entropy_op.h +++ b/paddle/fluid/operators/margin_cross_entropy_op.h @@ -22,7 +22,6 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/softmax.h" -#include "paddle/fluid/operators/softmax_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/math/softmax.cc b/paddle/fluid/operators/math/softmax.cc index fa2018178f..c855cb763a 100644 --- a/paddle/fluid/operators/math/softmax.cc +++ b/paddle/fluid/operators/math/softmax.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax_impl.h" +#include "paddle/phi/backends/cpu/cpu_context.h" namespace paddle { namespace operators { @@ -26,6 +27,13 @@ template class SoftmaxFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxGradFunctor; +template class SoftmaxGradFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 692a077f10..fd879e9e6f 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax_impl.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -139,6 +140,16 @@ template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxGradFunctor; +template class SoftmaxGradFunctor; +template class SoftmaxGradFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index 2effcbf9f4..a0e50aa297 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/softmax_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" namespace paddle { namespace operators { @@ -70,7 +71,8 @@ class SoftmaxMKLDNNHandler out_grad->dims(), in_x_grad->dims())); auto dims = out_grad->dims(); // input and output share the same shape - const int axis = CanonicalAxis(ctx.Attr("axis"), dims.size()); + const int axis = + phi::funcs::CanonicalAxis(ctx.Attr("axis"), dims.size()); auto softmax_tz = phi::vectorize(dims); auto data_softmax_md = MKLDNNMemDesc( @@ -96,7 +98,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { Tensor* output = ctx.Output("Out"); bool is_inplaced = input->IsSharedBufferWith(*output); - const int axis = CanonicalAxis(ctx.Attr("axis"), input->dims().size()); + const int axis = + phi::funcs::CanonicalAxis(ctx.Attr("axis"), input->dims().size()); SoftmaxMKLDNNHandler handler(mkldnn_engine, ctx.GetPlace(), input, output, axis); diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc index 9c5bad8627..2fdeecf893 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc @@ -31,7 +31,7 @@ USE_OP(elementwise_mul); USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN); USE_OP(relu); USE_OP_DEVICE_KERNEL(relu, MKLDNN); -USE_OP(softmax); +USE_OP_ITSELF(softmax); USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP(conv2d); USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32); diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc index 92c58ae0a7..c776cf2a7c 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc @@ -29,7 +29,7 @@ USE_OP_ITSELF(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP(relu); USE_OP_DEVICE_KERNEL(relu, MKLDNN); -USE_OP(softmax); +USE_OP_ITSELF(softmax); USE_OP_DEVICE_KERNEL(softmax, MKLDNN); namespace paddle { diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu b/paddle/fluid/operators/softmax_cudnn_op.cu deleted file mode 100644 index 72c2e97c17..0000000000 --- a/paddle/fluid/operators/softmax_cudnn_op.cu +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/softmax_cudnn_op.cu.h" - -namespace paddle { -namespace operators { - -template -class SoftmaxCUDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - int input_axis = ctx.Attr("axis"); - auto& dev_ctx = ctx.template device_context(); - SoftmaxForwardCUDAKernelDriver(dev_ctx, *x, input_axis, out); - } -}; - -template -class SoftmaxGradCUDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* out = ctx.Input("Out"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - dx->mutable_data(ctx.GetPlace()); - - int input_axis = ctx.Attr("axis"); - auto& dev_ctx = ctx.template device_context(); - SoftmaxBackwardCUDAKernelDriver(dev_ctx, *out, *dout, input_axis, dx); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -#ifdef PADDLE_WITH_HIP -// MIOPEN do not support double -REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, - ops::SoftmaxCUDNNKernel, - ops::SoftmaxCUDNNKernel); -REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, - ops::SoftmaxGradCUDNNKernel, - ops::SoftmaxGradCUDNNKernel); -#else -REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, - ops::SoftmaxCUDNNKernel, - ops::SoftmaxCUDNNKernel, - ops::SoftmaxCUDNNKernel); -REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, - ops::SoftmaxGradCUDNNKernel, - ops::SoftmaxGradCUDNNKernel, - ops::SoftmaxGradCUDNNKernel); -#endif diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index cb97a0bb27..3749920966 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -12,12 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/softmax_op.h" - #include #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #ifdef PADDLE_WITH_MKLDNN @@ -251,10 +250,3 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, ops::SoftmaxOpGradMaker, ops::SoftmaxInplaceInferer); REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad); -REGISTER_OP_CPU_KERNEL( - softmax, ops::SoftmaxKernel, - ops::SoftmaxKernel); -REGISTER_OP_CPU_KERNEL( - softmax_grad, - ops::SoftmaxGradKernel, - ops::SoftmaxGradKernel); diff --git a/paddle/fluid/operators/softmax_op.cu.cc b/paddle/fluid/operators/softmax_op.cu.cc deleted file mode 100644 index 19359b7eef..0000000000 --- a/paddle/fluid/operators/softmax_op.cu.cc +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/softmax_op.h" -#include "paddle/fluid/platform/float16.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - softmax, ops::SoftmaxKernel, - ops::SoftmaxKernel, - ops::SoftmaxKernel); -REGISTER_OP_CUDA_KERNEL( - softmax_grad, ops::SoftmaxGradKernel, - ops::SoftmaxGradKernel, - ops::SoftmaxGradKernel); diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h deleted file mode 100644 index 497bbb06da..0000000000 --- a/paddle/fluid/operators/softmax_op.h +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/softmax.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using DDim = framework::DDim; - -static inline int CanonicalAxis(const int axis, const int rank) { - if (axis < 0) { - return axis + rank; - } - return axis; -} - -static inline int SizeToAxis(const int axis, DDim dims) { - int size = 1; - for (int i = 0; i < axis; i++) { - size *= dims[i]; - } - return size; -} - -static inline int SizeFromAxis(const int axis, DDim dims) { - int size = 1; - for (int i = axis; i < dims.size(); i++) { - size *= dims[i]; - } - return size; -} - -static inline int SizeOutAxis(const int axis, DDim dims) { - int size = 1; - for (int i = axis + 1; i < dims.size(); i++) { - size *= dims[i]; - } - return size; -} - -template -class SoftmaxKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Out = context.Output("Out"); - const int rank = X->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); - int axis_dim = X->dims()[axis]; - - // allocate memory on device. - Out->mutable_data(context.GetPlace()); - if (Out->numel() == 0) { - return; - } - - const int n = SizeToAxis(axis, X->dims()); - const int d = SizeFromAxis(axis, X->dims()); - Tensor X_2d, Out_2d; - X_2d.ShareDataWith(*X).Resize({n, d}); - Out_2d.ShareDataWith(*Out).Resize({n, d}); - math::SoftmaxFunctor()( - context.template device_context(), axis_dim, &X_2d, - &Out_2d); - } -}; - -template -class SoftmaxGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* Out = context.Input("Out"); - auto* dOut = context.Input(framework::GradVarName("Out")); - auto* dX = context.Output(framework::GradVarName("X")); - const int rank = dX->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); - int axis_dim = dX->dims()[axis]; - - // allocate memory on device. - dX->mutable_data(context.GetPlace()); - if (dX->numel() == 0) { - return; - } - - const int n = SizeToAxis(axis, dX->dims()); - const int d = SizeFromAxis(axis, dX->dims()); - Tensor dX_2d, Out_2d, dOut_2d; - dX_2d.ShareDataWith(*dX).Resize({n, d}); - Out_2d.ShareDataWith(*Out).Resize({n, d}); - dOut_2d.ShareDataWith(*dOut).Resize({n, d}); - - math::SoftmaxGradFunctor()( - context.template device_context(), axis_dim, &Out_2d, - &dOut_2d, &dX_2d); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/softmax_op_npu.cc b/paddle/fluid/operators/softmax_op_npu.cc index 07e74354bf..152c8d0a88 100644 --- a/paddle/fluid/operators/softmax_op_npu.cc +++ b/paddle/fluid/operators/softmax_op_npu.cc @@ -12,8 +12,9 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/softmax_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" namespace paddle { namespace operators { @@ -51,7 +52,7 @@ class SoftmaxGradNPUKernel : public framework::OpKernel { auto dims = dX->dims(); const int rank = dims.size(); - const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + const int axis = phi::funcs::CanonicalAxis(ctx.Attr("axis"), rank); int64_t first_dim = 1; int64_t sec_dim = 1; for (int i = 0; i < axis; i++) { diff --git a/paddle/fluid/operators/softmax_op_npu_test.cc b/paddle/fluid/operators/softmax_op_npu_test.cc index defda1a3b0..3bc55fafd8 100644 --- a/paddle/fluid/operators/softmax_op_npu_test.cc +++ b/paddle/fluid/operators/softmax_op_npu_test.cc @@ -29,7 +29,7 @@ limitations under the License. */ namespace f = paddle::framework; namespace p = paddle::platform; -USE_OP(softmax); +USE_OP_ITSELF(softmax); USE_OP_DEVICE_KERNEL(softmax, NPU); template diff --git a/paddle/fluid/operators/softmax_op_xpu.cc b/paddle/fluid/operators/softmax_op_xpu.cc index a29804e505..1ed13c8bd1 100644 --- a/paddle/fluid/operators/softmax_op_xpu.cc +++ b/paddle/fluid/operators/softmax_op_xpu.cc @@ -11,8 +11,8 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" namespace paddle { namespace operators { @@ -29,7 +29,7 @@ class SoftmaxXPUKernel : public framework::OpKernel { auto* x = context.Input("X"); auto* out = context.Output("Out"); const int rank = x->dims().size(); - int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); // allocate memory on device. out->mutable_data(context.GetPlace()); @@ -88,7 +88,7 @@ class SoftmaxGradXPUKernel : public framework::OpKernel { auto* dout = context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); const int rank = dx->dims().size(); - int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); // allocate memory on device. dx->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index cba779d0a7..6f0881e9fc 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -153,7 +153,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { "Attr(axis) value should be in range [-R, R-1], " "R is the rank of Input(Logits).")); - axis = CanonicalAxis(axis, logits_rank); + axis = phi::funcs::CanonicalAxis(axis, logits_rank); for (int i = 0; i < logits_rank; i++) { if (i != axis) { if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) { @@ -250,7 +250,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { "Attr(axis) value should be in range [-R, R-1], " "R is the rank of Input(Logits).")); - axis = CanonicalAxis(axis, softmax_rank); + axis = phi::funcs::CanonicalAxis(axis, softmax_rank); for (int i = 0; i < softmax_rank; i++) { if (i != axis) { if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) { diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 2bbacef596..fd035df768 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -17,12 +17,12 @@ namespace cub = hipcub; #endif #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/math/cross_entropy.h" -#include "paddle/fluid/operators/softmax_cudnn_op.cu.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace paddle { namespace operators { @@ -236,7 +236,7 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax; } } - WarpReduceMax(max_value); + phi::WarpReduceMax(max_value); // compute sum: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} } AccT sum[kBatchSize]; @@ -276,7 +276,7 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, } } } - WarpReduceSum(sum); + phi::WarpReduceSum(sum); // write data #pragma unroll @@ -566,7 +566,7 @@ __global__ void CrossEntropySoftLabel(T* loss, T* softmaxwrt, const T* softmax, } } } - WarpReduceSum(sum); + phi::WarpReduceSum(sum); __syncthreads(); __shared__ T sumshare[kWarpPerBatch][kBatchPerBlock][kBatchSize]; @@ -674,7 +674,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src, : static_cast(valmax); } } - WarpReduceMax(max_value); + phi::WarpReduceMax(max_value); // compute sum AccT sum[kBatchSize]{0.0}; @@ -694,7 +694,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src, } } } - WarpReduceSum(sum); + phi::WarpReduceSum(sum); // log_softmax and loss AccT sumloss[kBatchSize]{0.0}; @@ -737,7 +737,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src, } // loss - WarpReduceSum(sumloss); + phi::WarpReduceSum(sumloss); for (int i = 0; i < kBatchSize; i++) { if (i >= local_batches) break; @@ -950,11 +950,12 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { Tensor* loss = context.Output("Loss"); const int rank = softmax->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); + const int axis = + phi::funcs::CanonicalAxis(context.Attr("axis"), rank); const int axis_dim = softmax->dims()[axis]; - const int n = SizeToAxis(axis, softmax->dims()); - const int d = SizeFromAxis(axis, softmax->dims()); + const int n = phi::funcs::SizeToAxis(axis, softmax->dims()); + const int d = phi::funcs::SizeFromAxis(axis, softmax->dims()); auto* softmax_out_data = softmax_out->template mutable_data(context.GetPlace()); @@ -1035,11 +1036,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { Tensor* loss = context.Output("Loss"); const int rank = logits->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); + const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); int axis_dim = logits->dims()[axis]; - const int64_t n = SizeToAxis(axis, logits->dims()); - const int64_t d = SizeFromAxis(axis, logits->dims()); + const int64_t n = phi::funcs::SizeToAxis(axis, logits->dims()); + const int64_t d = phi::funcs::SizeFromAxis(axis, logits->dims()); auto* softmax_data = softmax->template mutable_data(context.GetPlace()); auto* loss_data = loss->template mutable_data(context.GetPlace()); @@ -1118,11 +1119,11 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { T* logit_grad_data = logit_grad->template data(); const int rank = logit_grad->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); + const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); int axis_dim = logit_grad->dims()[axis]; - const int64_t n = SizeToAxis(axis, logit_grad->dims()); - const int64_t d = SizeFromAxis(axis, logit_grad->dims()); + const int64_t n = phi::funcs::SizeToAxis(axis, logit_grad->dims()); + const int64_t d = phi::funcs::SizeFromAxis(axis, logit_grad->dims()); const int64_t remain = d / axis_dim; #ifdef __HIPCC__ diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h index a7f88dd0ec..4b875cbf58 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.h @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/softmax.h" -#include "paddle/fluid/operators/softmax_op.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" namespace paddle { namespace operators { @@ -84,7 +84,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { Tensor* softmax_out = context.Output("Softmax"); Tensor* loss = context.Output("Loss"); const int rank = softmax->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); + const int axis = + phi::funcs::CanonicalAxis(context.Attr("axis"), rank); int axis_dim = softmax->dims()[axis]; PADDLE_ENFORCE_GT( @@ -97,7 +98,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { softmax_out->mutable_data(context.GetPlace()); loss->mutable_data(context.GetPlace()); - const int n = SizeToAxis(axis, softmax->dims()); + const int n = phi::funcs::SizeToAxis(axis, softmax->dims()); PADDLE_ENFORCE_GT( n, 0, platform::errors::InvalidArgument( @@ -105,7 +106,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { "SizeToAxis of softmax is %d.", n)); - const int d = SizeFromAxis(axis, softmax->dims()); + const int d = phi::funcs::SizeFromAxis(axis, softmax->dims()); Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d; softmax_2d.ShareDataWith(*softmax).Resize({n, d}); @@ -133,7 +134,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { Tensor* loss = context.Output("Loss"); const int rank = logits->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); + const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); int axis_dim = logits->dims()[axis]; PADDLE_ENFORCE_GT( axis_dim, 0, @@ -145,14 +146,14 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { softmax->mutable_data(context.GetPlace()); loss->mutable_data(context.GetPlace()); - const int n = SizeToAxis(axis, logits->dims()); + const int n = phi::funcs::SizeToAxis(axis, logits->dims()); PADDLE_ENFORCE_GT( n, 0, platform::errors::InvalidArgument( "The size of axis should be larger than 0, but received " "SizeToAxis of logits is %d.", n)); - const int d = SizeFromAxis(axis, logits->dims()); + const int d = phi::funcs::SizeFromAxis(axis, logits->dims()); Tensor logits_2d, softmax_2d, labels_2d, loss_2d; logits_2d.ShareDataWith(*logits).Resize({n, d}); softmax_2d.ShareDataWith(*softmax).Resize({n, d}); @@ -192,7 +193,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { auto ignore_index = context.Attr("ignore_index"); const int rank = logit_grad->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); + const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); int axis_dim = logit_grad->dims()[axis]; PADDLE_ENFORCE_GT( axis_dim, 0, @@ -201,14 +202,14 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { "axis dimention is %d.", axis_dim)); - const int n = SizeToAxis(axis, logit_grad->dims()); + const int n = phi::funcs::SizeToAxis(axis, logit_grad->dims()); PADDLE_ENFORCE_GT( n, 0, platform::errors::InvalidArgument( "The size of axis should be larger than 0, but received " "SizeToAxis of logit_grad is %d.", n)); - const int d = SizeFromAxis(axis, logit_grad->dims()); + const int d = phi::funcs::SizeFromAxis(axis, logit_grad->dims()); Tensor logit_grad_2d, labels_2d, out_grad_2d; logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d}); labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n}); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc index a5576ab5af..1f1fbea090 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc @@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" + #include #include #include "paddle/fluid/operators/math/cross_entropy.h" -#include "paddle/fluid/operators/softmax_op.h" +#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { @@ -40,15 +41,16 @@ class SoftmaxWithCrossEntropyNPUKernel : public framework::OpKernel { "the npu kernel of softmax_with_cross_entropy.")); const int rank = logits->dims().size(); - const int axis = CanonicalAxis(ctx.Attr("axis"), rank); - const int n = SizeToAxis(axis, logits->dims()); - const int d = SizeFromAxis(axis, logits->dims()); + const int axis = phi::funcs::CanonicalAxis(ctx.Attr("axis"), rank); + const int n = phi::funcs::SizeToAxis(axis, logits->dims()); + const int d = phi::funcs::SizeFromAxis(axis, logits->dims()); PADDLE_ENFORCE_EQ( labels->numel(), n, platform::errors::Unimplemented( - "The size of labels should be equal to SizeToAxis of logits," - "but got size of labels is %d and SizeToAxis is %d.", + "The size of labels should be equal to phi::funcs::SizeToAxis of " + "logits," + "but got size of labels is %d and phi::funcs::SizeToAxis is %d.", labels->numel(), n)); loss->mutable_data(ctx.GetPlace()); @@ -97,9 +99,9 @@ class SoftmaxWithCrossEntropyGradNPUKernel : public framework::OpKernel { logits_grad->mutable_data(ctx.GetPlace()); const int rank = logits_grad->dims().size(); - const int axis = CanonicalAxis(ctx.Attr("axis"), rank); - const int n = SizeToAxis(axis, logits_grad->dims()); - const int d = SizeFromAxis(axis, logits_grad->dims()); + const int axis = phi::funcs::CanonicalAxis(ctx.Attr("axis"), rank); + const int n = phi::funcs::SizeToAxis(axis, logits_grad->dims()); + const int d = phi::funcs::SizeFromAxis(axis, logits_grad->dims()); Tensor logits_grad_2d, loss_grad_1d, backprop_2d; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc index 650e488c5e..d9149b85c6 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc @@ -38,13 +38,13 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { Tensor* softmax = context.Output("Softmax"); Tensor* loss = context.Output("Loss"); const int rank = logits->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); + const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); PADDLE_ENFORCE_EQ(axis, rank - 1, platform::errors::InvalidArgument( "axis should == rank - 1")); softmax->mutable_data(context.GetPlace()); loss->mutable_data(context.GetPlace()); - const int n = SizeToAxis(axis, logits->dims()); - const int d = SizeFromAxis(axis, logits->dims()); + const int n = phi::funcs::SizeToAxis(axis, logits->dims()); + const int d = phi::funcs::SizeFromAxis(axis, logits->dims()); std::vector logits_dims = phi::vectorize(logits->dims()); const bool soft_label = context.Attr("soft_label"); @@ -122,11 +122,11 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel { auto ignore_index = context.Attr("ignore_index"); const int rank = logit_grad->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); + const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); PADDLE_ENFORCE_EQ(axis, rank - 1, platform::errors::InvalidArgument( "axis should == rank - 1")); - const int n = SizeToAxis(axis, logit_grad->dims()); - const int d = SizeFromAxis(axis, logit_grad->dims()); + const int n = phi::funcs::SizeToAxis(axis, logit_grad->dims()); + const int d = phi::funcs::SizeFromAxis(axis, logit_grad->dims()); auto& dev_ctx = context.template device_context(); diff --git a/paddle/fluid/operators/test_common_infer_shape_functions.cc b/paddle/fluid/operators/test_common_infer_shape_functions.cc index f04ba72a1e..a7c7e33f58 100644 --- a/paddle/fluid/operators/test_common_infer_shape_functions.cc +++ b/paddle/fluid/operators/test_common_infer_shape_functions.cc @@ -22,7 +22,7 @@ limitations under the License. */ USE_OP(relu); USE_OP_ITSELF(elementwise_add); -USE_OP(softmax); +USE_OP_ITSELF(softmax); namespace paddle { namespace operators { diff --git a/paddle/phi/backends/gpu/gpu_context.h b/paddle/phi/backends/gpu/gpu_context.h index 5fa80d3a57..603ce0817c 100644 --- a/paddle/phi/backends/gpu/gpu_context.h +++ b/paddle/phi/backends/gpu/gpu_context.h @@ -220,4 +220,11 @@ class GPUContext : public DeviceContext { std::unique_ptr impl_; }; +// Note: In order to register the kernel of CUDNN, GPUDNNContext is required. +// Currently, CUDNN kernel directly uses GPUContext. But if the kernel function +// has the same name, this will lead to duplicate instantiations of GPU kernel +// and GPUDNN kernel function, so if we using GPUDNNContext = GPUContext, we +// must use different function name for cudnn kernel +using GPUDNNContext = GPUContext; + } // namespace phi diff --git a/paddle/phi/common/backend.h b/paddle/phi/common/backend.h index 1d3e4369c6..4b7bf65be3 100644 --- a/paddle/phi/common/backend.h +++ b/paddle/phi/common/backend.h @@ -50,7 +50,7 @@ enum class Backend : uint8_t { // the third library backend MKLDNN, - CUDNN, + GPUDNN, // cuDNN and hipDNN // end of backend types NUM_BACKENDS, @@ -112,8 +112,8 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) { case Backend::MKLDNN: os << "MKLDNN"; break; - case Backend::CUDNN: - os << "CUDNN"; + case Backend::GPUDNN: + os << "GPUDNN"; break; default: { size_t device_type_id_ = static_cast(backend) - @@ -145,8 +145,8 @@ inline Backend StringToBackend(const char* backend_cstr) { return Backend::NPU; } else if (s == std::string("MKLDNN")) { return Backend::MKLDNN; - } else if (s == std::string("CUDNN")) { - return Backend::CUDNN; + } else if (s == std::string("GPUDNN")) { + return Backend::GPUDNN; } else { return static_cast(static_cast(Backend::NUM_BACKENDS) + phi::GetOrRegisterGlobalDeviceTypeId(s)); diff --git a/paddle/phi/common/float16.h b/paddle/phi/common/float16.h index 1cdcdef2c1..6ed9c88d70 100644 --- a/paddle/phi/common/float16.h +++ b/paddle/phi/common/float16.h @@ -988,6 +988,18 @@ inline std::ostream& operator<<(std::ostream& os, const float16& a) { return os; } +template +class MPTypeTrait { + public: + using Type = T; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + } // namespace dtype } // namespace phi diff --git a/paddle/phi/core/compat/convert_utils.cc b/paddle/phi/core/compat/convert_utils.cc index a5b7b869b9..f7dab1d34c 100644 --- a/paddle/phi/core/compat/convert_utils.cc +++ b/paddle/phi/core/compat/convert_utils.cc @@ -58,7 +58,7 @@ phi::Place TransToPtenPlace(const Backend& backend, bool set_device_id) { return phi::CPUPlace(); #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - case phi::Backend::CUDNN: + case phi::Backend::GPUDNN: return phi::GPUPlace( set_device_id ? phi::backends::gpu::GetCurrentDeviceId() : 0); #endif diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 4a79f191c2..f27adf1de1 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -15,8 +15,15 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function i set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) -# auto build kernel targets by cmake -register_kernels(DEPS ${COMMON_KERNEL_DEPS}) +# NOTE: Some kernels depend on some targets that are not commonly used. +# These targets are not suitable for common dependencies. +# In this case, you need to manually generate them here. +set(MANUAL_BUILD_KERNELS softmax_kernel softmax_grad_kernel) +kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) +kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) + +# auto parse and build kernel targets by cmake +register_kernels(EXCLUDES ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS}) # phi sparse kernels add_subdirectory(sparse) diff --git a/paddle/phi/kernels/cpu/softmax_grad_kernel.cc b/paddle/phi/kernels/cpu/softmax_grad_kernel.cc new file mode 100644 index 0000000000..ef90f9c676 --- /dev/null +++ b/paddle/phi/kernels/cpu/softmax_grad_kernel.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/softmax_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/softmax_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + softmax_grad, CPU, ALL_LAYOUT, phi::SoftmaxGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/softmax_kernel.cc b/paddle/phi/kernels/cpu/softmax_kernel.cc new file mode 100644 index 0000000000..537b432668 --- /dev/null +++ b/paddle/phi/kernels/cpu/softmax_kernel.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/softmax_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/softmax_kernel_impl.h" + +PD_REGISTER_KERNEL( + softmax, CPU, ALL_LAYOUT, phi::SoftmaxRawKernel, float, double) {} diff --git a/paddle/phi/kernels/funcs/axis_utils.h b/paddle/phi/kernels/funcs/axis_utils.h new file mode 100644 index 0000000000..02a8947188 --- /dev/null +++ b/paddle/phi/kernels/funcs/axis_utils.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/ddim.h" + +namespace phi { +namespace funcs { + +static inline int CanonicalAxis(const int axis, const int rank) { + if (axis < 0) { + return axis + rank; + } + return axis; +} + +static inline int SizeToAxis(const int axis, DDim dims) { + int size = 1; + for (int i = 0; i < axis; i++) { + size *= dims[i]; + } + return size; +} + +static inline int SizeFromAxis(const int axis, DDim dims) { + int size = 1; + for (int i = axis; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +static inline int SizeOutAxis(const int axis, DDim dims) { + int size = 1; + for (int i = axis + 1; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/concat_funcs.h b/paddle/phi/kernels/funcs/concat_funcs.h index 32237e2cc2..70e3545b98 100644 --- a/paddle/phi/kernels/funcs/concat_funcs.h +++ b/paddle/phi/kernels/funcs/concat_funcs.h @@ -92,4 +92,4 @@ static inline phi::DDim ComputeAndCheckShape( } } // namespace funcs -} // namespace phi +} // namespace phi diff --git a/paddle/phi/kernels/funcs/eigen/elementwise.cu b/paddle/phi/kernels/funcs/eigen/elementwise.cu index 96d2ddba03..3855ba8ccf 100644 --- a/paddle/phi/kernels/funcs/eigen/elementwise.cu +++ b/paddle/phi/kernels/funcs/eigen/elementwise.cu @@ -55,5 +55,5 @@ struct EigenSub { template struct EigenSub; -} // namespace fucns +} // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/softmax_grad_kernel.cu b/paddle/phi/kernels/gpu/softmax_grad_kernel.cu new file mode 100644 index 0000000000..aa496d3cd3 --- /dev/null +++ b/paddle/phi/kernels/gpu/softmax_grad_kernel.cu @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/softmax_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/softmax_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(softmax_grad, + GPU, + ALL_LAYOUT, + phi::SoftmaxGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/softmax_kernel.cu b/paddle/phi/kernels/gpu/softmax_kernel.cu new file mode 100644 index 0000000000..32efb9b776 --- /dev/null +++ b/paddle/phi/kernels/gpu/softmax_kernel.cu @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/softmax_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/softmax_kernel_impl.h" + +PD_REGISTER_KERNEL(softmax, + GPU, + ALL_LAYOUT, + phi::SoftmaxRawKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h similarity index 63% rename from paddle/fluid/operators/softmax_cudnn_op.cu.h rename to paddle/phi/kernels/gpudnn/softmax_gpudnn.h index dc5166f4f9..45798b88bb 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -14,18 +14,20 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" -#include "paddle/fluid/operators/softmax_op.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/primitive/kernel_primitives.h" + +// See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -namespace paddle { -namespace operators { +namespace phi { -using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; -using DataLayout = platform::DataLayout; -using Tensor = framework::Tensor; +using ScopedTensorDescriptor = paddle::platform::ScopedTensorDescriptor; +using GPUDNNDataLayout = paddle::platform::DataLayout; // Vectorization trait 4 * sizeof(T) template @@ -41,7 +43,7 @@ class VecT4 { using Type = int4; }; template <> -class VecT4 { +class VecT4 { public: using Type = int2; }; @@ -60,7 +62,7 @@ class VecT2 { using Type = int2; }; template <> -class VecT2 { +class VecT2 { public: using Type = int; }; @@ -77,7 +79,8 @@ __device__ __forceinline__ void WarpReduceSum(T* sum) { for (int offset = WarpSize / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < BatchSize; ++i) { - T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); + T sum_val = + paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); sum[i] = sum[i] + sum_val; } } @@ -89,14 +92,13 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) { for (int offset = WarpSize / 2; offset > 0; offset /= 2) { #pragma unroll for (int i = 0; i < BatchSize; ++i) { - T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); + T max_val = + paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); sum[i] = max(sum[i], max_val); } } } -namespace kps = paddle::operators::kernel_primitives; - template struct ReduceMaxFunctor { inline Ty initial() { return -std::numeric_limits::infinity(); } @@ -248,10 +250,15 @@ One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle api to compute max (sum) in one warp. */ -template -__global__ void WarpSoftmaxForward(T* softmax, const T* src, - const int batch_size, const int stride, +__global__ void WarpSoftmaxForward(T* softmax, + const T* src, + const int batch_size, + const int stride, const int element_count) { constexpr int kDimCeil = 1 << Log2Elements; constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; @@ -302,9 +309,13 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, } // compute max - kps::Reduce, - kMode::kLocalMode>(&max[0], &srcdata[0][0][0], - ReduceMaxFunctor(), true); + kps::Reduce, + kMode::kLocalMode>( + &max[0], &srcdata[0][0][0], ReduceMaxFunctor(), true); WarpReduceMax(max); // compute sum @@ -313,9 +324,13 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, kps::ElementwiseUnary>( &srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor(max[i])); } - kps::Reduce, - kMode::kLocalMode>(&sum[0], &srcdata[0][0][0], - kps::AddFunctor(), true); + kps::Reduce, + kMode::kLocalMode>( + &sum[0], &srcdata[0][0][0], kps::AddFunctor(), true); WarpReduceSum(sum); // write data to global memory @@ -340,10 +355,16 @@ One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle api to compute max (sum) in one warp. */ -template -__global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, - int batch_size, int stride, +__global__ void WarpSoftmaxBackward(T* dst, + const T* grad, + const T* src, + int batch_size, + int stride, int element_count) { constexpr int kVSize = sizeof(VecT) / sizeof(T); constexpr int kDimCeil = 1 << Log2Elements; @@ -403,7 +424,11 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, AccT* srcptr = reinterpret_cast(&src_tmp[0][0][0]); kps::ElementwiseBinary>( &sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor()); - kps::Reduce, + kps::Reduce, kps::details::ReduceMode::kLocalMode>( &sum[0], &sum_tmp[0][0][0], kps::AddFunctor(), true); WarpReduceSum(sum); @@ -429,7 +454,10 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, #define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \ case Log2Elements: \ - WarpSoftmaxForward<<>>( \ dst, src, batch_size, stride, element_count); \ break; @@ -438,12 +466,16 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, Wrapper of softmax formward with template instantiation on size of input. */ template -void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads, - const platform::CUDADeviceContext& dev_ctx, - T* dst, const T* src, const int batch_size, - const int stride, const int element_count, +void SwitchWarpSoftmaxForward(const int blocks, + const dim3 threads, + const GPUContext& dev_ctx, + T* dst, + const T* src, + const int batch_size, + const int stride, + const int element_count, int Log2Elements) { - using AccT = typename details::MPTypeTrait::Type; + using AccT = typename phi::dtype::MPTypeTrait::Type; switch (Log2Elements) { SOFTMAX_WARP_FORWARD_CASE(0, AccT); SOFTMAX_WARP_FORWARD_CASE(1, AccT); @@ -462,7 +494,10 @@ void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads, #define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT) \ case Log2Elements: \ - WarpSoftmaxBackward<<>>( \ dst, grad, src, batch_size, stride, element_count); \ break; @@ -471,12 +506,17 @@ void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads, Wrapper of softmax backward with template instantiation on size of input. */ template -void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads, - const platform::CUDADeviceContext& dev_ctx, - T* dst, const T* grad, const T* src, - const int batch_size, const int stride, - const int element_count, int Log2Elements) { - using AccT = typename details::MPTypeTrait::Type; +void SwitchWarpSoftmaxBackward(const int blocks, + const dim3 threads, + const GPUContext& dev_ctx, + T* dst, + const T* grad, + const T* src, + const int batch_size, + const int stride, + const int element_count, + int Log2Elements) { + using AccT = typename phi::dtype::MPTypeTrait::Type; switch (Log2Elements) { SOFTMAX_WARP_BACKWARD_CASE(0, AccT); SOFTMAX_WARP_BACKWARD_CASE(1, AccT); @@ -501,12 +541,12 @@ void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads, * Better performence when axis != -1 */ -static void GetGridDim(int high_dim, int mid_dim, int low_dim, - const dim3& block, dim3* grid) { - int device_id = paddle::platform::GetCurrentDeviceId(); - int max_mp = paddle::platform::GetGPUMultiProcessors(device_id); +static void GetGridDim( + int high_dim, int mid_dim, int low_dim, const dim3& block, dim3* grid) { + int device_id = phi::backends::gpu::GetCurrentDeviceId(); + int max_mp = phi::backends::gpu::GetGPUMultiProcessors(device_id); int max_threads_per_mp = - paddle::platform::GetGPUMaxThreadsPerMultiProcessor(device_id); + phi::backends::gpu::GetGPUMaxThreadsPerMultiProcessor(device_id); int max_threads = max_threads_per_mp * max_mp; int num_threads = block.x * block.y; int max_num_blocks = max_threads / num_threads; @@ -532,16 +572,17 @@ static void GetBlockDim(int mid_dim, int low_dim, dim3* block) { block->x = std::min(block_x, static_cast(max_num_threads / block->y)); } -static void GetLaunchConfig(int high_dim, int mid_dim, int low_dim, dim3* grid, - dim3* block) { +static void GetLaunchConfig( + int high_dim, int mid_dim, int low_dim, dim3* grid, dim3* block) { GetBlockDim(mid_dim, low_dim, block); GetGridDim(high_dim, mid_dim, low_dim, *block, grid); } -template class Functor> -__global__ void NormalSoftmaxForward(T* output, const T* input, int high_dim, - int mid_dim, int low_dim) { +__global__ void NormalSoftmaxForward( + T* output, const T* input, int high_dim, int mid_dim, int low_dim) { using kMode = kps::details::ReduceMode; const int high_stride = mid_dim * low_dim; const int mid_stride = low_dim; @@ -584,11 +625,15 @@ __global__ void NormalSoftmaxForward(T* output, const T* input, int high_dim, } } -template class Functor> -__global__ void NormalSoftmaxBackward(T* input_grad, const T* output_grad, - const T* output, int high_dim, - int mid_dim, int low_dim) { +__global__ void NormalSoftmaxBackward(T* input_grad, + const T* output_grad, + const T* output, + int high_dim, + int mid_dim, + int low_dim) { using kMode = kps::details::ReduceMode; const int high_stride = mid_dim * low_dim; const int mid_stride = low_dim; @@ -622,58 +667,79 @@ __global__ void NormalSoftmaxBackward(T* input_grad, const T* output_grad, } template -void LaunchNormalSoftmaxForward(const platform::CUDADeviceContext& dev_ctx, - T* output_data, const T* input_data, - int high_dim, int mid_dim, int low_dim) { - using AccT = typename details::MPTypeTrait::Type; +void LaunchNormalSoftmaxForward(const GPUContext& dev_ctx, + T* output_data, + const T* input_data, + int high_dim, + int mid_dim, + int low_dim) { + using AccT = typename phi::dtype::MPTypeTrait::Type; dim3 grid, block; GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block); if (LogMode) { NormalSoftmaxForward< - T, AccT, + T, + AccT, LogSoftmaxForwardFunctor><<>>( output_data, input_data, high_dim, mid_dim, low_dim); } else { NormalSoftmaxForward< - T, AccT, SoftmaxForwardFunctor><<>>( + T, + AccT, + SoftmaxForwardFunctor><<>>( output_data, input_data, high_dim, mid_dim, low_dim); } } template -void LaunchNormalSoftmaxBackward(const platform::CUDADeviceContext& dev_ctx, - T* input_grad_data, const T* output_grad_data, - const T* output_data, int high_dim, - int mid_dim, int low_dim) { - using AccT = typename details::MPTypeTrait::Type; +void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx, + T* input_grad_data, + const T* output_grad_data, + const T* output_data, + int high_dim, + int mid_dim, + int low_dim) { + using AccT = typename phi::dtype::MPTypeTrait::Type; dim3 grid, block; GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block); if (LogMode) { NormalSoftmaxBackward< - T, AccT, + T, + AccT, LogSoftmaxBackwardFunctor><<>>( - input_grad_data, output_grad_data, output_data, high_dim, mid_dim, + input_grad_data, + output_grad_data, + output_data, + high_dim, + mid_dim, low_dim); } else { NormalSoftmaxBackward< - T, AccT, SoftmaxBackwardFunctor><<>>( - input_grad_data, output_grad_data, output_data, high_dim, mid_dim, + T, + AccT, + SoftmaxBackwardFunctor><<>>( + input_grad_data, + output_grad_data, + output_data, + high_dim, + mid_dim, low_dim); } } template -void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, - const Tensor& x, const int input_axis, - Tensor* out) { +void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, + const DenseTensor& x, + const int input_axis, + DenseTensor* out) { auto* out_data = out->data(); auto dims = x.dims(); const int rank = dims.size(); - const int axis = CanonicalAxis(input_axis, rank); + const int axis = phi::funcs::CanonicalAxis(input_axis, rank); const int dim = dims[axis]; - const int N = SizeToAxis(axis, dims); - const int D = SizeOutAxis(axis, dims); + const int N = phi::funcs::SizeToAxis(axis, dims); + const int D = phi::funcs::SizeOutAxis(axis, dims); constexpr int max_dim = 512; constexpr int warps_per_block = 4; @@ -697,25 +763,43 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, using T2 = typename VecT2::Type; if (dim % 4 == 0) { - SwitchWarpSoftmaxForward(blocks, threads, dev_ctx, - out_data, x.data(), N, dim, - dim, kDimLog2); + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + kDimLog2); } else if (dim % 2 == 0) { - SwitchWarpSoftmaxForward(blocks, threads, dev_ctx, - out_data, x.data(), N, dim, - dim, kDimLog2); + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + kDimLog2); } else { - SwitchWarpSoftmaxForward(blocks, threads, dev_ctx, - out_data, x.data(), N, dim, - dim, kDimLog2); + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + kDimLog2); } } else if (D > 1) { - LaunchNormalSoftmaxForward(dev_ctx, out_data, x.data(), N, - dim, D); + LaunchNormalSoftmaxForward( + dev_ctx, out_data, x.data(), N, dim, D); } else { ScopedTensorDescriptor desc; std::vector tensor_dims = {N, dim, D, 1}; - DataLayout layout = DataLayout::kNCHW; + GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW; #ifdef PADDLE_WITH_HIP miopenTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); #else @@ -728,46 +812,74 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE : MIOPEN_SOFTMAX_MODE_CHANNEL; if (LogMode) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( - handle, platform::CudnnDataType::kOne(), desc_, x.data(), - platform::CudnnDataType::kZero(), desc_, out_data, - MIOPEN_SOFTMAX_LOG, mode)); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::miopenSoftmaxForward_V2( + handle, + paddle::platform::CudnnDataType::kOne(), + desc_, + x.data(), + paddle::platform::CudnnDataType::kZero(), + desc_, + out_data, + MIOPEN_SOFTMAX_LOG, + mode)); } else { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( - handle, platform::CudnnDataType::kOne(), desc_, x.data(), - platform::CudnnDataType::kZero(), desc_, out_data, - MIOPEN_SOFTMAX_ACCURATE, mode)); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::miopenSoftmaxForward_V2( + handle, + paddle::platform::CudnnDataType::kOne(), + desc_, + x.data(), + paddle::platform::CudnnDataType::kZero(), + desc_, + out_data, + MIOPEN_SOFTMAX_ACCURATE, + mode)); } #else auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; if (LogMode) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward( - handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), - desc_, x.data(), platform::CudnnDataType::kZero(), desc_, + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cudnnSoftmaxForward( + handle, + CUDNN_SOFTMAX_LOG, + mode, + paddle::platform::CudnnDataType::kOne(), + desc_, + x.data(), + paddle::platform::CudnnDataType::kZero(), + desc_, out_data)); } else { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward( - handle, CUDNN_SOFTMAX_ACCURATE, mode, - platform::CudnnDataType::kOne(), desc_, x.data(), - platform::CudnnDataType::kZero(), desc_, out_data)); + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cudnnSoftmaxForward( + handle, + CUDNN_SOFTMAX_ACCURATE, + mode, + paddle::platform::CudnnDataType::kOne(), + desc_, + x.data(), + paddle::platform::CudnnDataType::kZero(), + desc_, + out_data)); } #endif } } template -void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, - const Tensor& out, const Tensor& dout, - const int input_axis, Tensor* dx) { +void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx, + const DenseTensor& out, + const DenseTensor& dout, + const int input_axis, + DenseTensor* dx) { auto* dx_data = dx->data(); auto dims = out.dims(); const int rank = dims.size(); - const int axis = CanonicalAxis(input_axis, rank); + const int axis = phi::funcs::CanonicalAxis(input_axis, rank); const int dim = dims[axis]; - const int N = SizeToAxis(axis, dims); - const int D = SizeOutAxis(axis, dims); + const int N = phi::funcs::SizeToAxis(axis, dims); + const int D = phi::funcs::SizeOutAxis(axis, dims); constexpr int max_dim = 512; constexpr int warps_per_block = 4; @@ -788,25 +900,46 @@ void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, using T4 = typename VecT4::Type; using T2 = typename VecT2::Type; if (dim % 4 == 0) { - SwitchWarpSoftmaxBackward( - blocks, threads, dev_ctx, dx_data, dout.data(), out.data(), N, - dim, dim, kDimLog2); + SwitchWarpSoftmaxBackward(blocks, + threads, + dev_ctx, + dx_data, + dout.data(), + out.data(), + N, + dim, + dim, + kDimLog2); } else if (dim % 2 == 0) { - SwitchWarpSoftmaxBackward( - blocks, threads, dev_ctx, dx_data, dout.data(), out.data(), N, - dim, dim, kDimLog2); + SwitchWarpSoftmaxBackward(blocks, + threads, + dev_ctx, + dx_data, + dout.data(), + out.data(), + N, + dim, + dim, + kDimLog2); } else { - SwitchWarpSoftmaxBackward( - blocks, threads, dev_ctx, dx_data, dout.data(), out.data(), N, - dim, dim, kDimLog2); + SwitchWarpSoftmaxBackward(blocks, + threads, + dev_ctx, + dx_data, + dout.data(), + out.data(), + N, + dim, + dim, + kDimLog2); } } else if (D > 1) { - LaunchNormalSoftmaxBackward(dev_ctx, dx_data, dout.data(), - out.data(), N, dim, D); + LaunchNormalSoftmaxBackward( + dev_ctx, dx_data, dout.data(), out.data(), N, dim, D); } else { ScopedTensorDescriptor desc; std::vector tensor_dims = {N, dim, D, 1}; - DataLayout layout = DataLayout::kNCHW; + GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW; #ifdef PADDLE_WITH_HIP miopenTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); #else @@ -819,33 +952,68 @@ void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE : MIOPEN_SOFTMAX_MODE_CHANNEL; if (LogMode) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2( - handle, platform::CudnnDataType::kOne(), desc_, out.data(), - desc_, dout.data(), platform::CudnnDataType::kZero(), desc_, - dx_data, MIOPEN_SOFTMAX_LOG, mode)); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::miopenSoftmaxBackward_V2( + handle, + paddle::platform::CudnnDataType::kOne(), + desc_, + out.data(), + desc_, + dout.data(), + paddle::platform::CudnnDataType::kZero(), + desc_, + dx_data, + MIOPEN_SOFTMAX_LOG, + mode)); } else { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2( - handle, platform::CudnnDataType::kOne(), desc_, out.data(), - desc_, dout.data(), platform::CudnnDataType::kZero(), desc_, - dx_data, MIOPEN_SOFTMAX_ACCURATE, mode)); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::miopenSoftmaxBackward_V2( + handle, + paddle::platform::CudnnDataType::kOne(), + desc_, + out.data(), + desc_, + dout.data(), + paddle::platform::CudnnDataType::kZero(), + desc_, + dx_data, + MIOPEN_SOFTMAX_ACCURATE, + mode)); } #else auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; if (LogMode) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxBackward( - handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), - desc_, out.data(), desc_, dout.data(), - platform::CudnnDataType::kZero(), desc_, dx_data)); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cudnnSoftmaxBackward( + handle, + CUDNN_SOFTMAX_LOG, + mode, + paddle::platform::CudnnDataType::kOne(), + desc_, + out.data(), + desc_, + dout.data(), + paddle::platform::CudnnDataType::kZero(), + desc_, + dx_data)); } else { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxBackward( - handle, CUDNN_SOFTMAX_ACCURATE, mode, - platform::CudnnDataType::kOne(), desc_, out.data(), desc_, - dout.data(), platform::CudnnDataType::kZero(), desc_, dx_data)); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cudnnSoftmaxBackward( + handle, + CUDNN_SOFTMAX_ACCURATE, + mode, + paddle::platform::CudnnDataType::kOne(), + desc_, + out.data(), + desc_, + dout.data(), + paddle::platform::CudnnDataType::kZero(), + desc_, + dx_data)); } #endif } } -} // namespace operators -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/kernels/gpudnn/softmax_grad_kernel_gpudnn.cu b/paddle/phi/kernels/gpudnn/softmax_grad_kernel_gpudnn.cu new file mode 100644 index 0000000000..56e5fef6e3 --- /dev/null +++ b/paddle/phi/kernels/gpudnn/softmax_grad_kernel_gpudnn.cu @@ -0,0 +1,50 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/softmax_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" + +namespace phi { + +template +void SoftmaxGradGPUDNNKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + SoftmaxBackwardCUDAKernelDriver(dev_ctx, out, out_grad, axis, x_grad); +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(softmax_grad, + GPUDNN, + ALL_LAYOUT, + phi::SoftmaxGradGPUDNNKernel, + float, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(softmax_grad, + GPUDNN, + ALL_LAYOUT, + phi::SoftmaxGradGPUDNNKernel, + float, + double, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/gpudnn/softmax_kernel_gpudnn.cu b/paddle/phi/kernels/gpudnn/softmax_kernel_gpudnn.cu new file mode 100644 index 0000000000..427d1729a1 --- /dev/null +++ b/paddle/phi/kernels/gpudnn/softmax_kernel_gpudnn.cu @@ -0,0 +1,49 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/softmax_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" + +namespace phi { + +template +void SoftmaxRawGPUDNNKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + DenseTensor* out) { + dev_ctx.template Alloc(out); + SoftmaxForwardCUDAKernelDriver(dev_ctx, x, axis, out); +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(softmax, + GPUDNN, + ALL_LAYOUT, + phi::SoftmaxRawGPUDNNKernel, + float, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(softmax, + GPUDNN, + ALL_LAYOUT, + phi::SoftmaxRawGPUDNNKernel, + float, + double, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h b/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h new file mode 100644 index 0000000000..915bf16a92 --- /dev/null +++ b/paddle/phi/kernels/impl/softmax_grad_kernel_impl.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/softmax_grad_kernel.h" + +#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" + +namespace phi { + +template +void SoftmaxGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad) { + const int rank = x_grad->dims().size(); + const int calc_axis = phi::funcs::CanonicalAxis(axis, rank); + int axis_dim = x_grad->dims()[calc_axis]; + + // allocate memory on device. + dev_ctx.template Alloc(x_grad); + if (x_grad->numel() == 0) { + return; + } + + const int n = phi::funcs::SizeToAxis(calc_axis, x_grad->dims()); + const int d = phi::funcs::SizeFromAxis(calc_axis, x_grad->dims()); + DenseTensor dX_2d, Out_2d, dOut_2d; + dX_2d.ShareDataWith(*x_grad).Resize({n, d}); + Out_2d.ShareDataWith(out).Resize({n, d}); + dOut_2d.ShareDataWith(out_grad).Resize({n, d}); + + paddle::operators::math::SoftmaxGradFunctor()( + dev_ctx, axis_dim, &Out_2d, &dOut_2d, &dX_2d); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/softmax_kernel_impl.h b/paddle/phi/kernels/impl/softmax_kernel_impl.h new file mode 100644 index 0000000000..6552f6ed58 --- /dev/null +++ b/paddle/phi/kernels/impl/softmax_kernel_impl.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/softmax_kernel.h" + +#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" + +namespace phi { + +template +void SoftmaxRawKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + DenseTensor* out) { + const int rank = x.dims().size(); + const int calc_axis = phi::funcs::CanonicalAxis(axis, rank); + int axis_dim = x.dims()[calc_axis]; + + // allocate memory on device. + dev_ctx.template Alloc(out); + if (out->numel() == 0) { + return; + } + + const int n = phi::funcs::SizeToAxis(calc_axis, x.dims()); + const int d = phi::funcs::SizeFromAxis(calc_axis, x.dims()); + DenseTensor X_2d, Out_2d; + X_2d.ShareDataWith(x).Resize({n, d}); + Out_2d.ShareDataWith(*out).Resize({n, d}); + paddle::operators::math::SoftmaxFunctor()( + dev_ctx, axis_dim, &X_2d, &Out_2d); +} + +} // namespace phi diff --git a/paddle/phi/kernels/softmax_grad_kernel.h b/paddle/phi/kernels/softmax_grad_kernel.h new file mode 100644 index 0000000000..4ecf65c1f1 --- /dev/null +++ b/paddle/phi/kernels/softmax_grad_kernel.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/cast_kernel.h" + +namespace phi { + +template +void SoftmaxGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/softmax_kernel.h b/paddle/phi/kernels/softmax_kernel.h new file mode 100644 index 0000000000..ca69d65277 --- /dev/null +++ b/paddle/phi/kernels/softmax_kernel.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/cast_kernel.h" + +namespace phi { + +template +void SoftmaxRawKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + DenseTensor* out); + +template +void SoftmaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + DataType dtype, + DenseTensor* out) { + auto cast_x = phi::Cast(dev_ctx, x, dtype); + phi::SoftmaxRawKernel(dev_ctx, axis, out); +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/softmax_sig.cc b/paddle/phi/ops/compat/softmax_sig.cc new file mode 100644 index 0000000000..65a915b51d --- /dev/null +++ b/paddle/phi/ops/compat/softmax_sig.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SoftmaxOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("softmax", {"X"}, {"axis"}, {"Out"}); +} + +KernelSignature SoftmaxGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("softmax_grad", + {"Out", GradVarName("Out")}, + {"axis"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(softmax, phi::SoftmaxOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(softmax_grad, phi::SoftmaxGradOpArgumentMapping); diff --git a/paddle/phi/tests/common/test_backend.cc b/paddle/phi/tests/common/test_backend.cc index d74a35c9ea..fa4ffc84bf 100644 --- a/paddle/phi/tests/common/test_backend.cc +++ b/paddle/phi/tests/common/test_backend.cc @@ -41,8 +41,8 @@ TEST(Backend, OStream) { oss << phi::Backend::MKLDNN; EXPECT_EQ(oss.str(), "MKLDNN"); oss.str(""); - oss << phi::Backend::CUDNN; - EXPECT_EQ(oss.str(), "CUDNN"); + oss << phi::Backend::GPUDNN; + EXPECT_EQ(oss.str(), "GPUDNN"); oss.str(""); try { oss << phi::Backend::NUM_BACKENDS; @@ -60,7 +60,7 @@ TEST(Backend, StringToBackend) { EXPECT_EQ(phi::Backend::XPU, pexp::StringToBackend("XPU")); EXPECT_EQ(phi::Backend::NPU, pexp::StringToBackend("NPU")); EXPECT_EQ(phi::Backend::MKLDNN, pexp::StringToBackend("MKLDNN")); - EXPECT_EQ(phi::Backend::CUDNN, pexp::StringToBackend("CUDNN")); + EXPECT_EQ(phi::Backend::GPUDNN, pexp::StringToBackend("GPUDNN")); EXPECT_EQ(static_cast( static_cast(phi::Backend::NUM_BACKENDS) + 1), pexp::StringToBackend("CustomBackend")); -- GitLab