From bbe5228ca812b27b4c1514d928d894b61fd5a543 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Fri, 25 Feb 2022 12:27:57 +0800 Subject: [PATCH] Optimize perf of softmax_with_cross_entropy (#39553) * Optimize perf of softmax_with_cross_entropy * fix * fix * fix accuracy error --- .../softmax_with_cross_entropy_op.cu | 296 +++++++++++++++++- 1 file changed, 289 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index fd035df768d..92e2adb3ee8 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -27,6 +27,8 @@ namespace cub = hipcub; namespace paddle { namespace operators { +#define ALIGN_BYTES 16 + using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using DataLayout = platform::DataLayout; using Tensor = framework::Tensor; @@ -47,6 +49,18 @@ static __device__ __forceinline__ T Exp(T x) { return math::TolerableValue()(static_cast(expx)); } +template +struct ExpAddFunctor { + HOSTDEVICE inline ExpAddFunctor(Tx max) : max(max) {} + + HOSTDEVICE inline Ty operator()(const Tx& sum, const Tx& x) const { + return static_cast(sum + std::exp(x - max)); + } + + private: + Tx max; +}; + // log2(value) static inline int Log2Ceil(int value) { int log2_value = 0; @@ -419,10 +433,272 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, } } +template +__device__ __forceinline__ void ComputeLoss(T* loss, const T loss_value, + const int label_id, + const int64_t label_value, + const int tid, const int vec_size, + const int offset, + const int ignore_index) { + int loss_id = vec_size * tid + offset; + if (IgnoreIndex) { + if (label_value == loss_id) { + if (label_value == ignore_index) { + loss[label_id] = static_cast(0.0f); + } else { + loss[label_id] = loss_value; + } + } + } else { + if (label_value == loss_id) { + loss[label_id] = loss_value; + } + } +} + +template +__device__ __forceinline__ AccT ThreadReduce(const T* input, int size, + const int offset, AccT init, + ReduceFunctor reducer) { + using VecT = kps::details::VectorType; + int tid = threadIdx.x; + AccT val = init; + + if (offset > 0) { + input -= offset; + size += offset; + if (tid >= offset) { + val = reducer(val, input[tid]); + } + size -= blockDim.x; + input += blockDim.x; + } + int remain = size % (VecSize * blockDim.x); + + T ins[VecSize]; + VecT* ins_vec = reinterpret_cast(&ins); + + // vector part + for (; VecSize * tid < (size - remain); tid += blockDim.x) { + *ins_vec = reinterpret_cast(input)[tid]; + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + val = reducer(val, ins[i]); + } + } + + // scalar part + tid = size - remain + threadIdx.x; + for (; tid < size; tid += blockDim.x) { + val = reducer(val, input[tid]); + } + return val; +} + +template +__device__ __forceinline__ void VectorizedSoftmaxForwardImpl( + T* loss, T* softmax, const T* logits, const LabelT* label, int size, + const int offset, const LogSoftmaxForwardFunctor& func, + const int ignore_index) { + using VecT = kps::details::VectorType; + int tid = threadIdx.x; + int label_id = blockIdx.x; + auto label_value = static_cast(label[label_id]); + const bool label_valid = label_value >= 0 && label_value < size; + int loss_id_offset = 0; + + if (offset > 0) { + logits -= offset; + softmax -= offset; + size += offset; + loss_id_offset -= offset; + if (tid >= offset) { + AccT log_softmax = func(static_cast(logits[tid])); + softmax[tid] = static_cast(std::exp(log_softmax)); + // loss + if (label_valid) { + ComputeLoss(loss, static_cast(-log_softmax), + label_id, label_value, tid, 1, + loss_id_offset, ignore_index); + } + } + size -= blockDim.x; + logits += blockDim.x; + softmax += blockDim.x; + loss_id_offset += blockDim.x; + } + int remain = size % (VecSize * blockDim.x); + + T ins[VecSize]; + T outs[VecSize]; + VecT* ins_vec = reinterpret_cast(&ins); + VecT* outs_vec = reinterpret_cast(&outs); + + // vector part + for (; VecSize * tid < (size - remain); tid += blockDim.x) { + // read + *ins_vec = reinterpret_cast(logits)[tid]; + +#pragma unroll + // compute + for (int i = 0; i < VecSize; ++i) { + AccT log_softmax = func(static_cast(ins[i])); + outs[i] = static_cast(std::exp(log_softmax)); + + // loss + if (label_valid) { + ComputeLoss(loss, static_cast(-log_softmax), + label_id, label_value, tid, VecSize, + loss_id_offset + i, ignore_index); + } + } + + // write + reinterpret_cast(softmax)[tid] = *outs_vec; + } + + // scalar part + tid = size - remain + threadIdx.x; + for (; tid < size; tid += blockDim.x) { + AccT log_softmax = func(static_cast(logits[tid])); + softmax[tid] = static_cast(std::exp(log_softmax)); + + // loss + if (label_valid) { + ComputeLoss(loss, static_cast(-log_softmax), label_id, + label_value, tid, 1, loss_id_offset, + ignore_index); + } + } + + // invalid label, write once + if (!label_valid && threadIdx.x == 0) { + loss[label_id] = static_cast(0.0f); + } +} + +template +__device__ __forceinline__ void ScalarSoftmaxForwardImpl( + T* loss, T* softmax, const T* logits, const LabelT* label, const int size, + const LogSoftmaxForwardFunctor& func, const int ignore_index) { + int tid = threadIdx.x; + int remain = size % (VecSize * blockDim.x); + int label_id = blockIdx.x; + auto label_value = static_cast(label[label_id]); + const bool label_valid = label_value >= 0 && label_value < size; + + // main part + for (; tid < (size - remain); tid += VecSize * blockDim.x) { + T ins[VecSize]; + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + ins[i] = logits[tid + i * blockDim.x]; + } +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + AccT log_softmax = func(static_cast(ins[i])); + softmax[tid + i * blockDim.x] = static_cast(std::exp(log_softmax)); + // loss + if (label_valid) { + ComputeLoss(loss, static_cast(-log_softmax), + label_id, label_value, tid, VecSize, i, + ignore_index); + } + } + } + + // tail part + for (; tid < size; tid += blockDim.x) { + AccT log_softmax = func(static_cast(logits[tid])); + softmax[tid] = static_cast(std::exp(log_softmax)); + // loss + if (label_valid) { + ComputeLoss(loss, static_cast(-log_softmax), label_id, + label_value, tid, 1, 0, ignore_index); + } + } + + // invalid label, write once + if (!label_valid && threadIdx.x == 0) { + loss[label_id] = static_cast(0.0f); + } +} + +template +__global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits, + const LabelT* label, + const int high_dim, const int mid_dim, + const int ignore_index) { + using VecT = kps::details::VectorType; + + // each block deal with one batch + logits += blockIdx.x * mid_dim; + softmax += blockIdx.x * mid_dim; + + const int input_offset = ((uint64_t)logits) % ALIGN_BYTES / sizeof(T); + const int output_offset = ((uint64_t)softmax) % ALIGN_BYTES / sizeof(T); + + // 1. reduce max + AccT max = ThreadReduce>( + logits, mid_dim, input_offset, -std::numeric_limits::infinity(), + kps::MaxFunctor()); + max = kps::details::BlockXReduce>( + max, kps::MaxFunctor()); + + // 2. reduce sum + AccT sum = ThreadReduce>( + logits, mid_dim, input_offset, static_cast(0), + ExpAddFunctor(max)); + sum = kps::details::BlockXReduce>( + sum, kps::AddFunctor()); + + // 3. softmax + LogSoftmaxForwardFunctor func(max, sum); + if (input_offset == output_offset) { + VectorizedSoftmaxForwardImpl( + loss, softmax, logits, label, mid_dim, input_offset, func, + ignore_index); + } else { + ScalarSoftmaxForwardImpl( + loss, softmax, logits, label, mid_dim, func, ignore_index); + } +} + +template +void LaunchVectorizedSoftmaxForward(T* loss, T* softmax, const T* logits, + const LabelT* label, const int high_dim, + const int mid_dim, const int ignore_index, + gpuStream_t stream) { + using AccT = typename details::MPTypeTrait::Type; + constexpr int vec_size = sizeof(float4) / sizeof(T); + const int max_num_threads = 1024; + int max_block_size = std::min(mid_dim / vec_size, max_num_threads); + if (vec_size > 1) { + max_block_size /= 2; + } + + int block_size = 1; + while (block_size < max_block_size) { + block_size *= 2; + } + block_size = std::max(block_size, kps::details::kWarpSize); + dim3 grids(high_dim); + dim3 blocks(block_size); + VectorizedSoftmaxForward<<>>( + loss, softmax, logits, label, high_dim, mid_dim, ignore_index); +} + /* Wrapper of softmax with cross entropy hard label. - - SwitchWarpSoftmaxForward for small size - - cudnn function for large size + - SwitchWarpSoftmaxForward for small size when axis == -1 + - LaunchVectorizedSoftmaxForward for large size when axis == -1 + - cudnn function for axis != -1 */ template static void SoftmaxWithCrossEntropyHardLabel( @@ -431,11 +707,17 @@ static void SoftmaxWithCrossEntropyHardLabel( T* softmax_data, int N, int dim, int D, const int ignore_index) { auto stream = ctx.stream(); constexpr int max_dim = 320; - if (D == 1 && dim <= max_dim) { // small size - const SoftmaxMode mode = SoftmaxMode::kCrossEntropy; - SwitchWarpSoftmaxForward( - loss_data, softmax_data, logits_data, labels_data, N, dim, dim, - ignore_index, stream); + if (D == 1) { + if (dim <= max_dim) { // small size + const SoftmaxMode mode = SoftmaxMode::kCrossEntropy; + SwitchWarpSoftmaxForward( + loss_data, softmax_data, logits_data, labels_data, N, dim, dim, + ignore_index, stream); + } else { // large size + LaunchVectorizedSoftmaxForward( + loss_data, softmax_data, logits_data, labels_data, N, dim, + ignore_index, stream); + } } else { ScopedTensorDescriptor desc; std::vector tensor_dims = {N, dim, D, 1}; -- GitLab