From 7be6191bee6c6f3c1af8b93f989d8fa242844a6b Mon Sep 17 00:00:00 2001 From: Feng Xing <79969986+xingfeng01@users.noreply.github.com> Date: Fri, 21 May 2021 14:01:40 +0800 Subject: [PATCH] optimize softmax with cross entropy hard label (#32290) * optimize softmax with cross entropy hard label * label ignore_index cleaning --- .../softmax_with_cross_entropy_op.cu | 796 +++++++++++------- 1 file changed, 487 insertions(+), 309 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 4aec4c1742..8fe456edea 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -15,44 +15,481 @@ limitations under the License. */ #include namespace cub = hipcub; #endif +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/softmax_impl.cuh" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include "paddle/fluid/platform/for_range.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#else +#include "paddle/fluid/platform/cudnn_helper.h" +#endif namespace paddle { namespace operators { +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using DataLayout = platform::DataLayout; using Tensor = framework::Tensor; -namespace { +// Wrapper of log function. Use log(float32) for float16 template -__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, - const int64_t n, const int64_t d, - const int64_t remain, const int ignore_index) { - CUDA_KERNEL_LOOP_TYPE(index, n * remain, int64_t) { - int64_t idx_n = index / remain; - int64_t idx_remain = index % remain; - int64_t tmp = labels[index]; - if (ignore_index != tmp) { - int64_t idx = idx_n * d + tmp * remain + idx_remain; - logit_grad[idx] -= static_cast(1.); +static __device__ __forceinline__ T Log(T x) { + using AccT = typename details::MPTypeTrait::Type; + AccT logx = std::log(static_cast(x)); + return math::TolerableValue()(static_cast(logx)); +} + +// Wrapper of exp function. Use exp(float32) for float16 +template +static __device__ __forceinline__ T Exp(T x) { + using AccT = typename details::MPTypeTrait::Type; + AccT expx = std::exp(static_cast(x)); + return math::TolerableValue()(static_cast(expx)); +} + +// log2(value) +static inline int Log2Ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +enum class SoftmaxMode { kSoftmax, kLogSoftmax, kCrossEntropy }; + +/* + Hard label cross entropy. +*/ +template +__global__ void CrossEntropyHardLabel(T* loss, const T* softmax, + const int64_t* labels, const int n, + const int dim, const int d, + const int ignore_idx) { + int64_t ids = blockIdx.x * blockDim.x + threadIdx.x; + int64_t idx_n = ids / d; + int64_t idx_d = ids % d; + + // thread ids compute loss[ids] using softmax[idx] + if (ids < n * d) { + int64_t idx = idx_n * dim * d + labels[ids] * d + idx_d; + if (IgnoreIndex == true) { + // IgnoreIndex is true + if (labels[ids] == ignore_idx) { + loss[ids] = static_cast(0.0); + } else { + loss[ids] = -Log(softmax[idx]); + } + } else { + // IgnoreIndex is false + loss[ids] = -Log(softmax[idx]); } } } +/* + Hard label cross entropy with exp. + Input: log softmax + Output: loss and exp(input) +*/ +template +__global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, + const int64_t* labels, const int n, + const int dim, const int d, + const int ignore_idx) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t idx_n = idx / (d * dim); + int64_t idx_dim = (idx / d) % dim; + int64_t idx_d = idx % d; + int64_t ids = idx_n * d + idx_d; + + if (idx < n * dim * d) { + if (IgnoreIndex == true) { + // IgnoreIndex is true + if (idx_dim == labels[ids]) { + if (labels[ids] == ignore_idx) { + loss[ids] = static_cast(0.0); + } else { + loss[ids] = -softmax[idx]; + } + } + } else { + // IgnoreIndex is false + if (labels[ids] >= 0 && labels[ids] < dim) { + if (labels[ids] == idx_dim) { + loss[ids] = -softmax[idx]; + } + } else { + loss[ids] = static_cast(0.0); + } + } + softmax[idx] = Exp(softmax[idx]); + } +} + +/* + Core function of softmax with cross entropy forward + - softmax, SoftmaxMode=kSoftmax + - log softmax, SoftmaxMode=kLogSoftmax + - softmax with cross entropy hard label, SoftmaxMode=kCrossEntropy + The computation includes + - Compute max value: maxvalue_{i} = max_j src_{i,j} + - Compute sum of exp: s_{i} = sum_{j}{e^{src_{i,j} - maxvalue_{i}}} + - Compute: softmax_{i,j} = e^{src_{i,j} - maxvalue_{i}} / s_{i} + - Compute: logsoftmax_{i,j} = src_{i,j} - maxvalue_{i} - log(s_{i}) + - Compute: loss_{i} = -logsoftmax[i,label[i]] (Hard label) + This computation results from following formula: + softmax_{i,j} = e^{src_{i,j}} / sum_{j}{e^{src_{i,j}}} + = e^{src_{i,j} - maxvalue_{i}} + / sum_{j}{e^{src_{i,j} - maxvalue_{i}}} + = e^{src_{i,j} - maxvalue_{i}} / s_{i} + logsoftmax_{i,j} = log(softmax_{i,j}) + = src_{i,j} - maxvalue_{i} - log(s_{i}) + One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). + For reduction max (sum), firstly compute max (sum) to one warp, then use + shuffle api to compute max (sum) in one warp. +*/ +template +__global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, + const int64_t* label, const int batch_size, + const int stride, const int element_count, + const int ignore_index) { + constexpr int kDimCeil = 1 << Log2Elements; + constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + constexpr int kVSize = sizeof(VecT) / sizeof(T); + constexpr int kIterations = kDimCeil / kWarpSize; + constexpr int kIterationsV = + (kIterations >= kVSize) ? (kIterations / kVSize) : 1; + constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; + + // max index to read + int idx_max_v[kBatchSize]; +#pragma unroll + for (int i = 0; i < kBatchSize; i++) { + int idx_max = ((i + first_batch) < batch_size) ? element_count : 0; + idx_max_v[i] = idx_max / kVSize; + } + + // read data from global memory + AccT srcdata[kBatchSize][kIterationsV][kVSize]; + +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { +// read data to srcdata: - KVSize==1, - KVSize>1 +#pragma unroll + for (int it = 0; it < kIterationsV; ++it) { + int src_idx = threadIdx.x + it * kWarpSize; + if (kVSize == 1) { + if (src_idx < idx_max_v[i]) { + srcdata[i][it][0] = + static_cast(src[(first_batch + i) * stride + src_idx]); + } else { + srcdata[i][it][0] = -std::numeric_limits::infinity(); + } + } else { + const VecT* src_v = + reinterpret_cast(&src[(first_batch + i) * stride]); + if (src_idx < idx_max_v[i]) { + VecT srctmp = src_v[src_idx]; + const T* srcinptr = reinterpret_cast(&srctmp); +#pragma unroll + for (int s = 0; s < kVSize; s++) { + srcdata[i][it][s] = static_cast(srcinptr[s]); + } + } else { +#pragma unroll + for (int s = 0; s < kVSize; s++) { + srcdata[i][it][s] = -std::numeric_limits::infinity(); + } + } + } + } + } + + // compute max value: maxvalue_{i} = max_j src_{i,j} + AccT max_value[kBatchSize]; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + // it = 0 + AccT valmax = srcdata[i][0][0]; +#pragma unroll + for (int s = 1; s < kVSize; ++s) { + valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s]; + } + max_value[i] = valmax; + +// it = 1, 2, ... +#pragma unroll + for (int it = 1; it < kIterationsV; ++it) { + AccT valmax = srcdata[i][it][0]; +#pragma unroll + for (int s = 1; s < kVSize; ++s) { + valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s]; + } + max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax; + } + } + WarpReduceMax(max_value); + + // compute sum: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} } + AccT sum[kBatchSize]; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + // it = 0 + if (mode == SoftmaxMode::kLogSoftmax || + mode == SoftmaxMode::kCrossEntropy) { + sum[i] = std::exp(srcdata[i][0][0] - max_value[i]); + } else { + srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]); + sum[i] = srcdata[i][0][0]; + } +#pragma unroll + for (int s = 1; s < kVSize; ++s) { + if (mode == SoftmaxMode::kLogSoftmax || + mode == SoftmaxMode::kCrossEntropy) { + sum[i] += std::exp(srcdata[i][0][s] - max_value[i]); + } else { + srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]); + sum[i] += srcdata[i][0][s]; + } + } + +// it = 1, 2, ... +#pragma unroll + for (int it = 1; it < kIterationsV; ++it) { +#pragma unroll + for (int s = 0; s < kVSize; ++s) { + if (mode == SoftmaxMode::kLogSoftmax || + mode == SoftmaxMode::kCrossEntropy) { + sum[i] += std::exp(srcdata[i][it][s] - max_value[i]); + } else { + srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]); + sum[i] += srcdata[i][it][s]; + } + } + } + } + WarpReduceSum(sum); + +// write data +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + if (mode == SoftmaxMode::kLogSoftmax || + mode == SoftmaxMode::kCrossEntropy) { + sum[i] = std::log(sum[i]); + } + +#pragma unroll + for (int it = 0; it < kIterationsV; ++it) { + int idx = threadIdx.x + it * kWarpSize; + if (kVSize == 1) { // kVSize==1 + if (idx < idx_max_v[i]) { + if (mode == SoftmaxMode::kLogSoftmax) { // log softmax + softmax[(first_batch + i) * stride + idx] = + srcdata[i][it][0] - max_value[i] - sum[i]; + // softmax with cross entropy hard label + } else if (mode == SoftmaxMode::kCrossEntropy) { + AccT logsoftmax = srcdata[i][it][0] - max_value[i] - sum[i]; + // softmax + softmax[(first_batch + i) * stride + idx] = std::exp(logsoftmax); + // label + int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize; + if (IgnoreIndex == true) { + // IgnoreIndex is true + if (label[first_batch + i] == loss_idx) { + if (label[first_batch + i] != ignore_index) { + loss[first_batch + i] = -logsoftmax; + } else { + loss[first_batch + i] = static_cast(0.0); + } + } + } else { + // IgnoreIndex is false + if (label[first_batch + i] >= 0 && + label[first_batch + i] < element_count) { + if (label[first_batch + i] == loss_idx) { + loss[first_batch + i] = -logsoftmax; + } + } else { + loss[first_batch + i] = static_cast(0.0); + } + } + } else { // softmax + softmax[(first_batch + i) * stride + idx] = + srcdata[i][it][0] / sum[i]; + } + } else { + break; + } + } else { // KVSize>1 + VecT* softmax_v = + reinterpret_cast(&softmax[(first_batch + i) * stride]); + VecT tmpdata; + T* tmpptr = reinterpret_cast(&tmpdata); +#pragma unroll + for (int s = 0; s < kVSize; ++s) { + if (mode == SoftmaxMode::kLogSoftmax) { // log softmax + tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i]; + // softmax with cross entropy hard label + } else if (mode == SoftmaxMode::kCrossEntropy) { + AccT logsoftmax = srcdata[i][it][s] - max_value[i] - sum[i]; + // softmax + tmpptr[s] = std::exp(logsoftmax); + // label + int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s; + if (IgnoreIndex == true) { + // IgnoreIndex is true + if (label[first_batch + i] == loss_idx && + label[first_batch + i] != ignore_index) { + loss[first_batch + i] = -logsoftmax; + } + } else { + // IgnoreIndex is false + if (label[first_batch + i] >= 0 && + label[first_batch + i] < element_count) { + if (label[first_batch + i] == loss_idx) { + loss[first_batch + i] = -logsoftmax; + } + } else { + loss[first_batch + i] = static_cast(0.0); + } + } + } else { // softmax + tmpptr[s] = srcdata[i][it][s] / sum[i]; + } + } + if (idx < idx_max_v[i]) { + softmax_v[idx] = tmpdata; + } else { + break; + } + } + } + } +} + +#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, VecT, AccT) \ + case Log2Elements: \ + WarpSoftmaxForward<<>>( \ + loss, softmax, src, label, batch_size, stride, element_count, \ + ignore_index); \ + break; + +/* + Wrapper of softmax with cross entropy forward hard label. +*/ +template +void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, + const int64_t* label, const int batch_size, + const int stride, const int element_count, + const int ignore_index, gpuStream_t stream) { + using AccT = typename details::MPTypeTrait::Type; + + // use 128 threads per block to maximimize gpu utilization + const int Log2Elements = static_cast(Log2Ceil(element_count)); + const int kDimCeil = 1 << Log2Elements; + int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + int batches_per_warp = (kDimCeil <= 128) ? 2 : 1; + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / kWarpSize); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (batch_size + batches_per_block - 1) / batches_per_block; + dim3 threads(kWarpSize, warps_per_block, 1); + + switch (Log2Elements) { + SOFTMAX_WARP_FORWARD_CASE(0, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(1, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(2, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(3, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(4, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(5, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(6, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(7, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(8, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(9, T, AccT); + default: + break; + } +} + +/* + Wrapper of softmax with cross entropy hard label. + - SwitchWarpSoftmaxForward for small size + - cudnn function for large size +*/ +template +static void SoftmaxWithCrossEntropyHardLabel( + const platform::CUDADeviceContext& ctx, int rank, int axis, + const T* logits_data, const int64_t* labels_data, T* loss_data, + T* softmax_data, int N, int dim, int D, const int ignore_index) { + auto stream = ctx.stream(); + constexpr int max_dim = 320; + if (D == 1 && dim <= max_dim) { // small size + const SoftmaxMode mode = SoftmaxMode::kCrossEntropy; + SwitchWarpSoftmaxForward( + loss_data, softmax_data, logits_data, labels_data, N, dim, dim, + ignore_index, stream); + } else { + ScopedTensorDescriptor desc; + std::vector tensor_dims = {N, dim, D, 1}; + DataLayout layout = DataLayout::kNCHW; +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t descp = desc.descriptor(layout, tensor_dims); +#else + cudnnTensorDescriptor_t descp = desc.descriptor(layout, tensor_dims); +#endif + + auto handle = ctx.cudnn_handle(); + +#ifdef PADDLE_WITH_HIP + auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE + : MIOPEN_SOFTMAX_MODE_CHANNEL; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( + handle, platform::CudnnDataType::kOne(), descp, logits_data, + platform::CudnnDataType::kZero(), descp, softmax_data, + MIOPEN_SOFTMAX_LOG, mode)); +#else + auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE + : CUDNN_SOFTMAX_MODE_CHANNEL; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( + handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), + descp, logits_data, platform::CudnnDataType::kZero(), descp, + softmax_data)); +#endif + int threads = 128; + int blocks = (N * dim * D + threads - 1) / threads; + // compute cross entropy, input is log softmax + CrossEntropyExpHardLabel<<>>( + loss_data, softmax_data, labels_data, N, dim, D, ignore_index); + } +} + +/* + Wrapper of softmax with cross entropy grad hard label. +*/ template -__global__ void Scale(T* logit_grad, const T* loss_grad, const int64_t num, - const int64_t d, const int64_t remain, - const int64_t* labels, const int ignore_index) { - CUDA_KERNEL_LOOP_TYPE(index, num, int64_t) { - int64_t idx_n = index / d; - int64_t idx_remain = index % remain; - int64_t idx_lbl = idx_n * remain + idx_remain; - if (labels[idx_lbl] == ignore_index) { - logit_grad[index] = static_cast(0.); +__global__ void SoftmaxWithCrossEntropyGradHardLabel( + T* logits_grad, const T* loss_grad, const int64_t* labels, const int64_t n, + const int64_t dim, const int64_t d, const int ignore_index) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t idx_n = idx / (d * dim); + int64_t idx_dim = (idx / d) % dim; + int64_t idx_d = idx % d; + int64_t ids = idx_n * d + idx_d; + + if (idx < n * dim * d) { + if (labels[ids] == ignore_index) { + logits_grad[idx] = static_cast(0.0); + } else if (labels[ids] == idx_dim) { + logits_grad[idx] = + (logits_grad[idx] - static_cast(1.0)) * loss_grad[ids]; } else { - logit_grad[index] *= loss_grad[idx_lbl]; + logits_grad[idx] *= loss_grad[ids]; } } } @@ -123,8 +560,6 @@ __global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad, } } -} // namespace - static __device__ __forceinline__ platform::float16 exp_on_device( platform::float16 x) { return ::Eigen::numext::exp(x); @@ -396,278 +831,6 @@ static __global__ void RowReductionForCrossEntropy(const T* logits_data, if (threadIdx.x == 0) loss_data[blockIdx.x] = loss; } -template -struct HardLabelCrossEntropyFunctor { - public: - HardLabelCrossEntropyFunctor(const int64_t* labels, T* loss, - const T* logits_data, int d, int axis_dim) - : labels_(labels), - loss_(loss), - logits_data_(logits_data), - d_(d), - axis_dim_(axis_dim) {} - - __device__ void operator()(int idx) const { - // logits view as [n, axis_dim, remain], where d = axis_dim * remain - int remain = d_ / axis_dim_; - int idx_n = idx / d_; - int idx_axis = (idx % d_) / remain; - int idx_remain = idx % remain; - // labels, loss view as [n, remain] - int idx_lbl = idx_n * remain + idx_remain; - // It also would ignore labels not in range(class_num). - if (idx_axis != labels_[idx_lbl]) { - } else { - loss_[idx_lbl] = -log_on_device(logits_data_[idx]); - } - } - - private: - const int64_t* labels_; - T* loss_; - const T* logits_data_; - int d_; - int axis_dim_; -}; - -template -struct HardLabelCrossEntropyFunctorWithIgnoreIdx { - public: - HardLabelCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels, T* loss, - const T* logits_data, int d, - int axis_dim, int ignore_idx) - : labels_(labels), - loss_(loss), - logits_data_(logits_data), - d_(d), - axis_dim_(axis_dim), - ignore_idx_(ignore_idx) {} - - __device__ void operator()(int idx) const { - // logits view as [n, axis_dim, remain], where d = axis_dim * remain - int remain = d_ / axis_dim_; - int idx_n = idx / d_; - int idx_axis = (idx % d_) / remain; - int idx_remain = idx % remain; - // labels, loss view as [n, remain] - int idx_lbl = idx_n * remain + idx_remain; - - if (idx_axis == labels_[idx_lbl] && idx_axis != ignore_idx_) { - loss_[idx_lbl] = -log_on_device(logits_data_[idx]); - } - } - - private: - const int64_t* labels_; - T* loss_; - const T* logits_data_; - int d_; - int axis_dim_; - int ignore_idx_; -}; - -template -static void HardLabelCrossEntropy(const platform::CUDADeviceContext& ctx, - const T* logits_data, - const int64_t* labels_data, T* loss_data, - int n, int d, int axis_dim, int ignore_idx) { - constexpr int kMaxBlockDim = 512; - int block_dim = axis_dim >= kMaxBlockDim - ? kMaxBlockDim - : (1 << static_cast(std::log2(axis_dim))); - int grid_dim = n * d / axis_dim; - auto stream = ctx.stream(); - -#define CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ - case BlockDim: { \ - platform::ForRange for_range(ctx, n* d); \ - if (ignore_idx >= 0 && ignore_idx < axis_dim) { \ - for_range(HardLabelCrossEntropyFunctorWithIgnoreIdx( \ - labels_data, loss_data, logits_data, d, axis_dim, ignore_idx)); \ - } else { \ - for_range(HardLabelCrossEntropyFunctor(labels_data, loss_data, \ - logits_data, d, axis_dim)); \ - } \ - } break - - switch (block_dim) { - CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(512); - CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(256); - CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(128); - CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(64); - CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(32); - CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(16); - CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(8); - CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(4); - CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(2); - default: - PADDLE_THROW(platform::errors::Unavailable( - "Block Dimension must be 2^n in softmax_with_cross_entropy_op.")); - break; - } -#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL -} - -template -struct HardLabelSoftmaxWithCrossEntropyFunctor { - public: - HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss, - T* log_softmax, int64_t d, - int axis_dim, int ignore_idx) - : labels_(labels), - loss_(loss), - log_softmax_(log_softmax), - d_(d), - axis_dim_(axis_dim), - ignore_idx_(ignore_idx) {} - - __device__ void operator()(int64_t idx) const { - // logits view as [n, axis_dim, remain], where d = axis_dim * remain - int64_t remain = d_ / axis_dim_; - int64_t idx_n = idx / d_; - int64_t idx_axis = (idx % d_) / remain; - int64_t idx_remain = idx % remain; - // labels, loss view as [n, remain] - int64_t idx_lbl = idx_n * remain + idx_remain; - PADDLE_ENFORCE(labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_ || - labels_[idx_lbl] == ignore_idx_, - "The value of label[%ld] expected >= 0 and < %ld, or == %d," - "but got %ld. Please check input value.", - idx_lbl, d_, ignore_idx_, labels_[idx_lbl]); - // It also would ignore labels not in range(class_num). - if (idx_axis != labels_[idx_lbl]) { - log_softmax_[idx] = exp_on_device(log_softmax_[idx]); - } else { - auto softmax = log_softmax_[idx]; - log_softmax_[idx] = exp_on_device(softmax); - loss_[idx_lbl] = -softmax; - } - } - - private: - const int64_t* labels_; - T* loss_; - T* log_softmax_; - int64_t d_; - int axis_dim_; - int ignore_idx_; -}; - -template -struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { - public: - HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels, - T* loss, T* log_softmax, - int64_t d, int axis_dim, - int ignore_idx) - : labels_(labels), - loss_(loss), - log_softmax_(log_softmax), - d_(d), - axis_dim_(axis_dim), - ignore_idx_(ignore_idx) {} - - __device__ void operator()(int64_t idx) const { - // logits view as [n, axis_dim, remain], where d = axis_dim * remain - int64_t remain = d_ / axis_dim_; - int64_t idx_n = idx / d_; - int64_t idx_axis = (idx % d_) / remain; - int64_t idx_remain = idx % remain; - // labels, loss view as [n, remain] - int64_t idx_lbl = idx_n * remain + idx_remain; - if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) { - log_softmax_[idx] = exp_on_device(log_softmax_[idx]); - } else { - auto softmax = log_softmax_[idx]; - log_softmax_[idx] = exp_on_device(softmax); - loss_[idx_lbl] = -softmax; - } - } - - private: - const int64_t* labels_; - T* loss_; - T* log_softmax_; - int64_t d_; - int axis_dim_; - int ignore_idx_; -}; - -template -static void HardLabelSoftmaxWithCrossEntropy( - const platform::CUDADeviceContext& ctx, const T* logits_data, - const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n, - int64_t d, int axis_dim, int ignore_idx) { -#ifdef __HIPCC__ - // HIP platform will have loss nan if dim size > 256 - constexpr int kMaxBlockDim = 256; -#else - constexpr int kMaxBlockDim = 512; -#endif - int64_t block_dim = axis_dim >= kMaxBlockDim - ? kMaxBlockDim - : (1 << static_cast(std::log2(axis_dim))); - int64_t grid_dim = n * d / axis_dim; - auto stream = ctx.stream(); - -#ifdef __HIPCC__ -#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ - case BlockDim: { \ - hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax), \ - dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \ - loss_data, d, axis_dim); \ - hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum), \ - dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \ - loss_data, softmax_data, d, axis_dim); \ - hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForDiff), \ - dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \ - loss_data, softmax_data, d, axis_dim); \ - platform::ForRange for_range(ctx, n* d); \ - if (ignore_idx >= 0 && ignore_idx < axis_dim) { \ - for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx( \ - labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \ - } else { \ - for_range(HardLabelSoftmaxWithCrossEntropyFunctor( \ - labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \ - } \ - } break -#else -#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ - case BlockDim: { \ - RowReductionForMax<<>>( \ - logits_data, loss_data, d, axis_dim); \ - RowReductionForDiffMaxSum<<>>( \ - logits_data, loss_data, softmax_data, d, axis_dim); \ - platform::ForRange for_range(ctx, n* d); \ - if (ignore_idx >= 0 && ignore_idx < axis_dim) { \ - for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx( \ - labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \ - } else { \ - for_range(HardLabelSoftmaxWithCrossEntropyFunctor( \ - labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \ - } \ - } break -#endif - - switch (block_dim) { - CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512); - CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256); - CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128); - CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64); - CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32); - CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16); - CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8); - CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); - CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); - default: - PADDLE_THROW(platform::errors::Unavailable( - "Block Dimension must be 2^n in softmax_with_cross_entropy_op.")); - break; - } -#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL -} - template static void SoftmaxWithCrossEntropyFusedKernel( const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data, @@ -783,7 +946,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { const int rank = softmax->dims().size(); const int axis = CanonicalAxis(context.Attr("axis"), rank); - int axis_dim = softmax->dims()[axis]; + const int axis_dim = softmax->dims()[axis]; const int n = SizeToAxis(axis, softmax->dims()); const int d = SizeFromAxis(axis, softmax->dims()); @@ -826,9 +989,19 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { } else { // HardLabel auto* logits_data = softmax->data(); auto* labels_data = labels->data(); - HardLabelCrossEntropy(context.cuda_device_context(), logits_data, - labels_data, loss_data, n, d, axis_dim, - ignore_index); + int threads = 128; + int blocks = (n * d / axis_dim + threads - 1) / threads; + if (ignore_index >= 0 && ignore_index < axis_dim) { + CrossEntropyHardLabel<<< + blocks, threads, 0, context.cuda_device_context().stream()>>>( + loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim, + ignore_index); + } else { + CrossEntropyHardLabel<<< + blocks, threads, 0, context.cuda_device_context().stream()>>>( + loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim, + ignore_index); + } } // cause of input is softmax @@ -886,9 +1059,17 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { } else { auto* logits_data = logits->data(); auto* labels_data = labels->data(); - HardLabelSoftmaxWithCrossEntropy( - context.cuda_device_context(), logits_data, labels_data, loss_data, - softmax_data, n, d, axis_dim, ignore_index); + if (ignore_index >= 0 && ignore_index < axis_dim) { + SoftmaxWithCrossEntropyHardLabel( + context.cuda_device_context(), rank, axis, logits_data, + labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim, + ignore_index); + } else { + SoftmaxWithCrossEntropyHardLabel( + context.cuda_device_context(), rank, axis, logits_data, + labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim, + ignore_index); + } } } } @@ -959,14 +1140,11 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { SoftCrossEntropyGradientKernel<<>>( logit_grad_data, loss_grad_data, label_data, n, d, remain); } else { - int64_t grid = (n * remain + block - 1) / block; const int64_t* label_data = labels->data(); - CrossEntropyGrad<<>>( - logit_grad_data, label_data, n, d, remain, ignore_index); - int64_t num = n * d; - grid = (num + block - 1) / block; - Scale<<>>(logit_grad_data, loss_grad_data, num, - d, remain, label_data, ignore_index); + int grid = (n * d + block - 1) / block; + SoftmaxWithCrossEntropyGradHardLabel<<>>( + logit_grad_data, loss_grad_data, label_data, n, d / remain, remain, + ignore_index); } } }; -- GitLab