From f895560264dc0afdaa957b60da8d7fdf6aa40f54 Mon Sep 17 00:00:00 2001 From: Feng Xing <79969986+xingfeng01@users.noreply.github.com> Date: Mon, 20 Dec 2021 16:41:37 +0800 Subject: [PATCH] optimize softmax with cross entropy soft label (#32387) softmax_with_cross_entropy optimization with soft label. This PR includes optimization of "SoftmaxWithCrossEntropySoftLabel" : compute log_softmax and then compute loss. "CrossEntropySoftLabel" : compute loss with softmax as input. These optimization includes following technics: read data to buffer with vectorization compute max and sum in warp fixed loop size with macro Performance (computation time): softmax_with_cross_entropy_0 (forward) : -40.1% softmax_with_cross_entropy_0 (backward): -41% --- .../softmax_with_cross_entropy_op.cu | 765 +++++++++--------- 1 file changed, 389 insertions(+), 376 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 520c95b6f34..b6f89c7af4a 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -20,6 +20,7 @@ namespace cub = hipcub; #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/softmax_cudnn_op.cu.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/for_range.h" @@ -391,8 +392,8 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, using AccT = typename details::MPTypeTrait::Type; // use 128 threads per block to maximimize gpu utilization - const int Log2Elements = static_cast(Log2Ceil(element_count)); - const int kDimCeil = 1 << Log2Elements; + const int log2_elements = static_cast(Log2Ceil(element_count)); + const int kDimCeil = 1 << log2_elements; int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; int batches_per_warp = (kDimCeil <= 128) ? 2 : 1; constexpr int threads_per_block = 128; @@ -401,7 +402,7 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, int blocks = (batch_size + batches_per_block - 1) / batches_per_block; dim3 threads(kWarpSize, warps_per_block, 1); - switch (Log2Elements) { + switch (log2_elements) { SOFTMAX_WARP_FORWARD_CASE(0, T, AccT); SOFTMAX_WARP_FORWARD_CASE(1, T, AccT); SOFTMAX_WARP_FORWARD_CASE(2, T, AccT); @@ -494,6 +495,368 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel( } } +/* + Cross entropy soft label with dynamic size on axis (log2_elements is + varibale). + - if the input is softmax,compute loss with softmax + - if the input is log_softmax, compute loss with log_softmax and update + softmax +*/ +template +__global__ void CrossEntropySoftLabel(T* loss, T* softmaxwrt, const T* softmax, + const T* labels, const int n, + const int dim, const int d, + int log2_elements) { + const int kDimCeil = 1 << log2_elements; + const int kVSize = sizeof(VecT) / sizeof(T); + +#ifdef __HIPCC__ + const int kThreadPerBlock = 256; +#else + const int kThreadPerBlock = 512; +#endif + const int kBatchPerBlock = 1; + const int kWarpSize = 32; // (dim < 32) ? dim : 32; + const int kBatchSize = 1; + const int kThreadPerBatch = kThreadPerBlock / kBatchPerBlock; + const int kWarpPerBatch = kThreadPerBatch / kWarpSize; + + const int kIterations = (dim + kThreadPerBatch - 1) / kThreadPerBatch; + const int kIterationsV = (kIterations >= kVSize) ? (kIterations / kVSize) : 1; + + const int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; + + T sum[kBatchSize]{static_cast(0.0)}; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + int ids = first_batch + i; + if (ids >= n * d) break; + int idx_n = ids / d; + int idx_d = ids % d; +#pragma unroll + for (int it = 0; it < kIterations; ++it) { + int idx_dim = it * kThreadPerBatch + threadIdx.x; + int idx = idx_n * dim * d + idx_dim * d + idx_d; + + if (idx_n < n && idx_dim < dim) { + VecT softmaxdata; + if (InLogMode) { + softmaxdata = reinterpret_cast(&softmaxwrt[idx])[0]; + } else { + softmaxdata = reinterpret_cast(&softmax[idx])[0]; + } + VecT labelsdata = reinterpret_cast(&labels[idx])[0]; + T* softmaxptr = reinterpret_cast(&softmaxdata); + T* labelsptr = reinterpret_cast(&labelsdata); +#pragma unroll + for (int s = 0; s < kVSize; s++) { + if (InLogMode) { + sum[i] -= softmaxptr[s] * labelsptr[s]; + softmaxptr[s] = Exp(softmaxptr[s]); + } else { + sum[i] -= Log(softmaxptr[s]) * labelsptr[s]; + } + } + if (InLogMode) { + reinterpret_cast(&softmaxwrt[idx])[0] = softmaxdata; + } + } + } + } + WarpReduceSum(sum); + __syncthreads(); + + __shared__ T sumshare[kWarpPerBatch][kBatchPerBlock][kBatchSize]; + if (threadIdx.x % kWarpSize == 0) { +#pragma unroll + for (int i = 0; i < kBatchSize; i++) { + sumshare[threadIdx.x / kWarpSize][threadIdx.y][i] = sum[i]; + } + } + __syncthreads(); + + // write + if (threadIdx.x == 0) { + for (int i = 0; i < kBatchSize; i++) { + int ids = first_batch + i; + if (ids < n * d) { + loss[ids] = sumshare[0][threadIdx.y][i]; + for (int s = 1; s < kWarpPerBatch; s++) { + loss[ids] += sumshare[s][threadIdx.y][i]; + } + } + } + } +} + +/* +Core function of softmax with cross entropy forward soft label. +The computation includes + - Compute maximum of batch: maxvalue_{i} = max_j src_{i,j} + - Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} } + - Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} - +log(sum[i]))} +One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). +For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle +api to compute max (sum) in one warp. +*/ +template +__global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src, + const T* label, + const int batch_size, + const int stride, + const int element_count) { + const bool LogMode = true; + + constexpr int kDimCeil = 1 << Log2Elements; + constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + constexpr int kVSize = sizeof(VecT) / sizeof(T); + constexpr int kIterations = kDimCeil / kWarpSize; + constexpr int kIterationsV = + (kIterations >= kVSize) ? (kIterations / kVSize) : 1; + constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; + int local_batches = batch_size - first_batch; + if (local_batches > kBatchSize) { + local_batches = kBatchSize; + } + + // read data from global memory + VecT srcdata[kBatchSize][kIterationsV]; + VecT labeldata[kBatchSize][kIterationsV]; + + for (int i = 0; i < kBatchSize; ++i) { + const VecT* src_v = + reinterpret_cast(&src[(first_batch + i) * stride]); + const VecT* label_v = + reinterpret_cast(&label[(first_batch + i) * stride]); + + // max index to read + int idx_max = (i < local_batches) ? element_count : 0; + int idx_max_v = idx_max / kVSize; + + // read data + for (int it = 0; it < kIterationsV; ++it) { + int src_idx = threadIdx.x + it * kWarpSize; + if (src_idx < idx_max_v) { + srcdata[i][it] = src_v[src_idx]; + labeldata[i][it] = label_v[src_idx]; + } else { +#pragma unroll + for (int s = 0; s < kVSize; s++) { + reinterpret_cast(&srcdata[i][it])[s] = + -std::numeric_limits::max(); + reinterpret_cast(&labeldata[i][it])[s] = 0.0; + } + } + } + } + + // compute max value + AccT max_value[kBatchSize]; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + max_value[i] = -std::numeric_limits::infinity(); +#pragma unroll + for (int it = 0; it < kIterationsV; ++it) { + T* srcptr_v = reinterpret_cast(&srcdata[i][it]); + T valmax = srcptr_v[0]; +#pragma unroll + for (int s = 1; s < kVSize; ++s) { + valmax = (valmax > srcptr_v[s]) ? valmax : srcptr_v[s]; + } + max_value[i] = (max_value[i] > static_cast(valmax)) + ? max_value[i] + : static_cast(valmax); + } + } + WarpReduceMax(max_value); + + // compute sum + AccT sum[kBatchSize]{0.0}; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { +#pragma unroll + for (int it = 0; it < kIterationsV; ++it) { + T* srcptr_v = reinterpret_cast(&srcdata[i][it]); +#pragma unroll + for (int s = 0; s < kVSize; ++s) { + if (LogMode) { + sum[i] += std::exp(static_cast(srcptr_v[s]) - max_value[i]); + } else { + srcptr_v[s] = std::exp(static_cast(srcptr_v[s]) - max_value[i]); + sum[i] += static_cast(srcptr_v[s]); + } + } + } + } + WarpReduceSum(sum); + + // log_softmax and loss + AccT sumloss[kBatchSize]{0.0}; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + if (i >= local_batches) break; + + VecT* softmax_v = + reinterpret_cast(&softmax[(first_batch + i) * stride]); + + // max index to write + int idx_max = (i < local_batches) ? element_count : 0; + int idx_max_v = idx_max / kVSize; + + if (LogMode) { + sum[i] = std::log(sum[i]); + } +#pragma unroll + for (int it = 0; it < kIterationsV; ++it) { + T* srcvp = reinterpret_cast(&srcdata[i][it]); + T* labelvp = reinterpret_cast(&labeldata[i][it]); + VecT tmpv; + T* tmpvp = reinterpret_cast(&tmpv); +#pragma unroll + for (int s = 0; s < kVSize; ++s) { + if (LogMode) { + AccT logsoftmax = static_cast(srcvp[s]) - max_value[i] - sum[i]; + sumloss[i] -= logsoftmax * static_cast(labelvp[s]); + tmpvp[s] = std::exp(logsoftmax); + } else { + tmpvp[s] = static_cast(srcvp[s]) / sum[i]; + } + } + + int idx = threadIdx.x + it * kWarpSize; + if (idx < idx_max_v) { + softmax_v[idx] = tmpv; + } + } + } + + // loss + WarpReduceSum(sumloss); + + for (int i = 0; i < kBatchSize; i++) { + if (i >= local_batches) break; + loss[first_batch + i] = sumloss[i]; + } +} + +#define SOFTMAX_WARP_FORWARD_SOFT_CASE(Log2Elements, VecT, AccT) \ + case Log2Elements: \ + WarpSoftmaxForwardSoftLabel<<>>( \ + loss, softmax, src, label, batch_size, stride, element_count); \ + break; + +/* + Wrapper of softmax with cross entropy forward soft label. +*/ +template +void SwitchWarpSoftmaxForwardSoftLabel(const int blocks, const dim3 threads, + gpuStream_t stream, T* loss, T* softmax, + const T* src, const T* label, + const int batch_size, const int stride, + const int element_count, + const int log2_elements) { + using AccT = typename details::MPTypeTrait::Type; + switch (log2_elements) { + SOFTMAX_WARP_FORWARD_SOFT_CASE(0, T, AccT); + SOFTMAX_WARP_FORWARD_SOFT_CASE(1, T, AccT); + SOFTMAX_WARP_FORWARD_SOFT_CASE(2, T, AccT); + SOFTMAX_WARP_FORWARD_SOFT_CASE(3, T, AccT); + SOFTMAX_WARP_FORWARD_SOFT_CASE(4, T, AccT); + SOFTMAX_WARP_FORWARD_SOFT_CASE(5, T, AccT); + SOFTMAX_WARP_FORWARD_SOFT_CASE(6, T, AccT); + SOFTMAX_WARP_FORWARD_SOFT_CASE(7, T, AccT); + SOFTMAX_WARP_FORWARD_SOFT_CASE(8, T, AccT); + SOFTMAX_WARP_FORWARD_SOFT_CASE(9, T, AccT); + default: + break; + } +} + +template +static void SoftmaxWithCrossEntropySoftLabel( + const platform::CUDADeviceContext& ctx, const int rank, const int axis, + const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data, + int N, int dim, int D) { +#ifdef __HIPCC__ + constexpr int kMaxBlockDim = 256; +#else + constexpr int kMaxBlockDim = 512; +#endif + int64_t block_dim = dim >= kMaxBlockDim + ? kMaxBlockDim + : (1 << static_cast(std::log2(dim))); + + int64_t grid_dim = N * D; + constexpr int max_dim = 320; + + const int kDimLog2 = static_cast(Log2Ceil(dim)); + const int kDimCeil = 1 << kDimLog2; + auto stream = ctx.stream(); + + if (D == 1 && dim <= max_dim) { + int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + int batches_per_warp = (kDimCeil <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / kWarpSize); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (N + batches_per_block - 1) / batches_per_block; + dim3 threads(kWarpSize, warps_per_block, 1); + + SwitchWarpSoftmaxForwardSoftLabel(blocks, threads, stream, loss_data, + softmax_data, logits_data, labels_data, + N, dim, dim, kDimLog2); + + } else { + ScopedTensorDescriptor desc; + std::vector tensor_dims = {N, dim, D, 1}; + DataLayout layout = DataLayout::kNCHW; +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t descp = desc.descriptor(layout, tensor_dims); +#else + cudnnTensorDescriptor_t descp = desc.descriptor(layout, tensor_dims); +#endif + + auto handle = ctx.cudnn_handle(); + +#ifdef PADDLE_WITH_HIP + auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE + : MIOPEN_SOFTMAX_MODE_CHANNEL; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( + handle, platform::CudnnDataType::kOne(), descp, logits_data, + platform::CudnnDataType::kZero(), descp, softmax_data, + MIOPEN_SOFTMAX_LOG, mode)); +#else + auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE + : CUDNN_SOFTMAX_MODE_CHANNEL; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward( + handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), + descp, logits_data, platform::CudnnDataType::kZero(), descp, + softmax_data)); +#endif + + const int kDimLog2 = static_cast(Log2Ceil(dim)); + const int kDimCeil = 1 << kDimLog2; +#ifdef __HIPCC__ + int kThreadPerBlock = 256; +#else + int kThreadPerBlock = 512; +#endif + + int kBatchPerBlock = 1; + int blocks = (N * D + kBatchPerBlock - 1) / kBatchPerBlock; + dim3 threads(kThreadPerBlock / kBatchPerBlock, kBatchPerBlock, 1); + + CrossEntropySoftLabel<<>>( + loss_data, softmax_data, NULL, labels_data, N, dim, D, kDimLog2); + } +} + template __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, const T* loss_grad, @@ -560,373 +923,6 @@ __global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad, } } -static __device__ __forceinline__ platform::float16 exp_on_device( - platform::float16 x) { - return ::Eigen::numext::exp(x); -} -static __device__ __forceinline__ float exp_on_device(float x) { - return expf(x); -} -static __device__ __forceinline__ double exp_on_device(double x) { - return exp(x); -} -static __device__ __forceinline__ platform::float16 log_on_device( - platform::float16 x) { - return math::TolerableValue()(::Eigen::numext::log(x)); -} -static __device__ __forceinline__ float log_on_device(float x) { - return math::TolerableValue()(logf(x)); -} -static __device__ __forceinline__ double log_on_device(double x) { - return math::TolerableValue()(log(x)); -} - -/** In the following codes, 3 CUDA kernels are implemented to calculate softmax - * and loss **/ -/* - Supposing the x is `logits` and y is `labels`, the equations are as -followings: - cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})] - = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})] - = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})] - = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)] - = \sum_{j}(-y_i_j * tmp_i_j) - softmax_i_j = e^{tmp_i_j} -where: - max_i = \max_{j}{x_i_j} - logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i} - tmp_i_j = x_i_j - max_i - logDiffMaxSum_i -Therefore, the calculation can be separated into 3 steps: -Step 1: row-wise operation to calculate max_i -Step 2: row-wise operation to calculate logDiffMaxSum_i -Step 3: calculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i -To save memory, we can share memory among max_i, logDiffMaxSum_i and -cross\_entropy_i. -In this way, the 3 steps should be changed to: -Step 1 (RowReductionForMax): row-wise operation to calculate max_i -Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j = -x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i -Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j -- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i -*/ - -// There are 3 kinds of reduce algorithms in cub: -// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY -// BLOCK_REDUCE_RAKING -// BLOCK_REDUCE_WARP_REDUCTIONS (default) -template -using BlockReduce = - cub::BlockReduce; - -template -using BlockReduceTempStorage = typename BlockReduce::TempStorage; - -// Make sure that BlockDim <= axis_dim -// This kernel is used to calculate the max element of each row -template -static __global__ void RowReductionForMax(const T* logits_data, T* max_data, - int64_t d, int axis_dim) { - __shared__ BlockReduceTempStorage temp_storage; - - // logits_data view as [n, axis_dim, remain] - // max_data view as [n, 1, remain] - // blockDim = n * remain, split blockIdx to idx_n and idx_remain - int64_t remain = d / axis_dim; - int64_t idx_n = blockIdx.x / remain; - int64_t idx_remain = blockIdx.x % remain; - int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; - int64_t end_idx = (idx_n + 1) * d; - - int64_t step = BlockDim * remain; - T cur_max = logits_data[beg_idx]; - beg_idx += step; - while (beg_idx < end_idx) { - if (cur_max < logits_data[beg_idx]) { - cur_max = logits_data[beg_idx]; - } - beg_idx += step; - } - - cur_max = BlockReduce(temp_storage).Reduce(cur_max, cub::Max()); - - if (threadIdx.x == 0) max_data[blockIdx.x] = cur_max; -} - -// Make sure that BlockDim <= axis_dim -template -static __global__ void RowReductionForDiffMaxSum(const T* logits_data, - T* max_data, T* softmax, - int64_t d, int axis_dim) { - __shared__ BlockReduceTempStorage temp_storage; - - // logits, softmax data view as [n, axis_dim, remain] - // max_data view as [n, 1, remain] - // blockDim = n * remain, split blockIdx to idx_n and idx_remain - int64_t remain = d / axis_dim; - int64_t idx_n = blockIdx.x / remain; - int64_t idx_remain = blockIdx.x % remain; - int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; - int64_t end_idx = (idx_n + 1) * d; - - auto block_max = max_data[blockIdx.x]; - int64_t step = BlockDim * remain; - - // In numeric stable mode softmax_with_loss, we calc loss with - // tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of - // log(exp(x_i_j - max_i)/DiffMaxSum_i). Therefore, log(0) will not occur. - // Also we calc softmax_i_j = e^{tmp_i_j}, the maximum and minimum value will - // be 1.0 and 0.0, represent prob is 1.0 and 0.0. - // So there is no need to clip on shift_softmax. - softmax[beg_idx] = logits_data[beg_idx] - block_max; - T diff_max_sum = exp_on_device(softmax[beg_idx]); - auto idx = beg_idx + step; - while (idx < end_idx) { - softmax[idx] = logits_data[idx] - block_max; - diff_max_sum += exp_on_device(softmax[idx]); - idx += step; - } - - diff_max_sum = - BlockReduce(temp_storage).Reduce(diff_max_sum, cub::Sum()); - if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum); - - if (!CalculateLogSoftmax) return; - __syncthreads(); - diff_max_sum = max_data[blockIdx.x]; - softmax[beg_idx] -= diff_max_sum; - beg_idx += step; - while (beg_idx < end_idx) { - softmax[beg_idx] -= diff_max_sum; - beg_idx += step; - } - - // Note(zhiqiu): since different threads may use max_data[blockIdx.x] to - // calculate diff_max_sum, __syncthreads() is needed here. - __syncthreads(); - if (threadIdx.x == 0) max_data[blockIdx.x] = 0; -} - -#ifdef __HIPCC__ // @{ HIP Seperate Kernel for RowReductionForDiffMaxSum -// Note(qili93): HIP do not support return in kernel, need to seperate -// RowReductionForDiffMaxSum into two kernels below -template -static __global__ void RowReductionForSum(const T* logits_data, T* max_data, - T* softmax, int64_t d, int axis_dim) { - __shared__ BlockReduceTempStorage temp_storage; - - int64_t remain = d / axis_dim; - int64_t idx_n = blockIdx.x / remain; - int64_t idx_remain = blockIdx.x % remain; - int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; - int64_t end_idx = (idx_n + 1) * d; - - auto block_max = max_data[blockIdx.x]; - int64_t step = BlockDim * remain; - - softmax[beg_idx] = logits_data[beg_idx] - block_max; - T diff_max_sum = exp_on_device(softmax[beg_idx]); - auto idx = beg_idx + step; - while (idx < end_idx) { - softmax[idx] = logits_data[idx] - block_max; - diff_max_sum += exp_on_device(softmax[idx]); - idx += step; - } - - diff_max_sum = - BlockReduce(temp_storage).Reduce(diff_max_sum, cub::Sum()); - if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum); -} - -template -static __global__ void RowReductionForDiff(const T* logits_data, T* max_data, - T* softmax, int d, int axis_dim) { - int remain = d / axis_dim; - int idx_n = blockIdx.x / remain; - int idx_remain = blockIdx.x % remain; - int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; - int end_idx = (idx_n + 1) * d; - int step = BlockDim * remain; - - T diff_max_sum = max_data[blockIdx.x]; - softmax[beg_idx] -= diff_max_sum; - beg_idx += step; - while (beg_idx < end_idx) { - softmax[beg_idx] -= diff_max_sum; - beg_idx += step; - } - - __syncthreads(); - if (threadIdx.x == 0) max_data[blockIdx.x] = 0; -} -#endif // @} End HIP Seperate Kernel for RowReductionForDiffMaxSum - -// Make sure that BlockDim <= axis_dim -template -static __global__ void RowReductionForSoftmaxAndCrossEntropy( - const T* logits_data, const T* labels_data, T* loss_data, T* softmax, - int64_t d, int axis_dim) { - __shared__ BlockReduceTempStorage temp_storage; - - // logits, softmax, labels data view as [n, axis_dim, remain] - // loss_data view as [n, 1, remain] - // blockDim = n * remain, split blockIdx to idx_n and idx_remain - int64_t remain = d / axis_dim; - int64_t idx_n = blockIdx.x / remain; - int64_t idx_remain = blockIdx.x % remain; - int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; - int64_t end_idx = (idx_n + 1) * d; - - // log_diff_max_sum shares memory with loss - auto block_log_diff_max_sum = loss_data[blockIdx.x]; - auto tmp = softmax[beg_idx] - block_log_diff_max_sum; - softmax[beg_idx] = exp_on_device(tmp); - auto loss = -labels_data[beg_idx] * tmp; - int64_t step = BlockDim * remain; - beg_idx += step; - while (beg_idx < end_idx) { - tmp = softmax[beg_idx] - block_log_diff_max_sum; - softmax[beg_idx] = exp_on_device(tmp); - loss -= (labels_data[beg_idx] * tmp); - beg_idx += step; - } - - loss = BlockReduce(temp_storage).Reduce(loss, cub::Sum()); - if (threadIdx.x == 0) loss_data[blockIdx.x] = loss; -} - -// Make sure that BlockDim <= axis_dim -template -static __global__ void RowReductionForCrossEntropy(const T* logits_data, - const T* labels_data, - T* loss_data, int d, - int axis_dim) { - __shared__ BlockReduceTempStorage temp_storage; - - // logits, softmax, labels data view as [n, axis_dim, remain] - // loss_data view as [n, 1, remain] - // blockDim = n * remain, split blockIdx to idx_n and idx_remain - int remain = d / axis_dim; - int idx_n = blockIdx.x / remain; - int idx_remain = blockIdx.x % remain; - int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; - int end_idx = (idx_n + 1) * d; - - // log_diff_max_sum shares memory with loss - auto block_log_diff_max_sum = loss_data[blockIdx.x]; - auto tmp = log_on_device(logits_data[beg_idx]); // when not with softmax, - // softmax is stored in - // logits_data - auto loss = -labels_data[beg_idx] * tmp; - int step = BlockDim * remain; - beg_idx += step; - while (beg_idx < end_idx) { - tmp = log_on_device(logits_data[beg_idx]); // when not with softmax, - // softmax is stored in - // logits_data - loss -= (labels_data[beg_idx] * tmp); - beg_idx += step; - } - - loss = BlockReduce(temp_storage).Reduce(loss, cub::Sum()); - if (threadIdx.x == 0) loss_data[blockIdx.x] = loss; -} - -template -static void SoftmaxWithCrossEntropyFusedKernel( - const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data, - int64_t n, int64_t d, int axis_dim, gpuStream_t stream) { -#ifdef __HIPCC__ - constexpr int kMaxBlockDim = 256; -#else - constexpr int kMaxBlockDim = 512; -#endif - int64_t block_dim = axis_dim >= kMaxBlockDim - ? kMaxBlockDim - : (1 << static_cast(std::log2(axis_dim))); - int64_t grid_dim = n * d / axis_dim; -#ifdef __HIPCC__ -#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ - case BlockDim: \ - hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax), \ - dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \ - loss_data, d, axis_dim); \ - hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum), \ - dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \ - loss_data, softmax_data, d, axis_dim); \ - hipLaunchKernelGGL( \ - HIP_KERNEL_NAME(RowReductionForSoftmaxAndCrossEntropy), \ - dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, labels_data, \ - loss_data, softmax_data, d, axis_dim); \ - break -#else -#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ - case BlockDim: \ - RowReductionForMax<<>>( \ - logits_data, loss_data, d, axis_dim); \ - RowReductionForDiffMaxSum<<>>( \ - logits_data, loss_data, softmax_data, d, axis_dim); \ - RowReductionForSoftmaxAndCrossEntropy< \ - T, BlockDim><<>>( \ - logits_data, labels_data, loss_data, softmax_data, d, axis_dim); \ - break -#endif - - switch (block_dim) { - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); - default: - PADDLE_THROW(platform::errors::Unavailable( - "Block Dimension must be 2^n in softmax_with_cross_entropy_op.")); - break; - } - -#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL -} - -// not with softmax -template -static void CrossEntropyFusedKernel(const T* logits_data, const T* labels_data, - T* loss_data, int n, int d, int axis_dim, - gpuStream_t stream) { - constexpr int kMaxBlockDim = 512; - int block_dim = axis_dim >= kMaxBlockDim - ? kMaxBlockDim - : (1 << static_cast(std::log2(axis_dim))); - int grid_dim = n * d / axis_dim; - -#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ - case BlockDim: \ - RowReductionForCrossEntropy<<>>( \ - logits_data, labels_data, loss_data, d, axis_dim); \ - break - - switch (block_dim) { - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); - CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); - default: - PADDLE_THROW(platform::errors::Unavailable( - "Block Dimension must be 2^n in softmax_with_cross_entropy_op.")); - break; - } - -#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL -} - template class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { public: @@ -983,9 +979,22 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { if (soft_label) { auto* logits_data = softmax->data(); auto* labels_data = labels->data(); - CrossEntropyFusedKernel(logits_data, labels_data, loss_data, n, d, - axis_dim, - context.cuda_device_context().stream()); + + const int kDimLog2 = static_cast(Log2Ceil(axis_dim)); + const int kDimCeil = 1 << kDimLog2; +#ifdef __HIPCC__ + int kThreadPerBlock = 256; +#else + int kThreadPerBlock = 512; +#endif + int kBatchPerBlock = 1; + int blocks = (n * d + kBatchPerBlock - 1) / kBatchPerBlock; + dim3 threads(kThreadPerBlock / kBatchPerBlock, kBatchPerBlock, 1); + + CrossEntropySoftLabel<<< + blocks, threads, 0, context.cuda_device_context().stream()>>>( + loss_data, NULL, logits_data, labels_data, n, axis_dim, + d / axis_dim, kDimLog2); } else { // HardLabel auto* logits_data = softmax->data(); auto* labels_data = labels->data(); @@ -1040,9 +1049,9 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { if (soft_label) { auto* logits_data = logits->data(); auto* labels_data = labels->data(); - SoftmaxWithCrossEntropyFusedKernel( - logits_data, labels_data, softmax_data, loss_data, n, d, axis_dim, - context.cuda_device_context().stream()); + SoftmaxWithCrossEntropySoftLabel( + context.cuda_device_context(), rank, axis, logits_data, labels_data, + softmax_data, loss_data, n, axis_dim, d / axis_dim); } else { if (!context.Attr("numeric_stable_mode")) { // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim @@ -1103,7 +1112,11 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { const int64_t d = SizeFromAxis(axis, logit_grad->dims()); const int64_t remain = d / axis_dim; +#ifdef __HIPCC__ + int block = 256; +#else int block = 512; +#endif auto stream = context.cuda_device_context().stream(); auto ignore_index = context.Attr("ignore_index"); auto use_softmax = context.Attr("use_softmax"); -- GitLab