/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU #include #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; namespace { template __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, const int batch_size, const int class_num) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += blockDim.x * gridDim.x) { int idx = i * class_num + labels[i]; logit_grad[idx] -= static_cast(1.); } } template __global__ void Scale(T* logit_grad, const T* loss_grad, const int num, const int class_num) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) { logit_grad[i] *= loss_grad[i / class_num]; } } template __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, const T* loss_grad, const T* labels, const int batch_size, const int class_num) { int ids = blockIdx.x * blockDim.x + threadIdx.x; if (ids < batch_size * class_num) { int row_ids = ids / class_num; logit_grad[ids] = loss_grad[row_ids] * (logit_grad[ids] - labels[ids]); } } } // namespace static __device__ __forceinline__ float real_exp(float x) { return expf(x); } static __device__ __forceinline__ double real_exp(double x) { return exp(x); } static __device__ __forceinline__ float real_log(float x) { return math::TolerableValue()(logf(x)); } static __device__ __forceinline__ double real_log(double x) { return math::TolerableValue()(log(x)); } /** In the following codes, 3 CUDA kernels are implemented to calculate softmax * and loss **/ /* Supposing the x is `logits` and y is `labels`, the equations are as followings: cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})] = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})] = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})] = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)] = \sum_{j}(-y_i_j * tmp_i_j) softmax_i_j = e^{tmp_i_j} where: max_i = \max_{j}{x_i_j} logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i} tmp_i_j = x_i_j - max_i - logDiffMaxSum_i Therefore, the calculation can be separated into 3 steps: Step 1: row-wise operation to calculate max_i Step 2: row-wise operation to calculate logDiffMaxSum_i Step 3: caculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i To save memory, we can share memory among max_i, logDiffMaxSum_i and cross\_entropy_i. In this way, the 3 steps should be changed to: Step 1 (RowReductionForMax): row-wise operation to calculate max_i Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j = x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j - logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i */ // There are 3 kinds of reduce algorithms in cub: // BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY // BLOCK_REDUCE_RAKING // BLOCK_REDUCE_WARP_REDUCTIONS (default) template using BlockReduce = cub::BlockReduce; template using BlockReduceTempStorage = typename BlockReduce::TempStorage; // Make sure that BlockDim <= feature_size // This kernel is used to calculate the max element of each row template __global__ void RowReductionForMax(const T* logits_data, T* max_data, int feature_size) { __shared__ BlockReduceTempStorage temp_storage; auto beg_idx = feature_size * blockIdx.x + threadIdx.x; auto end_idx = feature_size * (blockIdx.x + 1); T cur_max = logits_data[beg_idx]; beg_idx += BlockDim; while (beg_idx < end_idx) { if (cur_max < logits_data[beg_idx]) { cur_max = logits_data[beg_idx]; } beg_idx += BlockDim; } cur_max = BlockReduce(temp_storage).Reduce(cur_max, cub::Max()); if (threadIdx.x == 0) { max_data[blockIdx.x] = cur_max < -64 ? -64 : cur_max; } } // Make sure that BlockDim <= feature_size template __global__ void RowReductionForDiffMaxSum(const T* logits_data, T* max_data, T* softmax, int feature_size) { __shared__ BlockReduceTempStorage temp_storage; auto beg_idx = feature_size * blockIdx.x + threadIdx.x; auto end_idx = feature_size * (blockIdx.x + 1); auto block_max = max_data[blockIdx.x]; softmax[beg_idx] = logits_data[beg_idx] - block_max; T diff_max_sum = real_exp(softmax[beg_idx]); beg_idx += BlockDim; while (beg_idx < end_idx) { softmax[beg_idx] = logits_data[beg_idx] - block_max; diff_max_sum += real_exp(softmax[beg_idx]); beg_idx += BlockDim; } diff_max_sum = BlockReduce(temp_storage).Reduce(diff_max_sum, cub::Sum()); if (threadIdx.x == 0) max_data[blockIdx.x] = real_log(diff_max_sum); } // Make sure that BlockDim <= feature_size template __global__ void RowReductionForSoftmaxAndCrossEntropy(const T* logits_data, const T* labels_data, T* loss_data, T* softmax, int feature_size) { __shared__ BlockReduceTempStorage temp_storage; auto beg_idx = feature_size * blockIdx.x + threadIdx.x; auto end_idx = feature_size * (blockIdx.x + 1); // log_diff_max_sum shares memory with loss auto block_log_diff_max_sum = loss_data[blockIdx.x]; auto tmp = softmax[beg_idx] - block_log_diff_max_sum; softmax[beg_idx] = real_exp(tmp); auto loss = -labels_data[beg_idx] * tmp; beg_idx += BlockDim; while (beg_idx < end_idx) { tmp = softmax[beg_idx] - block_log_diff_max_sum; softmax[beg_idx] = real_exp(tmp); loss -= (labels_data[beg_idx] * tmp); beg_idx += BlockDim; } loss = BlockReduce(temp_storage).Reduce(loss, cub::Sum()); if (threadIdx.x == 0) loss_data[blockIdx.x] = loss; } template __global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out, int batch_size) { auto idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < batch_size) out[idx] = static_cast(1); } template static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data, int batch_size, int feature_size, cudaStream_t stream) { constexpr int kMaxBlockDim = 512; int block_dim = feature_size >= kMaxBlockDim ? kMaxBlockDim : (1 << static_cast(std::log2(feature_size))); #define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ case BlockDim: \ RowReductionForMax<<>>( \ logits_data, loss_data, feature_size); \ RowReductionForDiffMaxSum<<>>( \ logits_data, loss_data, softmax_data, feature_size); \ RowReductionForSoftmaxAndCrossEntropy< \ T, BlockDim><<>>( \ logits_data, labels_data, loss_data, softmax_data, feature_size); \ break switch (block_dim) { CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); case 1: SetSoftmaxToOneWhenFeatureSizeIsOne<<<(batch_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0, stream>>>( softmax_data, batch_size); cudaMemsetAsync(loss_data, 0, batch_size, stream); break; default: PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op"); break; } #undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL } template class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), "This kernel only runs on GPU device."); const Tensor* logits = context.Input("Logits"); const Tensor* labels = context.Input("Label"); Tensor* softmax = context.Output("Softmax"); Tensor* loss = context.Output("Loss"); auto* softmax_data = softmax->mutable_data(context.GetPlace()); auto* loss_data = loss->mutable_data(context.GetPlace()); auto soft_label = context.Attr("soft_label"); if (soft_label) { int batch_size = logits->dims()[0]; int feature_size = logits->dims()[1]; auto* logits_data = logits->data(); auto* labels_data = labels->data(); SoftmaxWithCrossEntropyFusedKernel( logits_data, labels_data, softmax_data, loss_data, batch_size, feature_size, context.cuda_device_context().stream()); } else { math::SoftmaxCUDNNFunctor()(context.cuda_device_context(), logits, softmax); math::CrossEntropyFunctor()( context.cuda_device_context(), loss, softmax, labels, false); } } }; template class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), "This kernel only runs on GPU device."); const Tensor* labels = context.Input("Label"); const T* loss_grad_data = context.Input(framework::GradVarName("Loss"))->data(); Tensor* logit_grad = context.Output(framework::GradVarName("Logits")); logit_grad->ShareDataWith(*context.Input("Softmax")); T* logit_grad_data = logit_grad->data(); const int batch_size = logit_grad->dims()[0]; const int class_num = logit_grad->dims()[1]; int block = 512; auto stream = context.cuda_device_context().stream(); if (context.Attr("soft_label")) { int grid = (batch_size * class_num + block - 1) / block; const T* label_data = labels->data(); SoftCrossEntropyGradientKernel<<>>( logit_grad_data, loss_grad_data, label_data, batch_size, class_num); } else { int grid = (batch_size + block - 1) / block; const int64_t* label_data = labels->data(); CrossEntropyGrad<<>>( logit_grad_data, label_data, batch_size, class_num); int num = batch_size * class_num; grid = (num + block - 1) / block; Scale<<>>(logit_grad_data, loss_grad_data, num, class_num); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyCUDAKernel, ops::SoftmaxWithCrossEntropyCUDAKernel); REGISTER_OP_CUDA_KERNEL(softmax_with_cross_entropy_grad, ops::SoftmaxWithCrossEntropyGradCUDAKernel, ops::SoftmaxWithCrossEntropyGradCUDAKernel);