From 1b4515f6dbd4a273c404a0c6668e07105bc35a34 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 6 Aug 2018 09:31:31 +0000 Subject: [PATCH] refine softmax_with_cross_entropy --- .../softmax_with_cross_entropy_op.cu | 218 +++++++++++++++++- 1 file changed, 209 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 8f7840cee1..a559b01ed3 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ limitations under the License. */ #define EIGEN_USE_GPU +#include +#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" namespace paddle { @@ -53,8 +55,196 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, logit_grad[ids] = loss_grad[row_ids] * (logit_grad[ids] - labels[ids]); } } + } // namespace +static __device__ __forceinline__ float real_exp(float x) { return expf(x); } +static __device__ __forceinline__ double real_exp(double x) { return exp(x); } +static __device__ __forceinline__ float real_log(float x) { + return math::TolerableValue()(logf(x)); +} +static __device__ __forceinline__ double real_log(double x) { + return math::TolerableValue()(log(x)); +} + +/** In the following codes, 3 CUDA kernels are implemented to calculate softmax + * and loss **/ +/* + Supposing the x is `logits` and y is `labels`, the equations are as +followings: + + cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})] + = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})] + = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})] + = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)] + = \sum_{j}(-y_i_j * tmp_i_j) + + softmax_i_j = e^{tmp_i_j} + +where: + max_i = \max_{j}{x_i_j} + logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i} + tmp_i_j = x_i_j - max_i - logDiffMaxSum_i + +Therefore, the calculation can be separated into 3 steps: +Step 1: row-wise operation to calculate max_i +Step 2: row-wise operation to calculate logDiffMaxSum_i +Step 3: caculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i + +To save memory, we can share memory among max_i, logDiffMaxSum_i and +cross\_entropy_i. +In this way, the 3 steps should be changed to: +Step 1 (RowReductionForMax): row-wise operation to calculate max_i +Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j = +x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i +Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j +- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i +*/ + +// There are 3 kinds of reduce algorithms in cub: +// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY +// BLOCK_REDUCE_RAKING +// BLOCK_REDUCE_WARP_REDUCTIONS (default) +template +using BlockReduce = + cub::BlockReduce; + +template +using BlockReduceTempStorage = typename BlockReduce::TempStorage; + +// Make sure that BlockDim <= feature_size +// This kernel is used to calculate the max element of each row +template +__global__ void RowReductionForMax(const T* logits_data, T* max_data, + int feature_size) { + __shared__ BlockReduceTempStorage temp_storage; + + auto beg_idx = feature_size * blockIdx.x + threadIdx.x; + auto end_idx = feature_size * (blockIdx.x + 1); + + T cur_max = logits_data[beg_idx]; + beg_idx += BlockDim; + while (beg_idx < end_idx) { + if (cur_max < logits_data[beg_idx]) { + cur_max = logits_data[beg_idx]; + } + beg_idx += BlockDim; + } + + cur_max = BlockReduce(temp_storage).Reduce(cur_max, cub::Max()); + + if (threadIdx.x == 0) { + max_data[blockIdx.x] = cur_max < -64 ? -64 : cur_max; + } +} + +// Make sure that BlockDim <= feature_size +template +__global__ void RowReductionForDiffMaxSum(const T* logits_data, T* max_data, + T* softmax, int feature_size) { + __shared__ BlockReduceTempStorage temp_storage; + + auto beg_idx = feature_size * blockIdx.x + threadIdx.x; + auto end_idx = feature_size * (blockIdx.x + 1); + + auto block_max = max_data[blockIdx.x]; + + softmax[beg_idx] = logits_data[beg_idx] - block_max; + T diff_max_sum = real_exp(softmax[beg_idx]); + beg_idx += BlockDim; + while (beg_idx < end_idx) { + softmax[beg_idx] = logits_data[beg_idx] - block_max; + diff_max_sum += real_exp(softmax[beg_idx]); + beg_idx += BlockDim; + } + + diff_max_sum = + BlockReduce(temp_storage).Reduce(diff_max_sum, cub::Sum()); + if (threadIdx.x == 0) max_data[blockIdx.x] = real_log(diff_max_sum); +} + +// Make sure that BlockDim <= feature_size +template +__global__ void RowReductionForSoftmaxAndCrossEntropy(const T* logits_data, + const T* labels_data, + T* loss_data, T* softmax, + int feature_size) { + __shared__ BlockReduceTempStorage temp_storage; + + auto beg_idx = feature_size * blockIdx.x + threadIdx.x; + auto end_idx = feature_size * (blockIdx.x + 1); + + // log_diff_max_sum shares memory with loss + auto block_log_diff_max_sum = loss_data[blockIdx.x]; + auto tmp = softmax[beg_idx] - block_log_diff_max_sum; + softmax[beg_idx] = real_exp(tmp); + auto loss = -labels_data[beg_idx] * tmp; + beg_idx += BlockDim; + while (beg_idx < end_idx) { + tmp = softmax[beg_idx] - block_log_diff_max_sum; + softmax[beg_idx] = real_exp(tmp); + loss -= (labels_data[beg_idx] * tmp); + beg_idx += BlockDim; + } + + loss = BlockReduce(temp_storage).Reduce(loss, cub::Sum()); + if (threadIdx.x == 0) loss_data[blockIdx.x] = loss; +} + +template +__global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out, int batch_size) { + auto idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < batch_size) out[idx] = static_cast(1); +} + +template +static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data, + const T* labels_data, + T* softmax_data, T* loss_data, + int batch_size, int feature_size, + cudaStream_t stream) { + constexpr int kMaxBlockDim = 512; + int block_dim = feature_size >= kMaxBlockDim + ? kMaxBlockDim + : (1 << static_cast(std::log2(feature_size))); + +#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ + case BlockDim: \ + RowReductionForMax<<>>( \ + logits_data, loss_data, feature_size); \ + RowReductionForDiffMaxSum<<>>( \ + logits_data, loss_data, softmax_data, feature_size); \ + RowReductionForSoftmaxAndCrossEntropy< \ + T, BlockDim><<>>( \ + logits_data, labels_data, loss_data, softmax_data, feature_size); \ + break + + switch (block_dim) { + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); + case 1: + SetSoftmaxToOneWhenFeatureSizeIsOne<<<(batch_size + kMaxBlockDim - 1) / + kMaxBlockDim, + kMaxBlockDim, 0, stream>>>( + softmax_data, batch_size); + cudaMemsetAsync(loss_data, 0, batch_size, stream); + break; + default: + PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op"); + break; + } + +#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL +} + template class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { public: @@ -66,14 +256,24 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { Tensor* softmax = context.Output("Softmax"); Tensor* loss = context.Output("Loss"); - softmax->mutable_data(context.GetPlace()); - loss->mutable_data(context.GetPlace()); - - math::SoftmaxFunctor()( - context.cuda_device_context(), logits, softmax); - math::CrossEntropyFunctor()( - context.cuda_device_context(), loss, softmax, labels, - context.Attr("soft_label")); + auto* softmax_data = softmax->mutable_data(context.GetPlace()); + auto* loss_data = loss->mutable_data(context.GetPlace()); + + auto soft_label = context.Attr("soft_label"); + if (soft_label) { + int batch_size = logits->dims()[0]; + int feature_size = logits->dims()[1]; + auto* logits_data = logits->data(); + auto* labels_data = labels->data(); + SoftmaxWithCrossEntropyFusedKernel( + logits_data, labels_data, softmax_data, loss_data, batch_size, + feature_size, context.cuda_device_context().stream()); + } else { + math::SoftmaxCUDNNFunctor()(context.cuda_device_context(), logits, + softmax); + math::CrossEntropyFunctor()( + context.cuda_device_context(), loss, softmax, labels, false); + } } }; -- GitLab