From 5106fbf5c21aa3b738d8cf288b1a439b853b9323 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Fri, 4 Jun 2021 11:07:59 +0800 Subject: [PATCH] Revert "optimize softmax with cross entropy hard label (#32290)" This reverts commit 7be6191bee6c6f3c1af8b93f989d8fa242844a6b. --- .../softmax_with_cross_entropy_op.cu | 796 +++++++----------- 1 file changed, 309 insertions(+), 487 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 8fe456edea..4aec4c1742 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -15,481 +15,44 @@ limitations under the License. */ #include namespace cub = hipcub; #endif -#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/operators/softmax_impl.cuh" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include "paddle/fluid/platform/for_range.h" -#ifdef PADDLE_WITH_HIP -#include "paddle/fluid/platform/miopen_helper.h" -#else -#include "paddle/fluid/platform/cudnn_helper.h" -#endif namespace paddle { namespace operators { -using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; -using DataLayout = platform::DataLayout; using Tensor = framework::Tensor; -// Wrapper of log function. Use log(float32) for float16 -template -static __device__ __forceinline__ T Log(T x) { - using AccT = typename details::MPTypeTrait::Type; - AccT logx = std::log(static_cast(x)); - return math::TolerableValue()(static_cast(logx)); -} - -// Wrapper of exp function. Use exp(float32) for float16 +namespace { template -static __device__ __forceinline__ T Exp(T x) { - using AccT = typename details::MPTypeTrait::Type; - AccT expx = std::exp(static_cast(x)); - return math::TolerableValue()(static_cast(expx)); -} - -// log2(value) -static inline int Log2Ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -enum class SoftmaxMode { kSoftmax, kLogSoftmax, kCrossEntropy }; - -/* - Hard label cross entropy. -*/ -template -__global__ void CrossEntropyHardLabel(T* loss, const T* softmax, - const int64_t* labels, const int n, - const int dim, const int d, - const int ignore_idx) { - int64_t ids = blockIdx.x * blockDim.x + threadIdx.x; - int64_t idx_n = ids / d; - int64_t idx_d = ids % d; - - // thread ids compute loss[ids] using softmax[idx] - if (ids < n * d) { - int64_t idx = idx_n * dim * d + labels[ids] * d + idx_d; - if (IgnoreIndex == true) { - // IgnoreIndex is true - if (labels[ids] == ignore_idx) { - loss[ids] = static_cast(0.0); - } else { - loss[ids] = -Log(softmax[idx]); - } - } else { - // IgnoreIndex is false - loss[ids] = -Log(softmax[idx]); - } - } -} - -/* - Hard label cross entropy with exp. - Input: log softmax - Output: loss and exp(input) -*/ -template -__global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, - const int64_t* labels, const int n, - const int dim, const int d, - const int ignore_idx) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - int64_t idx_n = idx / (d * dim); - int64_t idx_dim = (idx / d) % dim; - int64_t idx_d = idx % d; - int64_t ids = idx_n * d + idx_d; - - if (idx < n * dim * d) { - if (IgnoreIndex == true) { - // IgnoreIndex is true - if (idx_dim == labels[ids]) { - if (labels[ids] == ignore_idx) { - loss[ids] = static_cast(0.0); - } else { - loss[ids] = -softmax[idx]; - } - } - } else { - // IgnoreIndex is false - if (labels[ids] >= 0 && labels[ids] < dim) { - if (labels[ids] == idx_dim) { - loss[ids] = -softmax[idx]; - } - } else { - loss[ids] = static_cast(0.0); - } - } - softmax[idx] = Exp(softmax[idx]); - } -} - -/* - Core function of softmax with cross entropy forward - - softmax, SoftmaxMode=kSoftmax - - log softmax, SoftmaxMode=kLogSoftmax - - softmax with cross entropy hard label, SoftmaxMode=kCrossEntropy - The computation includes - - Compute max value: maxvalue_{i} = max_j src_{i,j} - - Compute sum of exp: s_{i} = sum_{j}{e^{src_{i,j} - maxvalue_{i}}} - - Compute: softmax_{i,j} = e^{src_{i,j} - maxvalue_{i}} / s_{i} - - Compute: logsoftmax_{i,j} = src_{i,j} - maxvalue_{i} - log(s_{i}) - - Compute: loss_{i} = -logsoftmax[i,label[i]] (Hard label) - This computation results from following formula: - softmax_{i,j} = e^{src_{i,j}} / sum_{j}{e^{src_{i,j}}} - = e^{src_{i,j} - maxvalue_{i}} - / sum_{j}{e^{src_{i,j} - maxvalue_{i}}} - = e^{src_{i,j} - maxvalue_{i}} / s_{i} - logsoftmax_{i,j} = log(softmax_{i,j}) - = src_{i,j} - maxvalue_{i} - log(s_{i}) - One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). - For reduction max (sum), firstly compute max (sum) to one warp, then use - shuffle api to compute max (sum) in one warp. -*/ -template -__global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, - const int64_t* label, const int batch_size, - const int stride, const int element_count, - const int ignore_index) { - constexpr int kDimCeil = 1 << Log2Elements; - constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; - constexpr int kVSize = sizeof(VecT) / sizeof(T); - constexpr int kIterations = kDimCeil / kWarpSize; - constexpr int kIterationsV = - (kIterations >= kVSize) ? (kIterations / kVSize) : 1; - constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1; - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; - - // max index to read - int idx_max_v[kBatchSize]; -#pragma unroll - for (int i = 0; i < kBatchSize; i++) { - int idx_max = ((i + first_batch) < batch_size) ? element_count : 0; - idx_max_v[i] = idx_max / kVSize; - } - - // read data from global memory - AccT srcdata[kBatchSize][kIterationsV][kVSize]; - -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { -// read data to srcdata: - KVSize==1, - KVSize>1 -#pragma unroll - for (int it = 0; it < kIterationsV; ++it) { - int src_idx = threadIdx.x + it * kWarpSize; - if (kVSize == 1) { - if (src_idx < idx_max_v[i]) { - srcdata[i][it][0] = - static_cast(src[(first_batch + i) * stride + src_idx]); - } else { - srcdata[i][it][0] = -std::numeric_limits::infinity(); - } - } else { - const VecT* src_v = - reinterpret_cast(&src[(first_batch + i) * stride]); - if (src_idx < idx_max_v[i]) { - VecT srctmp = src_v[src_idx]; - const T* srcinptr = reinterpret_cast(&srctmp); -#pragma unroll - for (int s = 0; s < kVSize; s++) { - srcdata[i][it][s] = static_cast(srcinptr[s]); - } - } else { -#pragma unroll - for (int s = 0; s < kVSize; s++) { - srcdata[i][it][s] = -std::numeric_limits::infinity(); - } - } - } - } - } - - // compute max value: maxvalue_{i} = max_j src_{i,j} - AccT max_value[kBatchSize]; -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { - // it = 0 - AccT valmax = srcdata[i][0][0]; -#pragma unroll - for (int s = 1; s < kVSize; ++s) { - valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s]; - } - max_value[i] = valmax; - -// it = 1, 2, ... -#pragma unroll - for (int it = 1; it < kIterationsV; ++it) { - AccT valmax = srcdata[i][it][0]; -#pragma unroll - for (int s = 1; s < kVSize; ++s) { - valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s]; - } - max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax; - } - } - WarpReduceMax(max_value); - - // compute sum: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} } - AccT sum[kBatchSize]; -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { - // it = 0 - if (mode == SoftmaxMode::kLogSoftmax || - mode == SoftmaxMode::kCrossEntropy) { - sum[i] = std::exp(srcdata[i][0][0] - max_value[i]); - } else { - srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]); - sum[i] = srcdata[i][0][0]; - } -#pragma unroll - for (int s = 1; s < kVSize; ++s) { - if (mode == SoftmaxMode::kLogSoftmax || - mode == SoftmaxMode::kCrossEntropy) { - sum[i] += std::exp(srcdata[i][0][s] - max_value[i]); - } else { - srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]); - sum[i] += srcdata[i][0][s]; - } - } - -// it = 1, 2, ... -#pragma unroll - for (int it = 1; it < kIterationsV; ++it) { -#pragma unroll - for (int s = 0; s < kVSize; ++s) { - if (mode == SoftmaxMode::kLogSoftmax || - mode == SoftmaxMode::kCrossEntropy) { - sum[i] += std::exp(srcdata[i][it][s] - max_value[i]); - } else { - srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]); - sum[i] += srcdata[i][it][s]; - } - } - } - } - WarpReduceSum(sum); - -// write data -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { - if (mode == SoftmaxMode::kLogSoftmax || - mode == SoftmaxMode::kCrossEntropy) { - sum[i] = std::log(sum[i]); - } - -#pragma unroll - for (int it = 0; it < kIterationsV; ++it) { - int idx = threadIdx.x + it * kWarpSize; - if (kVSize == 1) { // kVSize==1 - if (idx < idx_max_v[i]) { - if (mode == SoftmaxMode::kLogSoftmax) { // log softmax - softmax[(first_batch + i) * stride + idx] = - srcdata[i][it][0] - max_value[i] - sum[i]; - // softmax with cross entropy hard label - } else if (mode == SoftmaxMode::kCrossEntropy) { - AccT logsoftmax = srcdata[i][it][0] - max_value[i] - sum[i]; - // softmax - softmax[(first_batch + i) * stride + idx] = std::exp(logsoftmax); - // label - int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize; - if (IgnoreIndex == true) { - // IgnoreIndex is true - if (label[first_batch + i] == loss_idx) { - if (label[first_batch + i] != ignore_index) { - loss[first_batch + i] = -logsoftmax; - } else { - loss[first_batch + i] = static_cast(0.0); - } - } - } else { - // IgnoreIndex is false - if (label[first_batch + i] >= 0 && - label[first_batch + i] < element_count) { - if (label[first_batch + i] == loss_idx) { - loss[first_batch + i] = -logsoftmax; - } - } else { - loss[first_batch + i] = static_cast(0.0); - } - } - } else { // softmax - softmax[(first_batch + i) * stride + idx] = - srcdata[i][it][0] / sum[i]; - } - } else { - break; - } - } else { // KVSize>1 - VecT* softmax_v = - reinterpret_cast(&softmax[(first_batch + i) * stride]); - VecT tmpdata; - T* tmpptr = reinterpret_cast(&tmpdata); -#pragma unroll - for (int s = 0; s < kVSize; ++s) { - if (mode == SoftmaxMode::kLogSoftmax) { // log softmax - tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i]; - // softmax with cross entropy hard label - } else if (mode == SoftmaxMode::kCrossEntropy) { - AccT logsoftmax = srcdata[i][it][s] - max_value[i] - sum[i]; - // softmax - tmpptr[s] = std::exp(logsoftmax); - // label - int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s; - if (IgnoreIndex == true) { - // IgnoreIndex is true - if (label[first_batch + i] == loss_idx && - label[first_batch + i] != ignore_index) { - loss[first_batch + i] = -logsoftmax; - } - } else { - // IgnoreIndex is false - if (label[first_batch + i] >= 0 && - label[first_batch + i] < element_count) { - if (label[first_batch + i] == loss_idx) { - loss[first_batch + i] = -logsoftmax; - } - } else { - loss[first_batch + i] = static_cast(0.0); - } - } - } else { // softmax - tmpptr[s] = srcdata[i][it][s] / sum[i]; - } - } - if (idx < idx_max_v[i]) { - softmax_v[idx] = tmpdata; - } else { - break; - } - } +__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, + const int64_t n, const int64_t d, + const int64_t remain, const int ignore_index) { + CUDA_KERNEL_LOOP_TYPE(index, n * remain, int64_t) { + int64_t idx_n = index / remain; + int64_t idx_remain = index % remain; + int64_t tmp = labels[index]; + if (ignore_index != tmp) { + int64_t idx = idx_n * d + tmp * remain + idx_remain; + logit_grad[idx] -= static_cast(1.); } } } -#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, VecT, AccT) \ - case Log2Elements: \ - WarpSoftmaxForward<<>>( \ - loss, softmax, src, label, batch_size, stride, element_count, \ - ignore_index); \ - break; - -/* - Wrapper of softmax with cross entropy forward hard label. -*/ -template -void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, - const int64_t* label, const int batch_size, - const int stride, const int element_count, - const int ignore_index, gpuStream_t stream) { - using AccT = typename details::MPTypeTrait::Type; - - // use 128 threads per block to maximimize gpu utilization - const int Log2Elements = static_cast(Log2Ceil(element_count)); - const int kDimCeil = 1 << Log2Elements; - int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; - int batches_per_warp = (kDimCeil <= 128) ? 2 : 1; - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / kWarpSize); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_size + batches_per_block - 1) / batches_per_block; - dim3 threads(kWarpSize, warps_per_block, 1); - - switch (Log2Elements) { - SOFTMAX_WARP_FORWARD_CASE(0, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(1, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(2, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(3, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(4, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(5, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(6, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(7, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(8, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(9, T, AccT); - default: - break; - } -} - -/* - Wrapper of softmax with cross entropy hard label. - - SwitchWarpSoftmaxForward for small size - - cudnn function for large size -*/ -template -static void SoftmaxWithCrossEntropyHardLabel( - const platform::CUDADeviceContext& ctx, int rank, int axis, - const T* logits_data, const int64_t* labels_data, T* loss_data, - T* softmax_data, int N, int dim, int D, const int ignore_index) { - auto stream = ctx.stream(); - constexpr int max_dim = 320; - if (D == 1 && dim <= max_dim) { // small size - const SoftmaxMode mode = SoftmaxMode::kCrossEntropy; - SwitchWarpSoftmaxForward( - loss_data, softmax_data, logits_data, labels_data, N, dim, dim, - ignore_index, stream); - } else { - ScopedTensorDescriptor desc; - std::vector tensor_dims = {N, dim, D, 1}; - DataLayout layout = DataLayout::kNCHW; -#ifdef PADDLE_WITH_HIP - miopenTensorDescriptor_t descp = desc.descriptor(layout, tensor_dims); -#else - cudnnTensorDescriptor_t descp = desc.descriptor(layout, tensor_dims); -#endif - - auto handle = ctx.cudnn_handle(); - -#ifdef PADDLE_WITH_HIP - auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE - : MIOPEN_SOFTMAX_MODE_CHANNEL; - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( - handle, platform::CudnnDataType::kOne(), descp, logits_data, - platform::CudnnDataType::kZero(), descp, softmax_data, - MIOPEN_SOFTMAX_LOG, mode)); -#else - auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE - : CUDNN_SOFTMAX_MODE_CHANNEL; - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( - handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), - descp, logits_data, platform::CudnnDataType::kZero(), descp, - softmax_data)); -#endif - int threads = 128; - int blocks = (N * dim * D + threads - 1) / threads; - // compute cross entropy, input is log softmax - CrossEntropyExpHardLabel<<>>( - loss_data, softmax_data, labels_data, N, dim, D, ignore_index); - } -} - -/* - Wrapper of softmax with cross entropy grad hard label. -*/ template -__global__ void SoftmaxWithCrossEntropyGradHardLabel( - T* logits_grad, const T* loss_grad, const int64_t* labels, const int64_t n, - const int64_t dim, const int64_t d, const int ignore_index) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - int64_t idx_n = idx / (d * dim); - int64_t idx_dim = (idx / d) % dim; - int64_t idx_d = idx % d; - int64_t ids = idx_n * d + idx_d; - - if (idx < n * dim * d) { - if (labels[ids] == ignore_index) { - logits_grad[idx] = static_cast(0.0); - } else if (labels[ids] == idx_dim) { - logits_grad[idx] = - (logits_grad[idx] - static_cast(1.0)) * loss_grad[ids]; +__global__ void Scale(T* logit_grad, const T* loss_grad, const int64_t num, + const int64_t d, const int64_t remain, + const int64_t* labels, const int ignore_index) { + CUDA_KERNEL_LOOP_TYPE(index, num, int64_t) { + int64_t idx_n = index / d; + int64_t idx_remain = index % remain; + int64_t idx_lbl = idx_n * remain + idx_remain; + if (labels[idx_lbl] == ignore_index) { + logit_grad[index] = static_cast(0.); } else { - logits_grad[idx] *= loss_grad[ids]; + logit_grad[index] *= loss_grad[idx_lbl]; } } } @@ -560,6 +123,8 @@ __global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad, } } +} // namespace + static __device__ __forceinline__ platform::float16 exp_on_device( platform::float16 x) { return ::Eigen::numext::exp(x); @@ -831,6 +396,278 @@ static __global__ void RowReductionForCrossEntropy(const T* logits_data, if (threadIdx.x == 0) loss_data[blockIdx.x] = loss; } +template +struct HardLabelCrossEntropyFunctor { + public: + HardLabelCrossEntropyFunctor(const int64_t* labels, T* loss, + const T* logits_data, int d, int axis_dim) + : labels_(labels), + loss_(loss), + logits_data_(logits_data), + d_(d), + axis_dim_(axis_dim) {} + + __device__ void operator()(int idx) const { + // logits view as [n, axis_dim, remain], where d = axis_dim * remain + int remain = d_ / axis_dim_; + int idx_n = idx / d_; + int idx_axis = (idx % d_) / remain; + int idx_remain = idx % remain; + // labels, loss view as [n, remain] + int idx_lbl = idx_n * remain + idx_remain; + // It also would ignore labels not in range(class_num). + if (idx_axis != labels_[idx_lbl]) { + } else { + loss_[idx_lbl] = -log_on_device(logits_data_[idx]); + } + } + + private: + const int64_t* labels_; + T* loss_; + const T* logits_data_; + int d_; + int axis_dim_; +}; + +template +struct HardLabelCrossEntropyFunctorWithIgnoreIdx { + public: + HardLabelCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels, T* loss, + const T* logits_data, int d, + int axis_dim, int ignore_idx) + : labels_(labels), + loss_(loss), + logits_data_(logits_data), + d_(d), + axis_dim_(axis_dim), + ignore_idx_(ignore_idx) {} + + __device__ void operator()(int idx) const { + // logits view as [n, axis_dim, remain], where d = axis_dim * remain + int remain = d_ / axis_dim_; + int idx_n = idx / d_; + int idx_axis = (idx % d_) / remain; + int idx_remain = idx % remain; + // labels, loss view as [n, remain] + int idx_lbl = idx_n * remain + idx_remain; + + if (idx_axis == labels_[idx_lbl] && idx_axis != ignore_idx_) { + loss_[idx_lbl] = -log_on_device(logits_data_[idx]); + } + } + + private: + const int64_t* labels_; + T* loss_; + const T* logits_data_; + int d_; + int axis_dim_; + int ignore_idx_; +}; + +template +static void HardLabelCrossEntropy(const platform::CUDADeviceContext& ctx, + const T* logits_data, + const int64_t* labels_data, T* loss_data, + int n, int d, int axis_dim, int ignore_idx) { + constexpr int kMaxBlockDim = 512; + int block_dim = axis_dim >= kMaxBlockDim + ? kMaxBlockDim + : (1 << static_cast(std::log2(axis_dim))); + int grid_dim = n * d / axis_dim; + auto stream = ctx.stream(); + +#define CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ + case BlockDim: { \ + platform::ForRange for_range(ctx, n* d); \ + if (ignore_idx >= 0 && ignore_idx < axis_dim) { \ + for_range(HardLabelCrossEntropyFunctorWithIgnoreIdx( \ + labels_data, loss_data, logits_data, d, axis_dim, ignore_idx)); \ + } else { \ + for_range(HardLabelCrossEntropyFunctor(labels_data, loss_data, \ + logits_data, d, axis_dim)); \ + } \ + } break + + switch (block_dim) { + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(512); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(256); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(128); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(64); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(32); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(16); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(8); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(4); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(2); + default: + PADDLE_THROW(platform::errors::Unavailable( + "Block Dimension must be 2^n in softmax_with_cross_entropy_op.")); + break; + } +#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL +} + +template +struct HardLabelSoftmaxWithCrossEntropyFunctor { + public: + HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss, + T* log_softmax, int64_t d, + int axis_dim, int ignore_idx) + : labels_(labels), + loss_(loss), + log_softmax_(log_softmax), + d_(d), + axis_dim_(axis_dim), + ignore_idx_(ignore_idx) {} + + __device__ void operator()(int64_t idx) const { + // logits view as [n, axis_dim, remain], where d = axis_dim * remain + int64_t remain = d_ / axis_dim_; + int64_t idx_n = idx / d_; + int64_t idx_axis = (idx % d_) / remain; + int64_t idx_remain = idx % remain; + // labels, loss view as [n, remain] + int64_t idx_lbl = idx_n * remain + idx_remain; + PADDLE_ENFORCE(labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_ || + labels_[idx_lbl] == ignore_idx_, + "The value of label[%ld] expected >= 0 and < %ld, or == %d," + "but got %ld. Please check input value.", + idx_lbl, d_, ignore_idx_, labels_[idx_lbl]); + // It also would ignore labels not in range(class_num). + if (idx_axis != labels_[idx_lbl]) { + log_softmax_[idx] = exp_on_device(log_softmax_[idx]); + } else { + auto softmax = log_softmax_[idx]; + log_softmax_[idx] = exp_on_device(softmax); + loss_[idx_lbl] = -softmax; + } + } + + private: + const int64_t* labels_; + T* loss_; + T* log_softmax_; + int64_t d_; + int axis_dim_; + int ignore_idx_; +}; + +template +struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { + public: + HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels, + T* loss, T* log_softmax, + int64_t d, int axis_dim, + int ignore_idx) + : labels_(labels), + loss_(loss), + log_softmax_(log_softmax), + d_(d), + axis_dim_(axis_dim), + ignore_idx_(ignore_idx) {} + + __device__ void operator()(int64_t idx) const { + // logits view as [n, axis_dim, remain], where d = axis_dim * remain + int64_t remain = d_ / axis_dim_; + int64_t idx_n = idx / d_; + int64_t idx_axis = (idx % d_) / remain; + int64_t idx_remain = idx % remain; + // labels, loss view as [n, remain] + int64_t idx_lbl = idx_n * remain + idx_remain; + if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) { + log_softmax_[idx] = exp_on_device(log_softmax_[idx]); + } else { + auto softmax = log_softmax_[idx]; + log_softmax_[idx] = exp_on_device(softmax); + loss_[idx_lbl] = -softmax; + } + } + + private: + const int64_t* labels_; + T* loss_; + T* log_softmax_; + int64_t d_; + int axis_dim_; + int ignore_idx_; +}; + +template +static void HardLabelSoftmaxWithCrossEntropy( + const platform::CUDADeviceContext& ctx, const T* logits_data, + const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n, + int64_t d, int axis_dim, int ignore_idx) { +#ifdef __HIPCC__ + // HIP platform will have loss nan if dim size > 256 + constexpr int kMaxBlockDim = 256; +#else + constexpr int kMaxBlockDim = 512; +#endif + int64_t block_dim = axis_dim >= kMaxBlockDim + ? kMaxBlockDim + : (1 << static_cast(std::log2(axis_dim))); + int64_t grid_dim = n * d / axis_dim; + auto stream = ctx.stream(); + +#ifdef __HIPCC__ +#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ + case BlockDim: { \ + hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax), \ + dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \ + loss_data, d, axis_dim); \ + hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum), \ + dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \ + loss_data, softmax_data, d, axis_dim); \ + hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForDiff), \ + dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \ + loss_data, softmax_data, d, axis_dim); \ + platform::ForRange for_range(ctx, n* d); \ + if (ignore_idx >= 0 && ignore_idx < axis_dim) { \ + for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx( \ + labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \ + } else { \ + for_range(HardLabelSoftmaxWithCrossEntropyFunctor( \ + labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \ + } \ + } break +#else +#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ + case BlockDim: { \ + RowReductionForMax<<>>( \ + logits_data, loss_data, d, axis_dim); \ + RowReductionForDiffMaxSum<<>>( \ + logits_data, loss_data, softmax_data, d, axis_dim); \ + platform::ForRange for_range(ctx, n* d); \ + if (ignore_idx >= 0 && ignore_idx < axis_dim) { \ + for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx( \ + labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \ + } else { \ + for_range(HardLabelSoftmaxWithCrossEntropyFunctor( \ + labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \ + } \ + } break +#endif + + switch (block_dim) { + CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512); + CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256); + CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128); + CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64); + CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32); + CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16); + CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8); + CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); + CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); + default: + PADDLE_THROW(platform::errors::Unavailable( + "Block Dimension must be 2^n in softmax_with_cross_entropy_op.")); + break; + } +#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL +} + template static void SoftmaxWithCrossEntropyFusedKernel( const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data, @@ -946,7 +783,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { const int rank = softmax->dims().size(); const int axis = CanonicalAxis(context.Attr("axis"), rank); - const int axis_dim = softmax->dims()[axis]; + int axis_dim = softmax->dims()[axis]; const int n = SizeToAxis(axis, softmax->dims()); const int d = SizeFromAxis(axis, softmax->dims()); @@ -989,19 +826,9 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { } else { // HardLabel auto* logits_data = softmax->data(); auto* labels_data = labels->data(); - int threads = 128; - int blocks = (n * d / axis_dim + threads - 1) / threads; - if (ignore_index >= 0 && ignore_index < axis_dim) { - CrossEntropyHardLabel<<< - blocks, threads, 0, context.cuda_device_context().stream()>>>( - loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim, - ignore_index); - } else { - CrossEntropyHardLabel<<< - blocks, threads, 0, context.cuda_device_context().stream()>>>( - loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim, - ignore_index); - } + HardLabelCrossEntropy(context.cuda_device_context(), logits_data, + labels_data, loss_data, n, d, axis_dim, + ignore_index); } // cause of input is softmax @@ -1059,17 +886,9 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { } else { auto* logits_data = logits->data(); auto* labels_data = labels->data(); - if (ignore_index >= 0 && ignore_index < axis_dim) { - SoftmaxWithCrossEntropyHardLabel( - context.cuda_device_context(), rank, axis, logits_data, - labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim, - ignore_index); - } else { - SoftmaxWithCrossEntropyHardLabel( - context.cuda_device_context(), rank, axis, logits_data, - labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim, - ignore_index); - } + HardLabelSoftmaxWithCrossEntropy( + context.cuda_device_context(), logits_data, labels_data, loss_data, + softmax_data, n, d, axis_dim, ignore_index); } } } @@ -1140,11 +959,14 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { SoftCrossEntropyGradientKernel<<>>( logit_grad_data, loss_grad_data, label_data, n, d, remain); } else { + int64_t grid = (n * remain + block - 1) / block; const int64_t* label_data = labels->data(); - int grid = (n * d + block - 1) / block; - SoftmaxWithCrossEntropyGradHardLabel<<>>( - logit_grad_data, loss_grad_data, label_data, n, d / remain, remain, - ignore_index); + CrossEntropyGrad<<>>( + logit_grad_data, label_data, n, d, remain, ignore_index); + int64_t num = n * d; + grid = (num + block - 1) / block; + Scale<<>>(logit_grad_data, loss_grad_data, num, + d, remain, label_data, ignore_index); } } }; -- GitLab