未验证 提交 8895379a 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Support cudnn kernel moving & move softmax kernels (#39547)

* support cudnn kernel moving

* polish cmake rules

* add unittest for coverage

* remove orig kernel

* remove softmax cudnn kernel

* fix softmax test failed

* fix npu func error

* resolve conflict

* rename gpu dnn kernels

* fix name rule error

* fix compile error

* update fp16 namespace
上级 fed6de40
...@@ -81,6 +81,8 @@ function(kernel_declare TARGET_LIST) ...@@ -81,6 +81,8 @@ function(kernel_declare TARGET_LIST)
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, GPU, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, GPU, ALL_LAYOUT);\n")
elseif (${kernel_path} MATCHES "./xpu\/") elseif (${kernel_path} MATCHES "./xpu\/")
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, XPU, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, XPU, ALL_LAYOUT);\n")
elseif (${kernel_path} MATCHES "./gpudnn\/")
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, GPUDNN, ALL_LAYOUT);\n")
else () else ()
# deal with device independent kernel, now we use CPU temporaary # deal with device independent kernel, now we use CPU temporaary
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n")
...@@ -94,6 +96,7 @@ function(kernel_library TARGET) ...@@ -94,6 +96,7 @@ function(kernel_library TARGET)
set(cpu_srcs) set(cpu_srcs)
set(gpu_srcs) set(gpu_srcs)
set(xpu_srcs) set(xpu_srcs)
set(gpudnn_srcs)
set(selected_rows_srcs) set(selected_rows_srcs)
# parse and save the deps kerenl targets # parse and save the deps kerenl targets
set(all_srcs) set(all_srcs)
...@@ -101,6 +104,8 @@ function(kernel_library TARGET) ...@@ -101,6 +104,8 @@ function(kernel_library TARGET)
set(oneValueArgs SUB_DIR) set(oneValueArgs SUB_DIR)
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
set(target_build_flag 1)
cmake_parse_arguments(kernel_library "${options}" "${oneValueArgs}" cmake_parse_arguments(kernel_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN}) "${multiValueArgs}" ${ARGN})
...@@ -123,6 +128,9 @@ function(kernel_library TARGET) ...@@ -123,6 +128,9 @@ function(kernel_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc)
list(APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc) list(APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc)
endif() endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpudnn/${TARGET}_gpudnn.cu)
list(APPEND gpudnn_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpudnn/${TARGET}_gpudnn.cu)
endif()
endif() endif()
if (WITH_XPU) if (WITH_XPU)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/xpu/${TARGET}.cc) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/xpu/${TARGET}.cc)
...@@ -141,6 +149,7 @@ function(kernel_library TARGET) ...@@ -141,6 +149,7 @@ function(kernel_library TARGET)
list(APPEND all_srcs ${cpu_srcs}) list(APPEND all_srcs ${cpu_srcs})
list(APPEND all_srcs ${gpu_srcs}) list(APPEND all_srcs ${gpu_srcs})
list(APPEND all_srcs ${xpu_srcs}) list(APPEND all_srcs ${xpu_srcs})
list(APPEND all_srcs ${gpudnn_srcs})
foreach(src ${all_srcs}) foreach(src ${all_srcs})
file(READ ${src} target_content) file(READ ${src} target_content)
string(REGEX MATCHALL "#include \"paddle\/phi\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content}) string(REGEX MATCHALL "#include \"paddle\/phi\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content})
...@@ -166,21 +175,22 @@ function(kernel_library TARGET) ...@@ -166,21 +175,22 @@ function(kernel_library TARGET)
list(LENGTH cpu_srcs cpu_srcs_len) list(LENGTH cpu_srcs cpu_srcs_len)
list(LENGTH gpu_srcs gpu_srcs_len) list(LENGTH gpu_srcs gpu_srcs_len)
list(LENGTH xpu_srcs xpu_srcs_len) list(LENGTH xpu_srcs xpu_srcs_len)
list(LENGTH gpudnn_srcs gpudnn_srcs_len)
list(LENGTH selected_rows_srcs selected_rows_srcs_len) list(LENGTH selected_rows_srcs selected_rows_srcs_len)
# Build Target according different src organization # Build Target according different src organization
if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR
${xpu_srcs_len} GREATER 0) AND (${common_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) AND
${selected_rows_srcs_len} GREATER 0)) (${common_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0))
# If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule. # If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule.
if (WITH_GPU) if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
nv_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part) nv_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif() endif()
elseif (WITH_ROCM) elseif (WITH_ROCM)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part) hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif() endif()
else() else()
...@@ -190,14 +200,14 @@ function(kernel_library TARGET) ...@@ -190,14 +200,14 @@ function(kernel_library TARGET)
endif() endif()
endif() endif()
# If there are only specific device srcs, build target using this rule. # If there are only specific device srcs, build target using this rule.
elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
if (WITH_GPU) if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
elseif (WITH_ROCM) elseif (WITH_ROCM)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
else() else()
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
...@@ -234,35 +244,40 @@ function(kernel_library TARGET) ...@@ -234,35 +244,40 @@ function(kernel_library TARGET)
cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
else() else()
message(FATAL_ERROR "Cannot find any implementation for ${TARGET}") set(target_build_flag 0)
endif() endif()
if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR if (${target_build_flag} EQUAL 1)
${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR
${selected_rows_srcs_len} GREATER 0) ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR
# append target into PHI_KERNELS property ${gpudnn_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0)
get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS) # append target into PHI_KERNELS property
set(phi_kernels ${phi_kernels} ${TARGET}) get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS)
set_property(GLOBAL PROPERTY PHI_KERNELS ${phi_kernels}) set(phi_kernels ${phi_kernels} ${TARGET})
endif() set_property(GLOBAL PROPERTY PHI_KERNELS ${phi_kernels})
endif()
# parse kernel name and auto generate kernel declaration # parse kernel name and auto generate kernel declaration
# here, we don't need to check WITH_XXX, because if not WITH_XXX, the # here, we don't need to check WITH_XXX, because if not WITH_XXX, the
# xxx_srcs_len will be equal to 0 # xxx_srcs_len will be equal to 0
if (${common_srcs_len} GREATER 0) if (${common_srcs_len} GREATER 0)
kernel_declare(${common_srcs}) kernel_declare(${common_srcs})
endif() endif()
if (${cpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0)
kernel_declare(${cpu_srcs}) kernel_declare(${cpu_srcs})
endif() endif()
if (${gpu_srcs_len} GREATER 0) if (${gpu_srcs_len} GREATER 0)
kernel_declare(${gpu_srcs}) kernel_declare(${gpu_srcs})
endif() endif()
if (${xpu_srcs_len} GREATER 0) if (${xpu_srcs_len} GREATER 0)
kernel_declare(${xpu_srcs}) kernel_declare(${xpu_srcs})
endif() endif()
if (${selected_rows_srcs_len} GREATER 0) if (${gpudnn_srcs_len} GREATER 0)
kernel_declare(${selected_rows_srcs}) kernel_declare(${gpudnn_srcs})
endif()
if (${selected_rows_srcs_len} GREATER 0)
kernel_declare(${selected_rows_srcs})
endif()
endif() endif()
endfunction() endfunction()
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
USE_OP(softmax); USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
......
...@@ -67,7 +67,7 @@ OpKernelType TransPtenKernelKeyToOpKernelType( ...@@ -67,7 +67,7 @@ OpKernelType TransPtenKernelKeyToOpKernelType(
LibraryType library_type = LibraryType::kPlain; LibraryType library_type = LibraryType::kPlain;
if (kernel_key.backend() == phi::Backend::MKLDNN) { if (kernel_key.backend() == phi::Backend::MKLDNN) {
library_type = LibraryType::kMKLDNN; library_type = LibraryType::kMKLDNN;
} else if (kernel_key.backend() == phi::Backend::CUDNN) { } else if (kernel_key.backend() == phi::Backend::GPUDNN) {
library_type = LibraryType::kCUDNN; library_type = LibraryType::kCUDNN;
} else { } else {
// do nothing // do nothing
...@@ -82,7 +82,7 @@ phi::KernelKey TransOpKernelTypeToPtenKernelKey( ...@@ -82,7 +82,7 @@ phi::KernelKey TransOpKernelTypeToPtenKernelKey(
if (kernel_type.library_type_ == LibraryType::kMKLDNN) { if (kernel_type.library_type_ == LibraryType::kMKLDNN) {
backend = phi::Backend::MKLDNN; backend = phi::Backend::MKLDNN;
} else if (kernel_type.library_type_ == LibraryType::kCUDNN) { } else if (kernel_type.library_type_ == LibraryType::kCUDNN) {
backend = phi::Backend::CUDNN; backend = phi::Backend::GPUDNN;
} else { } else {
// do // do
} }
......
...@@ -42,7 +42,7 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) { ...@@ -42,7 +42,7 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) {
#endif #endif
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
phi::KernelKey kernel_key_cudnn(phi::Backend::CUDNN, phi::DataLayout::NCHW, phi::KernelKey kernel_key_cudnn(phi::Backend::GPUDNN, phi::DataLayout::NCHW,
phi::DataType::FLOAT32); phi::DataType::FLOAT32);
op_kernel_type = op_kernel_type =
paddle::framework::TransPtenKernelKeyToOpKernelType(kernel_key_cudnn); paddle::framework::TransPtenKernelKeyToOpKernelType(kernel_key_cudnn);
...@@ -53,3 +53,38 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) { ...@@ -53,3 +53,38 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) {
paddle::framework::LibraryType::kCUDNN); paddle::framework::LibraryType::kCUDNN);
#endif #endif
} }
TEST(PtenUtils, TransOpKernelTypeToPtenKernelKey) {
paddle::framework::OpKernelType op_kernel_type(
paddle::framework::proto::VarType::FP32, paddle::platform::CPUPlace(),
paddle::framework::DataLayout::kNCHW);
auto kernel_key =
paddle::framework::TransOpKernelTypeToPtenKernelKey(op_kernel_type);
ASSERT_EQ(kernel_key.dtype(), phi::DataType::FLOAT32);
ASSERT_EQ(kernel_key.layout(), phi::DataLayout::NCHW);
ASSERT_EQ(kernel_key.backend(), phi::Backend::CPU);
#ifdef PADDLE_WITH_MKLDNN
paddle::framework::OpKernelType op_kernel_type_mkldnn(
paddle::framework::proto::VarType::FP32, paddle::platform::CPUPlace(),
paddle::framework::DataLayout::kMKLDNN,
paddle::framework::LibraryType::kMKLDNN);
auto kernel_key_mkldnn = paddle::framework::TransOpKernelTypeToPtenKernelKey(
op_kernel_type_mkldnn);
ASSERT_EQ(kernel_key_mkldnn.dtype(), phi::DataType::FLOAT32);
ASSERT_EQ(kernel_key_mkldnn.layout(), phi::DataLayout::MKLDNN);
ASSERT_EQ(kernel_key_mkldnn.backend(), phi::Backend::MKLDNN);
#endif
#ifdef PADDLE_WITH_CUDA
paddle::framework::OpKernelType op_kernel_type_cudnn(
paddle::framework::proto::VarType::FP32, paddle::platform::CPUPlace(),
paddle::framework::DataLayout::kNCHW,
paddle::framework::LibraryType::kCUDNN);
auto kernel_key_cudnn =
paddle::framework::TransOpKernelTypeToPtenKernelKey(op_kernel_type_cudnn);
ASSERT_EQ(kernel_key_cudnn.dtype(), phi::DataType::FLOAT32);
ASSERT_EQ(kernel_key_cudnn.layout(), phi::DataLayout::NCHW);
ASSERT_EQ(kernel_key_cudnn.backend(), phi::Backend::GPUDNN);
#endif
}
...@@ -88,5 +88,5 @@ class SoftMaxOpConverter : public OpConverter { ...@@ -88,5 +88,5 @@ class SoftMaxOpConverter : public OpConverter {
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(softmax); USE_OP_ITSELF(softmax);
REGISTER_TRT_OP_CONVERTER(softmax, SoftMaxOpConverter); REGISTER_TRT_OP_CONVERTER(softmax, SoftMaxOpConverter);
...@@ -45,4 +45,4 @@ TEST(SoftMaxOpConverter, main) { ...@@ -45,4 +45,4 @@ TEST(SoftMaxOpConverter, main) {
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(softmax); USE_OP_ITSELF(softmax);
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -98,8 +99,8 @@ class CSoftmaxWithCrossEntropyOpCUDAKernel : public framework::OpKernel<T> { ...@@ -98,8 +99,8 @@ class CSoftmaxWithCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
const auto& labels_dims = labels->dims(); const auto& labels_dims = labels->dims();
const int axis = logits_dims.size() - 1; const int axis = logits_dims.size() - 1;
const int N = SizeToAxis(axis, logits_dims); const int N = phi::funcs::SizeToAxis(axis, logits_dims);
const int D = SizeFromAxis(axis, logits_dims); const int D = phi::funcs::SizeFromAxis(axis, logits_dims);
Tensor logits_2d, softmax_2d, loss_2d; Tensor logits_2d, softmax_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({N, D}); logits_2d.ShareDataWith(*logits).Resize({N, D});
...@@ -220,8 +221,8 @@ class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -220,8 +221,8 @@ class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
} }
const auto sofrmax_dims = softmax->dims(); const auto sofrmax_dims = softmax->dims();
const int axis = sofrmax_dims.size() - 1; const int axis = sofrmax_dims.size() - 1;
const int N = SizeToAxis(axis, sofrmax_dims); const int N = phi::funcs::SizeToAxis(axis, sofrmax_dims);
const int D = SizeFromAxis(axis, sofrmax_dims); const int D = phi::funcs::SizeFromAxis(axis, sofrmax_dims);
Tensor logit_grad_2d; Tensor logit_grad_2d;
logit_grad_2d.ShareDataWith(*logit_grad).Resize({N, D}); logit_grad_2d.ShareDataWith(*logit_grad).Resize({N, D});
......
...@@ -23,7 +23,6 @@ limitations under the License. */ ...@@ -23,7 +23,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/softmax_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/dropout_impl.cu.h" #include "paddle/fluid/operators/dropout_impl.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/softmax_cudnn_op.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h" #include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -123,11 +123,11 @@ class FMHARef { ...@@ -123,11 +123,11 @@ class FMHARef {
T, T>( T, T>(
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>()); dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *src_mask_out_tensor, phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *src_mask_out_tensor,
softmax_axis, softmax_out_tensor); softmax_axis, softmax_out_tensor);
} else { } else {
SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out_tensor, softmax_axis, phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out_tensor,
softmax_out_tensor); softmax_axis, softmax_out_tensor);
} }
transB = CblasNoTrans; transB = CblasNoTrans;
...@@ -251,9 +251,9 @@ class FMHARef { ...@@ -251,9 +251,9 @@ class FMHARef {
} }
if (src_mask_tensor != nullptr) { if (src_mask_tensor != nullptr) {
SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, softmax_out_tensor, phi::SoftmaxBackwardCUDAKernelDriver<T>(
*softmax_out_grad_tensor, softmax_axis, dev_ctx_, softmax_out_tensor, *softmax_out_grad_tensor, softmax_axis,
src_mask_out_grad_tensor); src_mask_out_grad_tensor);
// recall LaunchElementwiseCudaKernel fw: src_mask_out = qk_out + // recall LaunchElementwiseCudaKernel fw: src_mask_out = qk_out +
// src_mask // src_mask
...@@ -272,9 +272,9 @@ class FMHARef { ...@@ -272,9 +272,9 @@ class FMHARef {
} }
} else { } else {
SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, softmax_out_tensor, phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, softmax_out_tensor,
*softmax_out_grad_tensor, softmax_axis, *softmax_out_grad_tensor,
qk_out_grad_tensor); softmax_axis, qk_out_grad_tensor);
} }
T* qk_out_grad_data = qk_out_grad_tensor->data<T>(); T* qk_out_grad_data = qk_out_grad_tensor->data<T>();
......
...@@ -26,6 +26,7 @@ namespace cub = hipcub; ...@@ -26,6 +26,7 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
...@@ -246,8 +247,8 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> { ...@@ -246,8 +247,8 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
const auto& labels_dims = labels->dims(); const auto& labels_dims = labels->dims();
const int axis = logits_dims.size() - 1; const int axis = logits_dims.size() - 1;
const int N = SizeToAxis(axis, logits_dims); const int N = phi::funcs::SizeToAxis(axis, logits_dims);
const int D = SizeFromAxis(axis, logits_dims); const int D = phi::funcs::SizeFromAxis(axis, logits_dims);
int blocks = NumBlocks(N); int blocks = NumBlocks(N);
int threads = kNumCUDAThreads; int threads = kNumCUDAThreads;
...@@ -401,8 +402,8 @@ class MarginCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -401,8 +402,8 @@ class MarginCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
const auto sofrmax_dims = softmax->dims(); const auto sofrmax_dims = softmax->dims();
const int axis = sofrmax_dims.size() - 1; const int axis = sofrmax_dims.size() - 1;
const int N = SizeToAxis(axis, sofrmax_dims); const int N = phi::funcs::SizeToAxis(axis, sofrmax_dims);
const int D = SizeFromAxis(axis, sofrmax_dims); const int D = phi::funcs::SizeFromAxis(axis, sofrmax_dims);
if (return_softmax) { if (return_softmax) {
framework::TensorCopy(*softmax, context.GetPlace(), framework::TensorCopy(*softmax, context.GetPlace(),
......
...@@ -22,7 +22,6 @@ limitations under the License. */ ...@@ -22,7 +22,6 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/softmax_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/math/softmax_impl.h" #include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,6 +27,13 @@ template class SoftmaxFunctor<platform::CPUDeviceContext, double, false>; ...@@ -26,6 +27,13 @@ template class SoftmaxFunctor<platform::CPUDeviceContext, double, false>;
template class SoftmaxGradFunctor<platform::CPUDeviceContext, float>; template class SoftmaxGradFunctor<platform::CPUDeviceContext, float>;
template class SoftmaxGradFunctor<platform::CPUDeviceContext, double>; template class SoftmaxGradFunctor<platform::CPUDeviceContext, double>;
template class SoftmaxFunctor<phi::CPUContext, float, true>;
template class SoftmaxFunctor<phi::CPUContext, float, false>;
template class SoftmaxFunctor<phi::CPUContext, double, true>;
template class SoftmaxFunctor<phi::CPUContext, double, false>;
template class SoftmaxGradFunctor<phi::CPUContext, float>;
template class SoftmaxGradFunctor<phi::CPUContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/math/softmax_impl.h" #include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -139,6 +140,16 @@ template class SoftmaxGradFunctor<platform::CUDADeviceContext, double>; ...@@ -139,6 +140,16 @@ template class SoftmaxGradFunctor<platform::CUDADeviceContext, double>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext, template class SoftmaxGradFunctor<platform::CUDADeviceContext,
platform::float16>; platform::float16>;
template class SoftmaxFunctor<phi::GPUContext, platform::float16, false>;
template class SoftmaxFunctor<phi::GPUContext, platform::float16, true>;
template class SoftmaxFunctor<phi::GPUContext, float, false>;
template class SoftmaxFunctor<phi::GPUContext, double, false>;
template class SoftmaxFunctor<phi::GPUContext, float, true>;
template class SoftmaxFunctor<phi::GPUContext, double, true>;
template class SoftmaxGradFunctor<phi::GPUContext, float>;
template class SoftmaxGradFunctor<phi::GPUContext, double>;
template class SoftmaxGradFunctor<phi::GPUContext, platform::float16>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -70,7 +71,8 @@ class SoftmaxMKLDNNHandler ...@@ -70,7 +71,8 @@ class SoftmaxMKLDNNHandler
out_grad->dims(), in_x_grad->dims())); out_grad->dims(), in_x_grad->dims()));
auto dims = out_grad->dims(); // input and output share the same shape auto dims = out_grad->dims(); // input and output share the same shape
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size()); const int axis =
phi::funcs::CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
auto softmax_tz = phi::vectorize<int64_t>(dims); auto softmax_tz = phi::vectorize<int64_t>(dims);
auto data_softmax_md = MKLDNNMemDesc( auto data_softmax_md = MKLDNNMemDesc(
...@@ -96,7 +98,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -96,7 +98,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
bool is_inplaced = input->IsSharedBufferWith(*output); bool is_inplaced = input->IsSharedBufferWith(*output);
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size()); const int axis =
phi::funcs::CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size());
SoftmaxMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), input, SoftmaxMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), input,
output, axis); output, axis);
......
...@@ -31,7 +31,7 @@ USE_OP(elementwise_mul); ...@@ -31,7 +31,7 @@ USE_OP(elementwise_mul);
USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN);
USE_OP(relu); USE_OP(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP(softmax); USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP(conv2d); USE_OP(conv2d);
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32); USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);
......
...@@ -29,7 +29,7 @@ USE_OP_ITSELF(elementwise_add); ...@@ -29,7 +29,7 @@ USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(relu); USE_OP(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP(softmax); USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
namespace paddle { namespace paddle {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/softmax_cudnn_op.cu.h"
namespace paddle {
namespace operators {
template <typename T, bool LogMode = false>
class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
int input_axis = ctx.Attr<int>("axis");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
SoftmaxForwardCUDAKernelDriver<T>(dev_ctx, *x, input_axis, out);
}
};
template <typename T, bool LogMode = false>
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
int input_axis = ctx.Attr<int>("axis");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx, *out, *dout, input_axis, dx);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
ops::SoftmaxCUDNNKernel<float>,
ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>,
ops::SoftmaxGradCUDNNKernel<plat::float16>);
#else
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
ops::SoftmaxCUDNNKernel<float>,
ops::SoftmaxCUDNNKernel<double>,
ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>,
ops::SoftmaxGradCUDNNKernel<double>,
ops::SoftmaxGradCUDNNKernel<plat::float16>);
#endif
...@@ -12,12 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -251,10 +250,3 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, ...@@ -251,10 +250,3 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
ops::SoftmaxOpGradMaker<paddle::imperative::OpBase>, ops::SoftmaxOpGradMaker<paddle::imperative::OpBase>,
ops::SoftmaxInplaceInferer); ops::SoftmaxInplaceInferer);
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad); REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
softmax_grad,
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/softmax.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
static inline int CanonicalAxis(const int axis, const int rank) {
if (axis < 0) {
return axis + rank;
}
return axis;
}
static inline int SizeToAxis(const int axis, DDim dims) {
int size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline int SizeFromAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
static inline int SizeOutAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
template <typename DeviceContext, typename T>
class SoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
const int rank = X->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = X->dims()[axis];
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
if (Out->numel() == 0) {
return;
}
const int n = SizeToAxis(axis, X->dims());
const int d = SizeFromAxis(axis, X->dims());
Tensor X_2d, Out_2d;
X_2d.ShareDataWith(*X).Resize({n, d});
Out_2d.ShareDataWith(*Out).Resize({n, d});
math::SoftmaxFunctor<DeviceContext, T, false>()(
context.template device_context<DeviceContext>(), axis_dim, &X_2d,
&Out_2d);
}
};
template <typename DeviceContext, typename T>
class SoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
const int rank = dX->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = dX->dims()[axis];
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
if (dX->numel() == 0) {
return;
}
const int n = SizeToAxis(axis, dX->dims());
const int d = SizeFromAxis(axis, dX->dims());
Tensor dX_2d, Out_2d, dOut_2d;
dX_2d.ShareDataWith(*dX).Resize({n, d});
Out_2d.ShareDataWith(*Out).Resize({n, d});
dOut_2d.ShareDataWith(*dOut).Resize({n, d});
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), axis_dim, &Out_2d,
&dOut_2d, &dX_2d);
}
};
} // namespace operators
} // namespace paddle
...@@ -12,8 +12,9 @@ limitations under the License. */ ...@@ -12,8 +12,9 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -51,7 +52,7 @@ class SoftmaxGradNPUKernel : public framework::OpKernel<T> { ...@@ -51,7 +52,7 @@ class SoftmaxGradNPUKernel : public framework::OpKernel<T> {
auto dims = dX->dims(); auto dims = dX->dims();
const int rank = dims.size(); const int rank = dims.size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank); const int axis = phi::funcs::CanonicalAxis(ctx.Attr<int>("axis"), rank);
int64_t first_dim = 1; int64_t first_dim = 1;
int64_t sec_dim = 1; int64_t sec_dim = 1;
for (int i = 0; i < axis; i++) { for (int i = 0; i < axis; i++) {
......
...@@ -29,7 +29,7 @@ limitations under the License. */ ...@@ -29,7 +29,7 @@ limitations under the License. */
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
USE_OP(softmax); USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, NPU); USE_OP_DEVICE_KERNEL(softmax, NPU);
template <typename T> template <typename T>
......
...@@ -11,8 +11,8 @@ limitations under the License. */ ...@@ -11,8 +11,8 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,7 +29,7 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> { ...@@ -29,7 +29,7 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out"); auto* out = context.Output<Tensor>("Out");
const int rank = x->dims().size(); const int rank = x->dims().size();
int axis = CanonicalAxis(context.Attr<int>("axis"), rank); int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
// allocate memory on device. // allocate memory on device.
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
...@@ -88,7 +88,7 @@ class SoftmaxGradXPUKernel : public framework::OpKernel<T> { ...@@ -88,7 +88,7 @@ class SoftmaxGradXPUKernel : public framework::OpKernel<T> {
auto* dout = context.Input<Tensor>(framework::GradVarName("Out")); auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<Tensor>(framework::GradVarName("X")); auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
const int rank = dx->dims().size(); const int rank = dx->dims().size();
int axis = CanonicalAxis(context.Attr<int>("axis"), rank); int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
// allocate memory on device. // allocate memory on device.
dx->mutable_data<T>(context.GetPlace()); dx->mutable_data<T>(context.GetPlace());
......
...@@ -153,7 +153,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -153,7 +153,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
"Attr(axis) value should be in range [-R, R-1], " "Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits).")); "R is the rank of Input(Logits)."));
axis = CanonicalAxis(axis, logits_rank); axis = phi::funcs::CanonicalAxis(axis, logits_rank);
for (int i = 0; i < logits_rank; i++) { for (int i = 0; i < logits_rank; i++) {
if (i != axis) { if (i != axis) {
if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) { if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
...@@ -250,7 +250,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -250,7 +250,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
"Attr(axis) value should be in range [-R, R-1], " "Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits).")); "R is the rank of Input(Logits)."));
axis = CanonicalAxis(axis, softmax_rank); axis = phi::funcs::CanonicalAxis(axis, softmax_rank);
for (int i = 0; i < softmax_rank; i++) { for (int i = 0; i < softmax_rank; i++) {
if (i != axis) { if (i != axis) {
if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) { if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
......
...@@ -17,12 +17,12 @@ namespace cub = hipcub; ...@@ -17,12 +17,12 @@ namespace cub = hipcub;
#endif #endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/softmax_cudnn_op.cu.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -236,7 +236,7 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, ...@@ -236,7 +236,7 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src,
max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax; max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax;
} }
} }
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value); phi::WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
// compute sum: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} } // compute sum: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
AccT sum[kBatchSize]; AccT sum[kBatchSize];
...@@ -276,7 +276,7 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, ...@@ -276,7 +276,7 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src,
} }
} }
} }
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum); phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write data // write data
#pragma unroll #pragma unroll
...@@ -566,7 +566,7 @@ __global__ void CrossEntropySoftLabel(T* loss, T* softmaxwrt, const T* softmax, ...@@ -566,7 +566,7 @@ __global__ void CrossEntropySoftLabel(T* loss, T* softmaxwrt, const T* softmax,
} }
} }
} }
WarpReduceSum<T, kBatchSize, kWarpSize>(sum); phi::WarpReduceSum<T, kBatchSize, kWarpSize>(sum);
__syncthreads(); __syncthreads();
__shared__ T sumshare[kWarpPerBatch][kBatchPerBlock][kBatchSize]; __shared__ T sumshare[kWarpPerBatch][kBatchPerBlock][kBatchSize];
...@@ -674,7 +674,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src, ...@@ -674,7 +674,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src,
: static_cast<AccT>(valmax); : static_cast<AccT>(valmax);
} }
} }
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value); phi::WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
// compute sum // compute sum
AccT sum[kBatchSize]{0.0}; AccT sum[kBatchSize]{0.0};
...@@ -694,7 +694,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src, ...@@ -694,7 +694,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src,
} }
} }
} }
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum); phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// log_softmax and loss // log_softmax and loss
AccT sumloss[kBatchSize]{0.0}; AccT sumloss[kBatchSize]{0.0};
...@@ -737,7 +737,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src, ...@@ -737,7 +737,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src,
} }
// loss // loss
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sumloss); phi::WarpReduceSum<AccT, kBatchSize, kWarpSize>(sumloss);
for (int i = 0; i < kBatchSize; i++) { for (int i = 0; i < kBatchSize; i++) {
if (i >= local_batches) break; if (i >= local_batches) break;
...@@ -950,11 +950,12 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -950,11 +950,12 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
Tensor* loss = context.Output<Tensor>("Loss"); Tensor* loss = context.Output<Tensor>("Loss");
const int rank = softmax->dims().size(); const int rank = softmax->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis =
phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
const int axis_dim = softmax->dims()[axis]; const int axis_dim = softmax->dims()[axis];
const int n = SizeToAxis(axis, softmax->dims()); const int n = phi::funcs::SizeToAxis(axis, softmax->dims());
const int d = SizeFromAxis(axis, softmax->dims()); const int d = phi::funcs::SizeFromAxis(axis, softmax->dims());
auto* softmax_out_data = auto* softmax_out_data =
softmax_out->template mutable_data<T>(context.GetPlace()); softmax_out->template mutable_data<T>(context.GetPlace());
...@@ -1035,11 +1036,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -1035,11 +1036,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
Tensor* loss = context.Output<Tensor>("Loss"); Tensor* loss = context.Output<Tensor>("Loss");
const int rank = logits->dims().size(); const int rank = logits->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logits->dims()[axis]; int axis_dim = logits->dims()[axis];
const int64_t n = SizeToAxis(axis, logits->dims()); const int64_t n = phi::funcs::SizeToAxis(axis, logits->dims());
const int64_t d = SizeFromAxis(axis, logits->dims()); const int64_t d = phi::funcs::SizeFromAxis(axis, logits->dims());
auto* softmax_data = softmax->template mutable_data<T>(context.GetPlace()); auto* softmax_data = softmax->template mutable_data<T>(context.GetPlace());
auto* loss_data = loss->template mutable_data<T>(context.GetPlace()); auto* loss_data = loss->template mutable_data<T>(context.GetPlace());
...@@ -1118,11 +1119,11 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -1118,11 +1119,11 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
T* logit_grad_data = logit_grad->template data<T>(); T* logit_grad_data = logit_grad->template data<T>();
const int rank = logit_grad->dims().size(); const int rank = logit_grad->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logit_grad->dims()[axis]; int axis_dim = logit_grad->dims()[axis];
const int64_t n = SizeToAxis(axis, logit_grad->dims()); const int64_t n = phi::funcs::SizeToAxis(axis, logit_grad->dims());
const int64_t d = SizeFromAxis(axis, logit_grad->dims()); const int64_t d = phi::funcs::SizeFromAxis(axis, logit_grad->dims());
const int64_t remain = d / axis_dim; const int64_t remain = d / axis_dim;
#ifdef __HIPCC__ #ifdef __HIPCC__
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -84,7 +84,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> { ...@@ -84,7 +84,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
Tensor* softmax_out = context.Output<Tensor>("Softmax"); Tensor* softmax_out = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss"); Tensor* loss = context.Output<Tensor>("Loss");
const int rank = softmax->dims().size(); const int rank = softmax->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis =
phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = softmax->dims()[axis]; int axis_dim = softmax->dims()[axis];
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
...@@ -97,7 +98,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> { ...@@ -97,7 +98,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
softmax_out->mutable_data<T>(context.GetPlace()); softmax_out->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace()); loss->mutable_data<T>(context.GetPlace());
const int n = SizeToAxis(axis, softmax->dims()); const int n = phi::funcs::SizeToAxis(axis, softmax->dims());
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument( n, 0, platform::errors::InvalidArgument(
...@@ -105,7 +106,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> { ...@@ -105,7 +106,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
"SizeToAxis of softmax is %d.", "SizeToAxis of softmax is %d.",
n)); n));
const int d = SizeFromAxis(axis, softmax->dims()); const int d = phi::funcs::SizeFromAxis(axis, softmax->dims());
Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d; Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
softmax_2d.ShareDataWith(*softmax).Resize({n, d}); softmax_2d.ShareDataWith(*softmax).Resize({n, d});
...@@ -133,7 +134,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> { ...@@ -133,7 +134,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
Tensor* loss = context.Output<Tensor>("Loss"); Tensor* loss = context.Output<Tensor>("Loss");
const int rank = logits->dims().size(); const int rank = logits->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logits->dims()[axis]; int axis_dim = logits->dims()[axis];
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
axis_dim, 0, axis_dim, 0,
...@@ -145,14 +146,14 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> { ...@@ -145,14 +146,14 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
softmax->mutable_data<T>(context.GetPlace()); softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace()); loss->mutable_data<T>(context.GetPlace());
const int n = SizeToAxis(axis, logits->dims()); const int n = phi::funcs::SizeToAxis(axis, logits->dims());
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument( n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received " "The size of axis should be larger than 0, but received "
"SizeToAxis of logits is %d.", "SizeToAxis of logits is %d.",
n)); n));
const int d = SizeFromAxis(axis, logits->dims()); const int d = phi::funcs::SizeFromAxis(axis, logits->dims());
Tensor logits_2d, softmax_2d, labels_2d, loss_2d; Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({n, d}); logits_2d.ShareDataWith(*logits).Resize({n, d});
softmax_2d.ShareDataWith(*softmax).Resize({n, d}); softmax_2d.ShareDataWith(*softmax).Resize({n, d});
...@@ -192,7 +193,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -192,7 +193,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
auto ignore_index = context.Attr<int>("ignore_index"); auto ignore_index = context.Attr<int>("ignore_index");
const int rank = logit_grad->dims().size(); const int rank = logit_grad->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logit_grad->dims()[axis]; int axis_dim = logit_grad->dims()[axis];
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
axis_dim, 0, axis_dim, 0,
...@@ -201,14 +202,14 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -201,14 +202,14 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
"axis dimention is %d.", "axis dimention is %d.",
axis_dim)); axis_dim));
const int n = SizeToAxis(axis, logit_grad->dims()); const int n = phi::funcs::SizeToAxis(axis, logit_grad->dims());
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument( n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received " "The size of axis should be larger than 0, but received "
"SizeToAxis of logit_grad is %d.", "SizeToAxis of logit_grad is %d.",
n)); n));
const int d = SizeFromAxis(axis, logit_grad->dims()); const int d = phi::funcs::SizeFromAxis(axis, logit_grad->dims());
Tensor logit_grad_2d, labels_2d, out_grad_2d; Tensor logit_grad_2d, labels_2d, out_grad_2d;
logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d}); logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n}); labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n});
......
...@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle { namespace paddle {
...@@ -40,15 +41,16 @@ class SoftmaxWithCrossEntropyNPUKernel : public framework::OpKernel<T> { ...@@ -40,15 +41,16 @@ class SoftmaxWithCrossEntropyNPUKernel : public framework::OpKernel<T> {
"the npu kernel of softmax_with_cross_entropy.")); "the npu kernel of softmax_with_cross_entropy."));
const int rank = logits->dims().size(); const int rank = logits->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank); const int axis = phi::funcs::CanonicalAxis(ctx.Attr<int>("axis"), rank);
const int n = SizeToAxis(axis, logits->dims()); const int n = phi::funcs::SizeToAxis(axis, logits->dims());
const int d = SizeFromAxis(axis, logits->dims()); const int d = phi::funcs::SizeFromAxis(axis, logits->dims());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
labels->numel(), n, labels->numel(), n,
platform::errors::Unimplemented( platform::errors::Unimplemented(
"The size of labels should be equal to SizeToAxis of logits," "The size of labels should be equal to phi::funcs::SizeToAxis of "
"but got size of labels is %d and SizeToAxis is %d.", "logits,"
"but got size of labels is %d and phi::funcs::SizeToAxis is %d.",
labels->numel(), n)); labels->numel(), n));
loss->mutable_data<T>(ctx.GetPlace()); loss->mutable_data<T>(ctx.GetPlace());
...@@ -97,9 +99,9 @@ class SoftmaxWithCrossEntropyGradNPUKernel : public framework::OpKernel<T> { ...@@ -97,9 +99,9 @@ class SoftmaxWithCrossEntropyGradNPUKernel : public framework::OpKernel<T> {
logits_grad->mutable_data<T>(ctx.GetPlace()); logits_grad->mutable_data<T>(ctx.GetPlace());
const int rank = logits_grad->dims().size(); const int rank = logits_grad->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank); const int axis = phi::funcs::CanonicalAxis(ctx.Attr<int>("axis"), rank);
const int n = SizeToAxis(axis, logits_grad->dims()); const int n = phi::funcs::SizeToAxis(axis, logits_grad->dims());
const int d = SizeFromAxis(axis, logits_grad->dims()); const int d = phi::funcs::SizeFromAxis(axis, logits_grad->dims());
Tensor logits_grad_2d, loss_grad_1d, backprop_2d; Tensor logits_grad_2d, loss_grad_1d, backprop_2d;
......
...@@ -38,13 +38,13 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> { ...@@ -38,13 +38,13 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
Tensor* softmax = context.Output<Tensor>("Softmax"); Tensor* softmax = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss"); Tensor* loss = context.Output<Tensor>("Loss");
const int rank = logits->dims().size(); const int rank = logits->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
PADDLE_ENFORCE_EQ(axis, rank - 1, platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(axis, rank - 1, platform::errors::InvalidArgument(
"axis should == rank - 1")); "axis should == rank - 1"));
softmax->mutable_data<T>(context.GetPlace()); softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace()); loss->mutable_data<T>(context.GetPlace());
const int n = SizeToAxis(axis, logits->dims()); const int n = phi::funcs::SizeToAxis(axis, logits->dims());
const int d = SizeFromAxis(axis, logits->dims()); const int d = phi::funcs::SizeFromAxis(axis, logits->dims());
std::vector<int> logits_dims = phi::vectorize<int>(logits->dims()); std::vector<int> logits_dims = phi::vectorize<int>(logits->dims());
const bool soft_label = context.Attr<bool>("soft_label"); const bool soft_label = context.Attr<bool>("soft_label");
...@@ -122,11 +122,11 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel<T> { ...@@ -122,11 +122,11 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel<T> {
auto ignore_index = context.Attr<int>("ignore_index"); auto ignore_index = context.Attr<int>("ignore_index");
const int rank = logit_grad->dims().size(); const int rank = logit_grad->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
PADDLE_ENFORCE_EQ(axis, rank - 1, platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(axis, rank - 1, platform::errors::InvalidArgument(
"axis should == rank - 1")); "axis should == rank - 1"));
const int n = SizeToAxis(axis, logit_grad->dims()); const int n = phi::funcs::SizeToAxis(axis, logit_grad->dims());
const int d = SizeFromAxis(axis, logit_grad->dims()); const int d = phi::funcs::SizeFromAxis(axis, logit_grad->dims());
auto& dev_ctx = auto& dev_ctx =
context.template device_context<platform::XPUDeviceContext>(); context.template device_context<platform::XPUDeviceContext>();
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
USE_OP(relu); USE_OP(relu);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(softmax); USE_OP_ITSELF(softmax);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -220,4 +220,11 @@ class GPUContext : public DeviceContext { ...@@ -220,4 +220,11 @@ class GPUContext : public DeviceContext {
std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> impl_;
}; };
// Note: In order to register the kernel of CUDNN, GPUDNNContext is required.
// Currently, CUDNN kernel directly uses GPUContext. But if the kernel function
// has the same name, this will lead to duplicate instantiations of GPU kernel
// and GPUDNN kernel function, so if we using GPUDNNContext = GPUContext, we
// must use different function name for cudnn kernel
using GPUDNNContext = GPUContext;
} // namespace phi } // namespace phi
...@@ -50,7 +50,7 @@ enum class Backend : uint8_t { ...@@ -50,7 +50,7 @@ enum class Backend : uint8_t {
// the third library backend // the third library backend
MKLDNN, MKLDNN,
CUDNN, GPUDNN, // cuDNN and hipDNN
// end of backend types // end of backend types
NUM_BACKENDS, NUM_BACKENDS,
...@@ -112,8 +112,8 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) { ...@@ -112,8 +112,8 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) {
case Backend::MKLDNN: case Backend::MKLDNN:
os << "MKLDNN"; os << "MKLDNN";
break; break;
case Backend::CUDNN: case Backend::GPUDNN:
os << "CUDNN"; os << "GPUDNN";
break; break;
default: { default: {
size_t device_type_id_ = static_cast<size_t>(backend) - size_t device_type_id_ = static_cast<size_t>(backend) -
...@@ -145,8 +145,8 @@ inline Backend StringToBackend(const char* backend_cstr) { ...@@ -145,8 +145,8 @@ inline Backend StringToBackend(const char* backend_cstr) {
return Backend::NPU; return Backend::NPU;
} else if (s == std::string("MKLDNN")) { } else if (s == std::string("MKLDNN")) {
return Backend::MKLDNN; return Backend::MKLDNN;
} else if (s == std::string("CUDNN")) { } else if (s == std::string("GPUDNN")) {
return Backend::CUDNN; return Backend::GPUDNN;
} else { } else {
return static_cast<Backend>(static_cast<size_t>(Backend::NUM_BACKENDS) + return static_cast<Backend>(static_cast<size_t>(Backend::NUM_BACKENDS) +
phi::GetOrRegisterGlobalDeviceTypeId(s)); phi::GetOrRegisterGlobalDeviceTypeId(s));
......
...@@ -988,6 +988,18 @@ inline std::ostream& operator<<(std::ostream& os, const float16& a) { ...@@ -988,6 +988,18 @@ inline std::ostream& operator<<(std::ostream& os, const float16& a) {
return os; return os;
} }
template <typename T>
class MPTypeTrait {
public:
using Type = T;
};
template <>
class MPTypeTrait<float16> {
public:
using Type = float;
};
} // namespace dtype } // namespace dtype
} // namespace phi } // namespace phi
......
...@@ -58,7 +58,7 @@ phi::Place TransToPtenPlace(const Backend& backend, bool set_device_id) { ...@@ -58,7 +58,7 @@ phi::Place TransToPtenPlace(const Backend& backend, bool set_device_id) {
return phi::CPUPlace(); return phi::CPUPlace();
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
case phi::Backend::CUDNN: case phi::Backend::GPUDNN:
return phi::GPUPlace( return phi::GPUPlace(
set_device_id ? phi::backends::gpu::GetCurrentDeviceId() : 0); set_device_id ? phi::backends::gpu::GetCurrentDeviceId() : 0);
#endif #endif
......
...@@ -15,8 +15,15 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function i ...@@ -15,8 +15,15 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function i
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
# auto build kernel targets by cmake # NOTE: Some kernels depend on some targets that are not commonly used.
register_kernels(DEPS ${COMMON_KERNEL_DEPS}) # These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS softmax_kernel softmax_grad_kernel)
kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
# auto parse and build kernel targets by cmake
register_kernels(EXCLUDES ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS})
# phi sparse kernels # phi sparse kernels
add_subdirectory(sparse) add_subdirectory(sparse)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/softmax_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/softmax_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
softmax_grad, CPU, ALL_LAYOUT, phi::SoftmaxGradKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/softmax_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/softmax_kernel_impl.h"
PD_REGISTER_KERNEL(
softmax, CPU, ALL_LAYOUT, phi::SoftmaxRawKernel, float, double) {}
...@@ -12,16 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,16 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h" #pragma once
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/ddim.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace phi {
REGISTER_OP_CUDA_KERNEL( namespace funcs {
softmax, ops::SoftmaxKernel<plat::CUDADeviceContext, float>,
ops::SoftmaxKernel<plat::CUDADeviceContext, double>, static inline int CanonicalAxis(const int axis, const int rank) {
ops::SoftmaxKernel<plat::CUDADeviceContext, plat::float16>); if (axis < 0) {
REGISTER_OP_CUDA_KERNEL( return axis + rank;
softmax_grad, ops::SoftmaxGradKernel<plat::CUDADeviceContext, float>, }
ops::SoftmaxGradKernel<plat::CUDADeviceContext, double>, return axis;
ops::SoftmaxGradKernel<plat::CUDADeviceContext, plat::float16>); }
static inline int SizeToAxis(const int axis, DDim dims) {
int size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline int SizeFromAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
static inline int SizeOutAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
} // namespace funcs
} // namespace phi
...@@ -92,4 +92,4 @@ static inline phi::DDim ComputeAndCheckShape( ...@@ -92,4 +92,4 @@ static inline phi::DDim ComputeAndCheckShape(
} }
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -55,5 +55,5 @@ struct EigenSub<Eigen::GpuDevice, T> { ...@@ -55,5 +55,5 @@ struct EigenSub<Eigen::GpuDevice, T> {
template struct EigenSub<Eigen::GpuDevice, float>; template struct EigenSub<Eigen::GpuDevice, float>;
} // namespace fucns } // namespace funcs
} // namespace phi } // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/softmax_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/softmax_grad_kernel_impl.h"
PD_REGISTER_KERNEL(softmax_grad,
GPU,
ALL_LAYOUT,
phi::SoftmaxGradKernel,
float,
double,
phi::dtype::float16) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/softmax_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/softmax_kernel_impl.h"
PD_REGISTER_KERNEL(softmax,
GPU,
ALL_LAYOUT,
phi::SoftmaxRawKernel,
float,
double,
phi::dtype::float16) {}
...@@ -14,18 +14,20 @@ limitations under the License. */ ...@@ -14,18 +14,20 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/phi/common/float16.h"
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
namespace paddle { namespace phi {
namespace operators {
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using ScopedTensorDescriptor = paddle::platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout; using GPUDNNDataLayout = paddle::platform::DataLayout;
using Tensor = framework::Tensor;
// Vectorization trait 4 * sizeof(T) // Vectorization trait 4 * sizeof(T)
template <typename T> template <typename T>
...@@ -41,7 +43,7 @@ class VecT4<float> { ...@@ -41,7 +43,7 @@ class VecT4<float> {
using Type = int4; using Type = int4;
}; };
template <> template <>
class VecT4<platform::float16> { class VecT4<phi::dtype::float16> {
public: public:
using Type = int2; using Type = int2;
}; };
...@@ -60,7 +62,7 @@ class VecT2<float> { ...@@ -60,7 +62,7 @@ class VecT2<float> {
using Type = int2; using Type = int2;
}; };
template <> template <>
class VecT2<platform::float16> { class VecT2<phi::dtype::float16> {
public: public:
using Type = int; using Type = int;
}; };
...@@ -77,7 +79,8 @@ __device__ __forceinline__ void WarpReduceSum(T* sum) { ...@@ -77,7 +79,8 @@ __device__ __forceinline__ void WarpReduceSum(T* sum) {
for (int offset = WarpSize / 2; offset > 0; offset /= 2) { for (int offset = WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < BatchSize; ++i) { for (int i = 0; i < BatchSize; ++i) {
T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); T sum_val =
paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
sum[i] = sum[i] + sum_val; sum[i] = sum[i] + sum_val;
} }
} }
...@@ -89,14 +92,13 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) { ...@@ -89,14 +92,13 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) {
for (int offset = WarpSize / 2; offset > 0; offset /= 2) { for (int offset = WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < BatchSize; ++i) { for (int i = 0; i < BatchSize; ++i) {
T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); T max_val =
paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
sum[i] = max(sum[i], max_val); sum[i] = max(sum[i], max_val);
} }
} }
} }
namespace kps = paddle::operators::kernel_primitives;
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct ReduceMaxFunctor { struct ReduceMaxFunctor {
inline Ty initial() { return -std::numeric_limits<Ty>::infinity(); } inline Ty initial() { return -std::numeric_limits<Ty>::infinity(); }
...@@ -248,10 +250,15 @@ One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). ...@@ -248,10 +250,15 @@ One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp. api to compute max (sum) in one warp.
*/ */
template <typename T, typename VecT, typename AccT, int Log2Elements, template <typename T,
typename VecT,
typename AccT,
int Log2Elements,
bool LogMode = false> bool LogMode = false>
__global__ void WarpSoftmaxForward(T* softmax, const T* src, __global__ void WarpSoftmaxForward(T* softmax,
const int batch_size, const int stride, const T* src,
const int batch_size,
const int stride,
const int element_count) { const int element_count) {
constexpr int kDimCeil = 1 << Log2Elements; constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
...@@ -302,9 +309,13 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, ...@@ -302,9 +309,13 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
} }
// compute max // compute max
kps::Reduce<AccT, kVItem, kBatchSize, 1, ReduceMaxFunctor<AccT>, kps::Reduce<AccT,
kMode::kLocalMode>(&max[0], &srcdata[0][0][0], kVItem,
ReduceMaxFunctor<AccT>(), true); kBatchSize,
1,
ReduceMaxFunctor<AccT>,
kMode::kLocalMode>(
&max[0], &srcdata[0][0][0], ReduceMaxFunctor<AccT>(), true);
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max); WarpReduceMax<AccT, kBatchSize, kWarpSize>(max);
// compute sum // compute sum
...@@ -313,9 +324,13 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, ...@@ -313,9 +324,13 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpSubFunctor<AccT>>( kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpSubFunctor<AccT>>(
&srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor<AccT>(max[i])); &srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor<AccT>(max[i]));
} }
kps::Reduce<AccT, kVItem, kBatchSize, 1, kps::AddFunctor<AccT>, kps::Reduce<AccT,
kMode::kLocalMode>(&sum[0], &srcdata[0][0][0], kVItem,
kps::AddFunctor<AccT>(), true); kBatchSize,
1,
kps::AddFunctor<AccT>,
kMode::kLocalMode>(
&sum[0], &srcdata[0][0][0], kps::AddFunctor<AccT>(), true);
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum); WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write data to global memory // write data to global memory
...@@ -340,10 +355,16 @@ One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). ...@@ -340,10 +355,16 @@ One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp. api to compute max (sum) in one warp.
*/ */
template <typename T, typename VecT, typename AccT, int Log2Elements, template <typename T,
typename VecT,
typename AccT,
int Log2Elements,
bool LogMode = false> bool LogMode = false>
__global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, __global__ void WarpSoftmaxBackward(T* dst,
int batch_size, int stride, const T* grad,
const T* src,
int batch_size,
int stride,
int element_count) { int element_count) {
constexpr int kVSize = sizeof(VecT) / sizeof(T); constexpr int kVSize = sizeof(VecT) / sizeof(T);
constexpr int kDimCeil = 1 << Log2Elements; constexpr int kDimCeil = 1 << Log2Elements;
...@@ -403,7 +424,11 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, ...@@ -403,7 +424,11 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src,
AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]); AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]);
kps::ElementwiseBinary<AccT, AccT, kStep, 1, 1, kps::MulFunctor<AccT>>( kps::ElementwiseBinary<AccT, AccT, kStep, 1, 1, kps::MulFunctor<AccT>>(
&sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>()); &sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>());
kps::Reduce<AccT, kVItem, kBatchSize, 1, kps::AddFunctor<AccT>, kps::Reduce<AccT,
kVItem,
kBatchSize,
1,
kps::AddFunctor<AccT>,
kps::details::ReduceMode::kLocalMode>( kps::details::ReduceMode::kLocalMode>(
&sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true); &sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum); WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
...@@ -429,7 +454,10 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, ...@@ -429,7 +454,10 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src,
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \ #define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \
case Log2Elements: \ case Log2Elements: \
WarpSoftmaxForward<T, VecT, AccT, Log2Elements, \ WarpSoftmaxForward<T, \
VecT, \
AccT, \
Log2Elements, \
LogMode><<<blocks, threads, 0, dev_ctx.stream()>>>( \ LogMode><<<blocks, threads, 0, dev_ctx.stream()>>>( \
dst, src, batch_size, stride, element_count); \ dst, src, batch_size, stride, element_count); \
break; break;
...@@ -438,12 +466,16 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, ...@@ -438,12 +466,16 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src,
Wrapper of softmax formward with template instantiation on size of input. Wrapper of softmax formward with template instantiation on size of input.
*/ */
template <typename T, typename VecT, bool LogMode> template <typename T, typename VecT, bool LogMode>
void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads, void SwitchWarpSoftmaxForward(const int blocks,
const platform::CUDADeviceContext& dev_ctx, const dim3 threads,
T* dst, const T* src, const int batch_size, const GPUContext& dev_ctx,
const int stride, const int element_count, T* dst,
const T* src,
const int batch_size,
const int stride,
const int element_count,
int Log2Elements) { int Log2Elements) {
using AccT = typename details::MPTypeTrait<T>::Type; using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
switch (Log2Elements) { switch (Log2Elements) {
SOFTMAX_WARP_FORWARD_CASE(0, AccT); SOFTMAX_WARP_FORWARD_CASE(0, AccT);
SOFTMAX_WARP_FORWARD_CASE(1, AccT); SOFTMAX_WARP_FORWARD_CASE(1, AccT);
...@@ -462,7 +494,10 @@ void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads, ...@@ -462,7 +494,10 @@ void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads,
#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT) \ #define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT) \
case Log2Elements: \ case Log2Elements: \
WarpSoftmaxBackward<T, VecT, AccT, Log2Elements, \ WarpSoftmaxBackward<T, \
VecT, \
AccT, \
Log2Elements, \
LogMode><<<blocks, threads, 0, dev_ctx.stream()>>>( \ LogMode><<<blocks, threads, 0, dev_ctx.stream()>>>( \
dst, grad, src, batch_size, stride, element_count); \ dst, grad, src, batch_size, stride, element_count); \
break; break;
...@@ -471,12 +506,17 @@ void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads, ...@@ -471,12 +506,17 @@ void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads,
Wrapper of softmax backward with template instantiation on size of input. Wrapper of softmax backward with template instantiation on size of input.
*/ */
template <typename T, typename VecT, bool LogMode> template <typename T, typename VecT, bool LogMode>
void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads, void SwitchWarpSoftmaxBackward(const int blocks,
const platform::CUDADeviceContext& dev_ctx, const dim3 threads,
T* dst, const T* grad, const T* src, const GPUContext& dev_ctx,
const int batch_size, const int stride, T* dst,
const int element_count, int Log2Elements) { const T* grad,
using AccT = typename details::MPTypeTrait<T>::Type; const T* src,
const int batch_size,
const int stride,
const int element_count,
int Log2Elements) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
switch (Log2Elements) { switch (Log2Elements) {
SOFTMAX_WARP_BACKWARD_CASE(0, AccT); SOFTMAX_WARP_BACKWARD_CASE(0, AccT);
SOFTMAX_WARP_BACKWARD_CASE(1, AccT); SOFTMAX_WARP_BACKWARD_CASE(1, AccT);
...@@ -501,12 +541,12 @@ void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads, ...@@ -501,12 +541,12 @@ void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads,
* Better performence when axis != -1 * Better performence when axis != -1
*/ */
static void GetGridDim(int high_dim, int mid_dim, int low_dim, static void GetGridDim(
const dim3& block, dim3* grid) { int high_dim, int mid_dim, int low_dim, const dim3& block, dim3* grid) {
int device_id = paddle::platform::GetCurrentDeviceId(); int device_id = phi::backends::gpu::GetCurrentDeviceId();
int max_mp = paddle::platform::GetGPUMultiProcessors(device_id); int max_mp = phi::backends::gpu::GetGPUMultiProcessors(device_id);
int max_threads_per_mp = int max_threads_per_mp =
paddle::platform::GetGPUMaxThreadsPerMultiProcessor(device_id); phi::backends::gpu::GetGPUMaxThreadsPerMultiProcessor(device_id);
int max_threads = max_threads_per_mp * max_mp; int max_threads = max_threads_per_mp * max_mp;
int num_threads = block.x * block.y; int num_threads = block.x * block.y;
int max_num_blocks = max_threads / num_threads; int max_num_blocks = max_threads / num_threads;
...@@ -532,16 +572,17 @@ static void GetBlockDim(int mid_dim, int low_dim, dim3* block) { ...@@ -532,16 +572,17 @@ static void GetBlockDim(int mid_dim, int low_dim, dim3* block) {
block->x = std::min(block_x, static_cast<int>(max_num_threads / block->y)); block->x = std::min(block_x, static_cast<int>(max_num_threads / block->y));
} }
static void GetLaunchConfig(int high_dim, int mid_dim, int low_dim, dim3* grid, static void GetLaunchConfig(
dim3* block) { int high_dim, int mid_dim, int low_dim, dim3* grid, dim3* block) {
GetBlockDim(mid_dim, low_dim, block); GetBlockDim(mid_dim, low_dim, block);
GetGridDim(high_dim, mid_dim, low_dim, *block, grid); GetGridDim(high_dim, mid_dim, low_dim, *block, grid);
} }
template <typename T, typename AccT, template <typename T,
typename AccT,
template <typename, typename> class Functor> template <typename, typename> class Functor>
__global__ void NormalSoftmaxForward(T* output, const T* input, int high_dim, __global__ void NormalSoftmaxForward(
int mid_dim, int low_dim) { T* output, const T* input, int high_dim, int mid_dim, int low_dim) {
using kMode = kps::details::ReduceMode; using kMode = kps::details::ReduceMode;
const int high_stride = mid_dim * low_dim; const int high_stride = mid_dim * low_dim;
const int mid_stride = low_dim; const int mid_stride = low_dim;
...@@ -584,11 +625,15 @@ __global__ void NormalSoftmaxForward(T* output, const T* input, int high_dim, ...@@ -584,11 +625,15 @@ __global__ void NormalSoftmaxForward(T* output, const T* input, int high_dim,
} }
} }
template <typename T, typename AccT, template <typename T,
typename AccT,
template <typename, typename> class Functor> template <typename, typename> class Functor>
__global__ void NormalSoftmaxBackward(T* input_grad, const T* output_grad, __global__ void NormalSoftmaxBackward(T* input_grad,
const T* output, int high_dim, const T* output_grad,
int mid_dim, int low_dim) { const T* output,
int high_dim,
int mid_dim,
int low_dim) {
using kMode = kps::details::ReduceMode; using kMode = kps::details::ReduceMode;
const int high_stride = mid_dim * low_dim; const int high_stride = mid_dim * low_dim;
const int mid_stride = low_dim; const int mid_stride = low_dim;
...@@ -622,58 +667,79 @@ __global__ void NormalSoftmaxBackward(T* input_grad, const T* output_grad, ...@@ -622,58 +667,79 @@ __global__ void NormalSoftmaxBackward(T* input_grad, const T* output_grad,
} }
template <typename T, bool LogMode = false> template <typename T, bool LogMode = false>
void LaunchNormalSoftmaxForward(const platform::CUDADeviceContext& dev_ctx, void LaunchNormalSoftmaxForward(const GPUContext& dev_ctx,
T* output_data, const T* input_data, T* output_data,
int high_dim, int mid_dim, int low_dim) { const T* input_data,
using AccT = typename details::MPTypeTrait<T>::Type; int high_dim,
int mid_dim,
int low_dim) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
dim3 grid, block; dim3 grid, block;
GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block); GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
if (LogMode) { if (LogMode) {
NormalSoftmaxForward< NormalSoftmaxForward<
T, AccT, T,
AccT,
LogSoftmaxForwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>( LogSoftmaxForwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
output_data, input_data, high_dim, mid_dim, low_dim); output_data, input_data, high_dim, mid_dim, low_dim);
} else { } else {
NormalSoftmaxForward< NormalSoftmaxForward<
T, AccT, SoftmaxForwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>( T,
AccT,
SoftmaxForwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
output_data, input_data, high_dim, mid_dim, low_dim); output_data, input_data, high_dim, mid_dim, low_dim);
} }
} }
template <typename T, bool LogMode = false> template <typename T, bool LogMode = false>
void LaunchNormalSoftmaxBackward(const platform::CUDADeviceContext& dev_ctx, void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
T* input_grad_data, const T* output_grad_data, T* input_grad_data,
const T* output_data, int high_dim, const T* output_grad_data,
int mid_dim, int low_dim) { const T* output_data,
using AccT = typename details::MPTypeTrait<T>::Type; int high_dim,
int mid_dim,
int low_dim) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
dim3 grid, block; dim3 grid, block;
GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block); GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
if (LogMode) { if (LogMode) {
NormalSoftmaxBackward< NormalSoftmaxBackward<
T, AccT, T,
AccT,
LogSoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>( LogSoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
input_grad_data, output_grad_data, output_data, high_dim, mid_dim, input_grad_data,
output_grad_data,
output_data,
high_dim,
mid_dim,
low_dim); low_dim);
} else { } else {
NormalSoftmaxBackward< NormalSoftmaxBackward<
T, AccT, SoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>( T,
input_grad_data, output_grad_data, output_data, high_dim, mid_dim, AccT,
SoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
input_grad_data,
output_grad_data,
output_data,
high_dim,
mid_dim,
low_dim); low_dim);
} }
} }
template <typename T, bool LogMode = false> template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
const Tensor& x, const int input_axis, const DenseTensor& x,
Tensor* out) { const int input_axis,
DenseTensor* out) {
auto* out_data = out->data<T>(); auto* out_data = out->data<T>();
auto dims = x.dims(); auto dims = x.dims();
const int rank = dims.size(); const int rank = dims.size();
const int axis = CanonicalAxis(input_axis, rank); const int axis = phi::funcs::CanonicalAxis(input_axis, rank);
const int dim = dims[axis]; const int dim = dims[axis];
const int N = SizeToAxis(axis, dims); const int N = phi::funcs::SizeToAxis(axis, dims);
const int D = SizeOutAxis(axis, dims); const int D = phi::funcs::SizeOutAxis(axis, dims);
constexpr int max_dim = 512; constexpr int max_dim = 512;
constexpr int warps_per_block = 4; constexpr int warps_per_block = 4;
...@@ -697,25 +763,43 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, ...@@ -697,25 +763,43 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
using T2 = typename VecT2<T>::Type; using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) { if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks, threads, dev_ctx, SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks,
out_data, x.data<T>(), N, dim, threads,
dim, kDimLog2); dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
kDimLog2);
} else if (dim % 2 == 0) { } else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks, threads, dev_ctx, SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks,
out_data, x.data<T>(), N, dim, threads,
dim, kDimLog2); dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
kDimLog2);
} else { } else {
SwitchWarpSoftmaxForward<T, T, LogMode>(blocks, threads, dev_ctx, SwitchWarpSoftmaxForward<T, T, LogMode>(blocks,
out_data, x.data<T>(), N, dim, threads,
dim, kDimLog2); dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
kDimLog2);
} }
} else if (D > 1) { } else if (D > 1) {
LaunchNormalSoftmaxForward<T, LogMode>(dev_ctx, out_data, x.data<T>(), N, LaunchNormalSoftmaxForward<T, LogMode>(
dim, D); dev_ctx, out_data, x.data<T>(), N, dim, D);
} else { } else {
ScopedTensorDescriptor desc; ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1}; std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW; GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims); miopenTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
#else #else
...@@ -728,46 +812,74 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, ...@@ -728,46 +812,74 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
: MIOPEN_SOFTMAX_MODE_CHANNEL; : MIOPEN_SOFTMAX_MODE_CHANNEL;
if (LogMode) { if (LogMode) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( PADDLE_ENFORCE_GPU_SUCCESS(
handle, platform::CudnnDataType<T>::kOne(), desc_, x.data<T>(), paddle::platform::dynload::miopenSoftmaxForward_V2(
platform::CudnnDataType<T>::kZero(), desc_, out_data, handle,
MIOPEN_SOFTMAX_LOG, mode)); paddle::platform::CudnnDataType<T>::kOne(),
desc_,
x.data<T>(),
paddle::platform::CudnnDataType<T>::kZero(),
desc_,
out_data,
MIOPEN_SOFTMAX_LOG,
mode));
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( PADDLE_ENFORCE_GPU_SUCCESS(
handle, platform::CudnnDataType<T>::kOne(), desc_, x.data<T>(), paddle::platform::dynload::miopenSoftmaxForward_V2(
platform::CudnnDataType<T>::kZero(), desc_, out_data, handle,
MIOPEN_SOFTMAX_ACCURATE, mode)); paddle::platform::CudnnDataType<T>::kOne(),
desc_,
x.data<T>(),
paddle::platform::CudnnDataType<T>::kZero(),
desc_,
out_data,
MIOPEN_SOFTMAX_ACCURATE,
mode));
} }
#else #else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL; : CUDNN_SOFTMAX_MODE_CHANNEL;
if (LogMode) { if (LogMode) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward( PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cudnnSoftmaxForward(
handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(), handle,
desc_, x.data<T>(), platform::CudnnDataType<T>::kZero(), desc_, CUDNN_SOFTMAX_LOG,
mode,
paddle::platform::CudnnDataType<T>::kOne(),
desc_,
x.data<T>(),
paddle::platform::CudnnDataType<T>::kZero(),
desc_,
out_data)); out_data));
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward( PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cudnnSoftmaxForward(
handle, CUDNN_SOFTMAX_ACCURATE, mode, handle,
platform::CudnnDataType<T>::kOne(), desc_, x.data<T>(), CUDNN_SOFTMAX_ACCURATE,
platform::CudnnDataType<T>::kZero(), desc_, out_data)); mode,
paddle::platform::CudnnDataType<T>::kOne(),
desc_,
x.data<T>(),
paddle::platform::CudnnDataType<T>::kZero(),
desc_,
out_data));
} }
#endif #endif
} }
} }
template <typename T, bool LogMode = false> template <typename T, bool LogMode = false>
void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
const Tensor& out, const Tensor& dout, const DenseTensor& out,
const int input_axis, Tensor* dx) { const DenseTensor& dout,
const int input_axis,
DenseTensor* dx) {
auto* dx_data = dx->data<T>(); auto* dx_data = dx->data<T>();
auto dims = out.dims(); auto dims = out.dims();
const int rank = dims.size(); const int rank = dims.size();
const int axis = CanonicalAxis(input_axis, rank); const int axis = phi::funcs::CanonicalAxis(input_axis, rank);
const int dim = dims[axis]; const int dim = dims[axis];
const int N = SizeToAxis(axis, dims); const int N = phi::funcs::SizeToAxis(axis, dims);
const int D = SizeOutAxis(axis, dims); const int D = phi::funcs::SizeOutAxis(axis, dims);
constexpr int max_dim = 512; constexpr int max_dim = 512;
constexpr int warps_per_block = 4; constexpr int warps_per_block = 4;
...@@ -788,25 +900,46 @@ void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, ...@@ -788,25 +900,46 @@ void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
using T4 = typename VecT4<T>::Type; using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type; using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) { if (dim % 4 == 0) {
SwitchWarpSoftmaxBackward<T, T4, LogMode>( SwitchWarpSoftmaxBackward<T, T4, LogMode>(blocks,
blocks, threads, dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, threads,
dim, dim, kDimLog2); dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
kDimLog2);
} else if (dim % 2 == 0) { } else if (dim % 2 == 0) {
SwitchWarpSoftmaxBackward<T, T2, LogMode>( SwitchWarpSoftmaxBackward<T, T2, LogMode>(blocks,
blocks, threads, dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, threads,
dim, dim, kDimLog2); dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
kDimLog2);
} else { } else {
SwitchWarpSoftmaxBackward<T, T, LogMode>( SwitchWarpSoftmaxBackward<T, T, LogMode>(blocks,
blocks, threads, dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, threads,
dim, dim, kDimLog2); dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
kDimLog2);
} }
} else if (D > 1) { } else if (D > 1) {
LaunchNormalSoftmaxBackward<T, LogMode>(dev_ctx, dx_data, dout.data<T>(), LaunchNormalSoftmaxBackward<T, LogMode>(
out.data<T>(), N, dim, D); dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
} else { } else {
ScopedTensorDescriptor desc; ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1}; std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW; GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims); miopenTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
#else #else
...@@ -819,33 +952,68 @@ void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, ...@@ -819,33 +952,68 @@ void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
: MIOPEN_SOFTMAX_MODE_CHANNEL; : MIOPEN_SOFTMAX_MODE_CHANNEL;
if (LogMode) { if (LogMode) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2( PADDLE_ENFORCE_GPU_SUCCESS(
handle, platform::CudnnDataType<T>::kOne(), desc_, out.data<T>(), paddle::platform::dynload::miopenSoftmaxBackward_V2(
desc_, dout.data<T>(), platform::CudnnDataType<T>::kZero(), desc_, handle,
dx_data, MIOPEN_SOFTMAX_LOG, mode)); paddle::platform::CudnnDataType<T>::kOne(),
desc_,
out.data<T>(),
desc_,
dout.data<T>(),
paddle::platform::CudnnDataType<T>::kZero(),
desc_,
dx_data,
MIOPEN_SOFTMAX_LOG,
mode));
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2( PADDLE_ENFORCE_GPU_SUCCESS(
handle, platform::CudnnDataType<T>::kOne(), desc_, out.data<T>(), paddle::platform::dynload::miopenSoftmaxBackward_V2(
desc_, dout.data<T>(), platform::CudnnDataType<T>::kZero(), desc_, handle,
dx_data, MIOPEN_SOFTMAX_ACCURATE, mode)); paddle::platform::CudnnDataType<T>::kOne(),
desc_,
out.data<T>(),
desc_,
dout.data<T>(),
paddle::platform::CudnnDataType<T>::kZero(),
desc_,
dx_data,
MIOPEN_SOFTMAX_ACCURATE,
mode));
} }
#else #else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL; : CUDNN_SOFTMAX_MODE_CHANNEL;
if (LogMode) { if (LogMode) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxBackward( PADDLE_ENFORCE_GPU_SUCCESS(
handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(), paddle::platform::dynload::cudnnSoftmaxBackward(
desc_, out.data<T>(), desc_, dout.data<T>(), handle,
platform::CudnnDataType<T>::kZero(), desc_, dx_data)); CUDNN_SOFTMAX_LOG,
mode,
paddle::platform::CudnnDataType<T>::kOne(),
desc_,
out.data<T>(),
desc_,
dout.data<T>(),
paddle::platform::CudnnDataType<T>::kZero(),
desc_,
dx_data));
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxBackward( PADDLE_ENFORCE_GPU_SUCCESS(
handle, CUDNN_SOFTMAX_ACCURATE, mode, paddle::platform::dynload::cudnnSoftmaxBackward(
platform::CudnnDataType<T>::kOne(), desc_, out.data<T>(), desc_, handle,
dout.data<T>(), platform::CudnnDataType<T>::kZero(), desc_, dx_data)); CUDNN_SOFTMAX_ACCURATE,
mode,
paddle::platform::CudnnDataType<T>::kOne(),
desc_,
out.data<T>(),
desc_,
dout.data<T>(),
paddle::platform::CudnnDataType<T>::kZero(),
desc_,
dx_data));
} }
#endif #endif
} }
} }
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/softmax_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace phi {
template <typename T, typename Context>
void SoftmaxGradGPUDNNKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx, out, out_grad, axis, x_grad);
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(softmax_grad,
GPUDNN,
ALL_LAYOUT,
phi::SoftmaxGradGPUDNNKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(softmax_grad,
GPUDNN,
ALL_LAYOUT,
phi::SoftmaxGradGPUDNNKernel,
float,
double,
phi::dtype::float16) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/softmax_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace phi {
template <typename T, typename Context>
void SoftmaxRawGPUDNNKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
SoftmaxForwardCUDAKernelDriver<T>(dev_ctx, x, axis, out);
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(softmax,
GPUDNN,
ALL_LAYOUT,
phi::SoftmaxRawGPUDNNKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(softmax,
GPUDNN,
ALL_LAYOUT,
phi::SoftmaxRawGPUDNNKernel,
float,
double,
phi::dtype::float16) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/softmax_grad_kernel.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace phi {
template <typename T, typename Context>
void SoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad) {
const int rank = x_grad->dims().size();
const int calc_axis = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = x_grad->dims()[calc_axis];
// allocate memory on device.
dev_ctx.template Alloc<T>(x_grad);
if (x_grad->numel() == 0) {
return;
}
const int n = phi::funcs::SizeToAxis(calc_axis, x_grad->dims());
const int d = phi::funcs::SizeFromAxis(calc_axis, x_grad->dims());
DenseTensor dX_2d, Out_2d, dOut_2d;
dX_2d.ShareDataWith(*x_grad).Resize({n, d});
Out_2d.ShareDataWith(out).Resize({n, d});
dOut_2d.ShareDataWith(out_grad).Resize({n, d});
paddle::operators::math::SoftmaxGradFunctor<Context, T>()(
dev_ctx, axis_dim, &Out_2d, &dOut_2d, &dX_2d);
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/softmax_kernel.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace phi {
template <typename T, typename Context>
void SoftmaxRawKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out) {
const int rank = x.dims().size();
const int calc_axis = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = x.dims()[calc_axis];
// allocate memory on device.
dev_ctx.template Alloc<T>(out);
if (out->numel() == 0) {
return;
}
const int n = phi::funcs::SizeToAxis(calc_axis, x.dims());
const int d = phi::funcs::SizeFromAxis(calc_axis, x.dims());
DenseTensor X_2d, Out_2d;
X_2d.ShareDataWith(x).Resize({n, d});
Out_2d.ShareDataWith(*out).Resize({n, d});
paddle::operators::math::SoftmaxFunctor<Context, T, false>()(
dev_ctx, axis_dim, &X_2d, &Out_2d);
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace phi {
template <typename T, typename Context>
void SoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace phi {
template <typename T, typename Context>
void SoftmaxRawKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void SoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DataType dtype,
DenseTensor* out) {
auto cast_x = phi::Cast<T, Context>(dev_ctx, x, dtype);
phi::SoftmaxRawKernel<T, Context>(dev_ctx, axis, out);
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SoftmaxOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("softmax", {"X"}, {"axis"}, {"Out"});
}
KernelSignature SoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("softmax_grad",
{"Out", GradVarName("Out")},
{"axis"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(softmax, phi::SoftmaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(softmax_grad, phi::SoftmaxGradOpArgumentMapping);
...@@ -41,8 +41,8 @@ TEST(Backend, OStream) { ...@@ -41,8 +41,8 @@ TEST(Backend, OStream) {
oss << phi::Backend::MKLDNN; oss << phi::Backend::MKLDNN;
EXPECT_EQ(oss.str(), "MKLDNN"); EXPECT_EQ(oss.str(), "MKLDNN");
oss.str(""); oss.str("");
oss << phi::Backend::CUDNN; oss << phi::Backend::GPUDNN;
EXPECT_EQ(oss.str(), "CUDNN"); EXPECT_EQ(oss.str(), "GPUDNN");
oss.str(""); oss.str("");
try { try {
oss << phi::Backend::NUM_BACKENDS; oss << phi::Backend::NUM_BACKENDS;
...@@ -60,7 +60,7 @@ TEST(Backend, StringToBackend) { ...@@ -60,7 +60,7 @@ TEST(Backend, StringToBackend) {
EXPECT_EQ(phi::Backend::XPU, pexp::StringToBackend("XPU")); EXPECT_EQ(phi::Backend::XPU, pexp::StringToBackend("XPU"));
EXPECT_EQ(phi::Backend::NPU, pexp::StringToBackend("NPU")); EXPECT_EQ(phi::Backend::NPU, pexp::StringToBackend("NPU"));
EXPECT_EQ(phi::Backend::MKLDNN, pexp::StringToBackend("MKLDNN")); EXPECT_EQ(phi::Backend::MKLDNN, pexp::StringToBackend("MKLDNN"));
EXPECT_EQ(phi::Backend::CUDNN, pexp::StringToBackend("CUDNN")); EXPECT_EQ(phi::Backend::GPUDNN, pexp::StringToBackend("GPUDNN"));
EXPECT_EQ(static_cast<phi::Backend>( EXPECT_EQ(static_cast<phi::Backend>(
static_cast<size_t>(phi::Backend::NUM_BACKENDS) + 1), static_cast<size_t>(phi::Backend::NUM_BACKENDS) + 1),
pexp::StringToBackend("CustomBackend")); pexp::StringToBackend("CustomBackend"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册