From ef61da86bbcf283158ae6514dfdff077eba938a0 Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Wed, 8 Sep 2021 19:24:45 +0800 Subject: [PATCH] Refactor softmax_cudnn kernel impl for code reuse. (#35350) --- paddle/fluid/operators/softmax_cudnn_op.cu | 604 +---------------- paddle/fluid/operators/softmax_cudnn_op.cu.h | 642 +++++++++++++++++++ paddle/fluid/operators/softmax_impl.cuh | 47 -- 3 files changed, 649 insertions(+), 644 deletions(-) create mode 100644 paddle/fluid/operators/softmax_cudnn_op.cu.h delete mode 100755 paddle/fluid/operators/softmax_impl.cuh diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu b/paddle/fluid/operators/softmax_cudnn_op.cu index 83b7b78aaec..72c2e97c178 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu +++ b/paddle/fluid/operators/softmax_cudnn_op.cu @@ -13,438 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/math/math_cuda_utils.h" -#include "paddle/fluid/operators/softmax_impl.cuh" -#include "paddle/fluid/operators/softmax_op.h" -#include "paddle/fluid/platform/cuda_device_function.h" -#ifdef PADDLE_WITH_HIP -#include "paddle/fluid/platform/miopen_helper.h" -#else -#include "paddle/fluid/platform/cudnn_helper.h" -#endif - -namespace paddle { -namespace platform { -struct CUDAPlace; -struct float16; -} // namespace platform -} // namespace paddle +#include "paddle/fluid/operators/softmax_cudnn_op.cu.h" namespace paddle { namespace operators { -using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; -using DataLayout = platform::DataLayout; -using Tensor = framework::Tensor; - -// Vectorization trait 4 * sizeof(T) -template -class VecT4 {}; -template <> -class VecT4 { - public: - using Type = long4; -}; -template <> -class VecT4 { - public: - using Type = int4; -}; -template <> -class VecT4 { - public: - using Type = int2; -}; - -// Vectorization trait 2 * sizeof(T) -template -class VecT2 {}; -template <> -class VecT2 { - public: - using Type = int4; -}; -template <> -class VecT2 { - public: - using Type = int2; -}; -template <> -class VecT2 { - public: - using Type = int; -}; - -int static inline log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -/* -Core function of computing softmax forward for axis=-1. -The computation includes - - Compute maximum of batch: maxvalue_{i} = max_j src_{i,j} - - Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} } - - Compute: (a_{i,j} - maxvalue_{i}) / s_{i} -One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). -For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle -api to compute max (sum) in one warp. -*/ -template -__global__ void WarpSoftmaxForward(T* softmax, const T* src, - const int batch_size, const int stride, - const int element_count) { - constexpr int kDimCeil = 1 << Log2Elements; - constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; - constexpr int kVSize = sizeof(VecT) / sizeof(T); - constexpr int kIterations = kDimCeil / kWarpSize; - constexpr int kIterationsV = - (kIterations >= kVSize) ? (kIterations / kVSize) : 1; - constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1; - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; - - // max index to read - int idx_max_v[kBatchSize]; -#pragma unroll - for (int i = 0; i < kBatchSize; i++) { - int idx_max = ((i + first_batch) < batch_size) ? element_count : 0; - idx_max_v[i] = idx_max / kVSize; - } - - // read data from global memory - AccT srcdata[kBatchSize][kIterationsV][kVSize]; - -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { -// read data -#pragma unroll - for (int it = 0; it < kIterationsV; ++it) { - int src_idx = threadIdx.x + it * kWarpSize; - if (kVSize == 1) { - if (src_idx < idx_max_v[i]) { - srcdata[i][it][0] = - static_cast(src[(first_batch + i) * stride + src_idx]); - } else { - srcdata[i][it][0] = -std::numeric_limits::infinity(); - } - } else { - const VecT* src_v = - reinterpret_cast(&src[(first_batch + i) * stride]); - if (src_idx < idx_max_v[i]) { - VecT srctmp = src_v[src_idx]; - const T* srcinptr = reinterpret_cast(&srctmp); -#pragma unroll - for (int s = 0; s < kVSize; s++) { - srcdata[i][it][s] = static_cast(srcinptr[s]); - } - } else { -#pragma unroll - for (int s = 0; s < kVSize; s++) { - srcdata[i][it][s] = -std::numeric_limits::infinity(); - } - } - } - } - } - - // compute max value - AccT max_value[kBatchSize]; -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { - // it = 0 - AccT valmax = srcdata[i][0][0]; -#pragma unroll - for (int s = 1; s < kVSize; ++s) { - valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s]; - } - max_value[i] = valmax; - -// it = 1, 2, ... -#pragma unroll - for (int it = 1; it < kIterationsV; ++it) { - AccT valmax = srcdata[i][it][0]; -#pragma unroll - for (int s = 1; s < kVSize; ++s) { - valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s]; - } - max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax; - } - } - WarpReduceMax(max_value); - - // compute sum - AccT sum[kBatchSize]; -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { - // it = 0 - if (LogMode) { - sum[i] = std::exp(srcdata[i][0][0] - max_value[i]); - } else { - srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]); - sum[i] = srcdata[i][0][0]; - } -#pragma unroll - for (int s = 1; s < kVSize; ++s) { - if (LogMode) { - sum[i] += std::exp(srcdata[i][0][s] - max_value[i]); - } else { - srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]); - sum[i] += srcdata[i][0][s]; - } - } - -// it = 1, 2, ... -#pragma unroll - for (int it = 1; it < kIterationsV; ++it) { -#pragma unroll - for (int s = 0; s < kVSize; ++s) { - if (LogMode) { - sum[i] += std::exp(srcdata[i][it][s] - max_value[i]); - } else { - srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]); - sum[i] += srcdata[i][it][s]; - } - } - } - } - WarpReduceSum(sum); - -// write result to global memory -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { - if (LogMode) { - sum[i] = std::log(sum[i]); - } - -#pragma unroll - for (int it = 0; it < kIterationsV; ++it) { - int idx = threadIdx.x + it * kWarpSize; - if (kVSize == 1) { - if (idx < idx_max_v[i]) { - if (LogMode) { - softmax[(first_batch + i) * stride + idx] = - srcdata[i][it][0] - max_value[i] - sum[i]; - } else { - softmax[(first_batch + i) * stride + idx] = - srcdata[i][it][0] / sum[i]; - } - } else { - break; - } - } else { - VecT* softmax_v = - reinterpret_cast(&softmax[(first_batch + i) * stride]); - VecT tmpdata; - T* tmpptr = reinterpret_cast(&tmpdata); -#pragma unroll - for (int s = 0; s < kVSize; ++s) { - if (LogMode) { - tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i]; - } else { - tmpptr[s] = srcdata[i][it][s] / sum[i]; - } - } - - if (idx < idx_max_v[i]) { - softmax_v[idx] = tmpdata; - } else { - break; - } - } - } - } -} - -/* -Core function of computing softmax backward for axis=-1. -The computation includes - - Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j} - - Compute src_{i,j} * ( grad_{i,j}) - s_{i} ) -One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). -For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle -api to compute max (sum) in one warp. -*/ -template -__global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, - int batch_size, int stride, - int element_count) { - constexpr int kVSize = sizeof(VecT) / sizeof(T); - constexpr int kDimCeil = 1 << Log2Elements; - constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; - constexpr int kIterations = kDimCeil / kWarpSize; - constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1; - constexpr int kIterationsV = - (kIterations >= kVSize) ? (kIterations / kVSize) : 1; - int element_count_v = element_count / kVSize; - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; - int local_batches = batch_size - first_batch; - if (local_batches > kBatchSize) { - local_batches = kBatchSize; - } - - // read data from global memory - VecT src_reg[kBatchSize][kIterationsV]; - VecT grad_reg[kBatchSize][kIterationsV]; - - for (int i = 0; i < kBatchSize; ++i) { - const VecT* src_v = - reinterpret_cast(&src[(first_batch + i) * stride]); - const VecT* grad_v = - reinterpret_cast(&grad[(first_batch + i) * stride]); - - // max index to read - int idx_max = (i < local_batches) ? element_count : 0; - int idx_max_v = idx_max / kVSize; - - // read data - for (int it = 0; it < kIterationsV; ++it) { - int src_idx = threadIdx.x + it * kWarpSize; - if (src_idx < idx_max_v) { - src_reg[i][it] = src_v[src_idx]; - grad_reg[i][it] = grad_v[src_idx]; - } else { -#pragma unroll - for (int s = 0; s < kVSize; s++) { - reinterpret_cast(&src_reg[i][it])[s] = 0.0; - reinterpret_cast(&grad_reg[i][it])[s] = 0.0; - } - } - } - } - - // compute sum - AccT sum[kBatchSize]{0.0}; -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { -#pragma unroll - for (int it = 0; it < kIterationsV; ++it) { - T* gradptr = reinterpret_cast(&grad_reg[i][it]); - T* srcptr = reinterpret_cast(&src_reg[i][it]); -#pragma unroll - for (int s = 0; s < kVSize; ++s) { - if (LogMode) { - sum[i] += static_cast(gradptr[s]); - } else { - sum[i] += static_cast(gradptr[s] * srcptr[s]); - } - } - } - } - WarpReduceSum(sum); - -// write result -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { - if (i >= local_batches) break; - - VecT* dst_v = reinterpret_cast(&dst[(first_batch + i) * stride]); - - // max index to write - int idx_max = (i < local_batches) ? element_count : 0; - int idx_max_v = idx_max / kVSize; - -#pragma unroll - for (int it = 0; it < kIterationsV; ++it) { - VecT tmpdata; - T* tmpptr = reinterpret_cast(&tmpdata); - T* gradptr = reinterpret_cast(&grad_reg[i][it]); - T* srcptr = reinterpret_cast(&src_reg[i][it]); -#pragma unroll - for (int s = 0; s < kVSize; ++s) { - if (LogMode) { - tmpptr[s] = static_cast(gradptr[s]) - - std::exp(static_cast(srcptr[s])) * sum[i]; - } else { - tmpptr[s] = static_cast(srcptr[s]) * - (static_cast(gradptr[s]) - sum[i]); - } - } - - int idx = threadIdx.x + it * kWarpSize; - if (idx < idx_max_v) { - dst_v[idx] = tmpdata; - } - } - } -} - -#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \ - case Log2Elements: \ - WarpSoftmaxForward< \ - T, VecT, AccT, Log2Elements, \ - LogMode><<>>( \ - dst, src, batch_size, stride, element_count); \ - break; - -/* - Wrapper of softmax formward with template instantiation on size of input. -*/ -template -void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads, - const framework::ExecutionContext& ctx, T* dst, - const T* src, const int batch_size, - const int stride, const int element_count, - int Log2Elements) { - using AccT = typename details::MPTypeTrait::Type; - switch (Log2Elements) { - SOFTMAX_WARP_FORWARD_CASE(0, AccT); - SOFTMAX_WARP_FORWARD_CASE(1, AccT); - SOFTMAX_WARP_FORWARD_CASE(2, AccT); - SOFTMAX_WARP_FORWARD_CASE(3, AccT); - SOFTMAX_WARP_FORWARD_CASE(4, AccT); - SOFTMAX_WARP_FORWARD_CASE(5, AccT); - SOFTMAX_WARP_FORWARD_CASE(6, AccT); - SOFTMAX_WARP_FORWARD_CASE(7, AccT); - SOFTMAX_WARP_FORWARD_CASE(8, AccT); - SOFTMAX_WARP_FORWARD_CASE(9, AccT); - default: - break; - } -} - -#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT) \ - case Log2Elements: \ - WarpSoftmaxBackward< \ - T, VecT, AccT, Log2Elements, \ - LogMode><<>>( \ - dst, grad, src, batch_size, stride, element_count); \ - break; - -/* -Wrapper of softmax backward with template instantiation on size of input. -*/ -template -void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads, - const framework::ExecutionContext& ctx, T* dst, - const T* grad, const T* src, - const int batch_size, const int stride, - const int element_count, int Log2Elements) { - using AccT = typename details::MPTypeTrait::Type; - switch (Log2Elements) { - SOFTMAX_WARP_BACKWARD_CASE(0, AccT); - SOFTMAX_WARP_BACKWARD_CASE(1, AccT); - SOFTMAX_WARP_BACKWARD_CASE(2, AccT); - SOFTMAX_WARP_BACKWARD_CASE(3, AccT); - SOFTMAX_WARP_BACKWARD_CASE(4, AccT); - SOFTMAX_WARP_BACKWARD_CASE(5, AccT); - SOFTMAX_WARP_BACKWARD_CASE(6, AccT); - SOFTMAX_WARP_BACKWARD_CASE(7, AccT); - SOFTMAX_WARP_BACKWARD_CASE(8, AccT); - SOFTMAX_WARP_BACKWARD_CASE(9, AccT); - default: - break; - } -} - -#undef SOFTMAX_WARP_FORWARD_CASE -#undef SOFTMAX_WARP_BACKWARD_CASE - template class SoftmaxCUDNNKernel : public framework::OpKernel { public: @@ -452,92 +25,10 @@ class SoftmaxCUDNNKernel : public framework::OpKernel { auto* x = ctx.Input("X"); auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); - auto* out_data = out->data(); - - auto dims = x->dims(); - const int rank = dims.size(); - const int axis = CanonicalAxis(ctx.Attr("axis"), rank); - const int dim = dims[axis]; - const int N = SizeToAxis(axis, dims); - const int D = SizeOutAxis(axis, dims); - - constexpr int max_dim = 320; - constexpr int warps_per_block = 4; - - if (D == 1 && dim <= max_dim && sizeof(T) <= 4) { - const int kDimLog2 = static_cast(log2_ceil(dim)); - const int kDimCeil = 1 << kDimLog2; - int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; - int batches_per_warp = (kDimCeil <= 32) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / kWarpSize); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (N + batches_per_block - 1) / batches_per_block; - dim3 threads(kWarpSize, warps_per_block, 1); - - // vectorization read/write - using T4 = typename VecT4::Type; - using T2 = typename VecT2::Type; - if (dim % 4 == 0) { - SwitchWarpSoftmaxForward(blocks, threads, ctx, out_data, - x->data(), N, dim, dim, - kDimLog2); - } else if (dim % 2 == 0) { - SwitchWarpSoftmaxForward(blocks, threads, ctx, out_data, - x->data(), N, dim, dim, - kDimLog2); - } else { - SwitchWarpSoftmaxForward(blocks, threads, ctx, out_data, - x->data(), N, dim, dim, - kDimLog2); - } - } else { - ScopedTensorDescriptor desc; - std::vector tensor_dims = {N, dim, D, 1}; - DataLayout layout = DataLayout::kNCHW; -#ifdef PADDLE_WITH_HIP - miopenTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); -#else - cudnnTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); -#endif - - auto& dev_ctx = - ctx.template device_context(); - auto handle = dev_ctx.cudnn_handle(); - -#ifdef PADDLE_WITH_HIP - auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE - : MIOPEN_SOFTMAX_MODE_CHANNEL; - if (LogMode) { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( - handle, platform::CudnnDataType::kOne(), desc_, x->data(), - platform::CudnnDataType::kZero(), desc_, out_data, - MIOPEN_SOFTMAX_LOG, mode)); - } else { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( - handle, platform::CudnnDataType::kOne(), desc_, x->data(), - platform::CudnnDataType::kZero(), desc_, out_data, - MIOPEN_SOFTMAX_ACCURATE, mode)); - } -#else - auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE - : CUDNN_SOFTMAX_MODE_CHANNEL; - if (LogMode) { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( - handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), - desc_, x->data(), platform::CudnnDataType::kZero(), desc_, - out_data)); - } else { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( - handle, CUDNN_SOFTMAX_ACCURATE, mode, - platform::CudnnDataType::kOne(), desc_, x->data(), - platform::CudnnDataType::kZero(), desc_, out_data)); - } -#endif - } + int input_axis = ctx.Attr("axis"); + auto& dev_ctx = ctx.template device_context(); + SoftmaxForwardCUDAKernelDriver(dev_ctx, *x, input_axis, out); } }; @@ -549,91 +40,10 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel { auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); dx->mutable_data(ctx.GetPlace()); - auto* dx_data = dx->data(); - - auto dims = out->dims(); - const int rank = dims.size(); - const int axis = CanonicalAxis(ctx.Attr("axis"), rank); - const int dim = dims[axis]; - const int N = SizeToAxis(axis, dims); - const int D = SizeOutAxis(axis, dims); - - constexpr int max_dim = 320; - constexpr int warps_per_block = 4; - if (D == 1 && dim <= max_dim && sizeof(T) <= 4) { - const int kDimLog2 = log2_ceil(dim); - const int kDimCeil = 1 << kDimLog2; - int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; - int batches_per_warp = (kDimCeil <= 128) ? 2 : 1; - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / kWarpSize); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (N + batches_per_block - 1) / batches_per_block; - dim3 threads(kWarpSize, warps_per_block, 1); - - // vectorization read/write - using T4 = typename VecT4::Type; - using T2 = typename VecT2::Type; - if (dim % 4 == 0) { - SwitchWarpSoftmaxBackward( - blocks, threads, ctx, dx_data, dout->data(), out->data(), N, - dim, dim, kDimLog2); - } else if (dim % 2 == 0) { - SwitchWarpSoftmaxBackward( - blocks, threads, ctx, dx_data, dout->data(), out->data(), N, - dim, dim, kDimLog2); - } else { - SwitchWarpSoftmaxBackward( - blocks, threads, ctx, dx_data, dout->data(), out->data(), N, - dim, dim, kDimLog2); - } - } else { - ScopedTensorDescriptor desc; - std::vector tensor_dims = {N, dim, D, 1}; - DataLayout layout = DataLayout::kNCHW; -#ifdef PADDLE_WITH_HIP - miopenTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); -#else - cudnnTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); -#endif - - auto& dev_ctx = - ctx.template device_context(); - auto handle = dev_ctx.cudnn_handle(); - -#ifdef PADDLE_WITH_HIP - auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE - : MIOPEN_SOFTMAX_MODE_CHANNEL; - if (LogMode) { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2( - handle, platform::CudnnDataType::kOne(), desc_, out->data(), - desc_, dout->data(), platform::CudnnDataType::kZero(), desc_, - dx_data, MIOPEN_SOFTMAX_LOG, mode)); - } else { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2( - handle, platform::CudnnDataType::kOne(), desc_, out->data(), - desc_, dout->data(), platform::CudnnDataType::kZero(), desc_, - dx_data, MIOPEN_SOFTMAX_ACCURATE, mode)); - } -#else - auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE - : CUDNN_SOFTMAX_MODE_CHANNEL; - if (LogMode) { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward( - handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), - desc_, out->data(), desc_, dout->data(), - platform::CudnnDataType::kZero(), desc_, dx_data)); - } else { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward( - handle, CUDNN_SOFTMAX_ACCURATE, mode, - platform::CudnnDataType::kOne(), desc_, out->data(), desc_, - dout->data(), platform::CudnnDataType::kZero(), desc_, - dx_data)); - } -#endif - } + int input_axis = ctx.Attr("axis"); + auto& dev_ctx = ctx.template device_context(); + SoftmaxBackwardCUDAKernelDriver(dev_ctx, *out, *dout, input_axis, dx); } }; diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.h b/paddle/fluid/operators/softmax_cudnn_op.cu.h new file mode 100644 index 00000000000..cb63e88d636 --- /dev/null +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.h @@ -0,0 +1,642 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/math/math_cuda_utils.h" +#include "paddle/fluid/operators/softmax_op.h" +#include "paddle/fluid/platform/cuda_device_function.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#else +#include "paddle/fluid/platform/cudnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using DataLayout = platform::DataLayout; +using Tensor = framework::Tensor; + +// Vectorization trait 4 * sizeof(T) +template +class VecT4 {}; +template <> +class VecT4 { + public: + using Type = long4; +}; +template <> +class VecT4 { + public: + using Type = int4; +}; +template <> +class VecT4 { + public: + using Type = int2; +}; + +// Vectorization trait 2 * sizeof(T) +template +class VecT2 {}; +template <> +class VecT2 { + public: + using Type = int4; +}; +template <> +class VecT2 { + public: + using Type = int2; +}; +template <> +class VecT2 { + public: + using Type = int; +}; + +static inline int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +__device__ __forceinline__ void WarpReduceSum(T* sum) { +#pragma unroll + for (int offset = WarpSize / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < BatchSize; ++i) { + T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); + sum[i] = sum[i] + sum_val; + } + } +} + +template +__device__ __forceinline__ void WarpReduceMax(T* sum) { +#pragma unroll + for (int offset = WarpSize / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < BatchSize; ++i) { + T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); + sum[i] = max(sum[i], max_val); + } + } +} + +/* +Core function of computing softmax forward for axis=-1. +The computation includes + - Compute maximum of batch: maxvalue_{i} = max_j src_{i,j} + - Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} } + - Compute: (a_{i,j} - maxvalue_{i}) / s_{i} +One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). +For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle +api to compute max (sum) in one warp. +*/ +template +__global__ void WarpSoftmaxForward(T* softmax, const T* src, + const int batch_size, const int stride, + const int element_count) { + constexpr int kDimCeil = 1 << Log2Elements; + constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + constexpr int kVSize = sizeof(VecT) / sizeof(T); + constexpr int kIterations = kDimCeil / kWarpSize; + constexpr int kIterationsV = + (kIterations >= kVSize) ? (kIterations / kVSize) : 1; + constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; + + // max index to read + int idx_max_v[kBatchSize]; +#pragma unroll + for (int i = 0; i < kBatchSize; i++) { + int idx_max = ((i + first_batch) < batch_size) ? element_count : 0; + idx_max_v[i] = idx_max / kVSize; + } + + // read data from global memory + AccT srcdata[kBatchSize][kIterationsV][kVSize]; + +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { +// read data +#pragma unroll + for (int it = 0; it < kIterationsV; ++it) { + int src_idx = threadIdx.x + it * kWarpSize; + if (kVSize == 1) { + if (src_idx < idx_max_v[i]) { + srcdata[i][it][0] = + static_cast(src[(first_batch + i) * stride + src_idx]); + } else { + srcdata[i][it][0] = -std::numeric_limits::infinity(); + } + } else { + const VecT* src_v = + reinterpret_cast(&src[(first_batch + i) * stride]); + if (src_idx < idx_max_v[i]) { + VecT srctmp = src_v[src_idx]; + const T* srcinptr = reinterpret_cast(&srctmp); +#pragma unroll + for (int s = 0; s < kVSize; s++) { + srcdata[i][it][s] = static_cast(srcinptr[s]); + } + } else { +#pragma unroll + for (int s = 0; s < kVSize; s++) { + srcdata[i][it][s] = -std::numeric_limits::infinity(); + } + } + } + } + } + + // compute max value + AccT max_value[kBatchSize]; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + // it = 0 + AccT valmax = srcdata[i][0][0]; +#pragma unroll + for (int s = 1; s < kVSize; ++s) { + valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s]; + } + max_value[i] = valmax; + +// it = 1, 2, ... +#pragma unroll + for (int it = 1; it < kIterationsV; ++it) { + AccT valmax = srcdata[i][it][0]; +#pragma unroll + for (int s = 1; s < kVSize; ++s) { + valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s]; + } + max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax; + } + } + WarpReduceMax(max_value); + + // compute sum + AccT sum[kBatchSize]; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + // it = 0 + if (LogMode) { + sum[i] = std::exp(srcdata[i][0][0] - max_value[i]); + } else { + srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]); + sum[i] = srcdata[i][0][0]; + } +#pragma unroll + for (int s = 1; s < kVSize; ++s) { + if (LogMode) { + sum[i] += std::exp(srcdata[i][0][s] - max_value[i]); + } else { + srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]); + sum[i] += srcdata[i][0][s]; + } + } + +// it = 1, 2, ... +#pragma unroll + for (int it = 1; it < kIterationsV; ++it) { +#pragma unroll + for (int s = 0; s < kVSize; ++s) { + if (LogMode) { + sum[i] += std::exp(srcdata[i][it][s] - max_value[i]); + } else { + srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]); + sum[i] += srcdata[i][it][s]; + } + } + } + } + WarpReduceSum(sum); + +// write result to global memory +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + if (LogMode) { + sum[i] = std::log(sum[i]); + } + +#pragma unroll + for (int it = 0; it < kIterationsV; ++it) { + int idx = threadIdx.x + it * kWarpSize; + if (kVSize == 1) { + if (idx < idx_max_v[i]) { + if (LogMode) { + softmax[(first_batch + i) * stride + idx] = + srcdata[i][it][0] - max_value[i] - sum[i]; + } else { + softmax[(first_batch + i) * stride + idx] = + srcdata[i][it][0] / sum[i]; + } + } else { + break; + } + } else { + VecT* softmax_v = + reinterpret_cast(&softmax[(first_batch + i) * stride]); + VecT tmpdata; + T* tmpptr = reinterpret_cast(&tmpdata); +#pragma unroll + for (int s = 0; s < kVSize; ++s) { + if (LogMode) { + tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i]; + } else { + tmpptr[s] = srcdata[i][it][s] / sum[i]; + } + } + + if (idx < idx_max_v[i]) { + softmax_v[idx] = tmpdata; + } else { + break; + } + } + } + } +} + +/* +Core function of computing softmax backward for axis=-1. +The computation includes + - Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j} + - Compute src_{i,j} * ( grad_{i,j}) - s_{i} ) +One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). +For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle +api to compute max (sum) in one warp. +*/ +template +__global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, + int batch_size, int stride, + int element_count) { + constexpr int kVSize = sizeof(VecT) / sizeof(T); + constexpr int kDimCeil = 1 << Log2Elements; + constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + constexpr int kIterations = kDimCeil / kWarpSize; + constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1; + constexpr int kIterationsV = + (kIterations >= kVSize) ? (kIterations / kVSize) : 1; + int element_count_v = element_count / kVSize; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; + int local_batches = batch_size - first_batch; + if (local_batches > kBatchSize) { + local_batches = kBatchSize; + } + + // read data from global memory + VecT src_reg[kBatchSize][kIterationsV]; + VecT grad_reg[kBatchSize][kIterationsV]; + + for (int i = 0; i < kBatchSize; ++i) { + const VecT* src_v = + reinterpret_cast(&src[(first_batch + i) * stride]); + const VecT* grad_v = + reinterpret_cast(&grad[(first_batch + i) * stride]); + + // max index to read + int idx_max = (i < local_batches) ? element_count : 0; + int idx_max_v = idx_max / kVSize; + + // read data + for (int it = 0; it < kIterationsV; ++it) { + int src_idx = threadIdx.x + it * kWarpSize; + if (src_idx < idx_max_v) { + src_reg[i][it] = src_v[src_idx]; + grad_reg[i][it] = grad_v[src_idx]; + } else { +#pragma unroll + for (int s = 0; s < kVSize; s++) { + reinterpret_cast(&src_reg[i][it])[s] = 0.0; + reinterpret_cast(&grad_reg[i][it])[s] = 0.0; + } + } + } + } + + // compute sum + AccT sum[kBatchSize]{0.0}; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { +#pragma unroll + for (int it = 0; it < kIterationsV; ++it) { + T* gradptr = reinterpret_cast(&grad_reg[i][it]); + T* srcptr = reinterpret_cast(&src_reg[i][it]); +#pragma unroll + for (int s = 0; s < kVSize; ++s) { + if (LogMode) { + sum[i] += static_cast(gradptr[s]); + } else { + sum[i] += static_cast(gradptr[s] * srcptr[s]); + } + } + } + } + WarpReduceSum(sum); + +// write result +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + if (i >= local_batches) break; + + VecT* dst_v = reinterpret_cast(&dst[(first_batch + i) * stride]); + + // max index to write + int idx_max = (i < local_batches) ? element_count : 0; + int idx_max_v = idx_max / kVSize; + +#pragma unroll + for (int it = 0; it < kIterationsV; ++it) { + VecT tmpdata; + T* tmpptr = reinterpret_cast(&tmpdata); + T* gradptr = reinterpret_cast(&grad_reg[i][it]); + T* srcptr = reinterpret_cast(&src_reg[i][it]); +#pragma unroll + for (int s = 0; s < kVSize; ++s) { + if (LogMode) { + tmpptr[s] = static_cast(gradptr[s]) - + std::exp(static_cast(srcptr[s])) * sum[i]; + } else { + tmpptr[s] = static_cast(srcptr[s]) * + (static_cast(gradptr[s]) - sum[i]); + } + } + + int idx = threadIdx.x + it * kWarpSize; + if (idx < idx_max_v) { + dst_v[idx] = tmpdata; + } + } + } +} + +#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \ + case Log2Elements: \ + WarpSoftmaxForward<<>>( \ + dst, src, batch_size, stride, element_count); \ + break; + +/* + Wrapper of softmax formward with template instantiation on size of input. +*/ +template +void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads, + const platform::CUDADeviceContext& dev_ctx, + T* dst, const T* src, const int batch_size, + const int stride, const int element_count, + int Log2Elements) { + using AccT = typename details::MPTypeTrait::Type; + switch (Log2Elements) { + SOFTMAX_WARP_FORWARD_CASE(0, AccT); + SOFTMAX_WARP_FORWARD_CASE(1, AccT); + SOFTMAX_WARP_FORWARD_CASE(2, AccT); + SOFTMAX_WARP_FORWARD_CASE(3, AccT); + SOFTMAX_WARP_FORWARD_CASE(4, AccT); + SOFTMAX_WARP_FORWARD_CASE(5, AccT); + SOFTMAX_WARP_FORWARD_CASE(6, AccT); + SOFTMAX_WARP_FORWARD_CASE(7, AccT); + SOFTMAX_WARP_FORWARD_CASE(8, AccT); + SOFTMAX_WARP_FORWARD_CASE(9, AccT); + default: + break; + } +} + +#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT) \ + case Log2Elements: \ + WarpSoftmaxBackward<<>>( \ + dst, grad, src, batch_size, stride, element_count); \ + break; + +/* +Wrapper of softmax backward with template instantiation on size of input. +*/ +template +void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads, + const platform::CUDADeviceContext& dev_ctx, + T* dst, const T* grad, const T* src, + const int batch_size, const int stride, + const int element_count, int Log2Elements) { + using AccT = typename details::MPTypeTrait::Type; + switch (Log2Elements) { + SOFTMAX_WARP_BACKWARD_CASE(0, AccT); + SOFTMAX_WARP_BACKWARD_CASE(1, AccT); + SOFTMAX_WARP_BACKWARD_CASE(2, AccT); + SOFTMAX_WARP_BACKWARD_CASE(3, AccT); + SOFTMAX_WARP_BACKWARD_CASE(4, AccT); + SOFTMAX_WARP_BACKWARD_CASE(5, AccT); + SOFTMAX_WARP_BACKWARD_CASE(6, AccT); + SOFTMAX_WARP_BACKWARD_CASE(7, AccT); + SOFTMAX_WARP_BACKWARD_CASE(8, AccT); + SOFTMAX_WARP_BACKWARD_CASE(9, AccT); + default: + break; + } +} + +#undef SOFTMAX_WARP_FORWARD_CASE +#undef SOFTMAX_WARP_BACKWARD_CASE + +template +void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, + const Tensor& x, const int input_axis, + Tensor* out) { + auto* out_data = out->data(); + + auto dims = x.dims(); + const int rank = dims.size(); + const int axis = CanonicalAxis(input_axis, rank); + const int dim = dims[axis]; + const int N = SizeToAxis(axis, dims); + const int D = SizeOutAxis(axis, dims); + + constexpr int max_dim = 320; + constexpr int warps_per_block = 4; + + if (D == 1 && dim <= max_dim && sizeof(T) <= 4) { + const int kDimLog2 = static_cast(log2_ceil(dim)); + const int kDimCeil = 1 << kDimLog2; + int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + int batches_per_warp = (kDimCeil <= 32) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / kWarpSize); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (N + batches_per_block - 1) / batches_per_block; + dim3 threads(kWarpSize, warps_per_block, 1); + + // vectorization read/write + using T4 = typename VecT4::Type; + using T2 = typename VecT2::Type; + if (dim % 4 == 0) { + SwitchWarpSoftmaxForward(blocks, threads, dev_ctx, + out_data, x.data(), N, dim, + dim, kDimLog2); + } else if (dim % 2 == 0) { + SwitchWarpSoftmaxForward(blocks, threads, dev_ctx, + out_data, x.data(), N, dim, + dim, kDimLog2); + } else { + SwitchWarpSoftmaxForward(blocks, threads, dev_ctx, + out_data, x.data(), N, dim, + dim, kDimLog2); + } + } else { + ScopedTensorDescriptor desc; + std::vector tensor_dims = {N, dim, D, 1}; + DataLayout layout = DataLayout::kNCHW; +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); +#else + cudnnTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); +#endif + + auto handle = dev_ctx.cudnn_handle(); + +#ifdef PADDLE_WITH_HIP + auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE + : MIOPEN_SOFTMAX_MODE_CHANNEL; + if (LogMode) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( + handle, platform::CudnnDataType::kOne(), desc_, x.data(), + platform::CudnnDataType::kZero(), desc_, out_data, + MIOPEN_SOFTMAX_LOG, mode)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( + handle, platform::CudnnDataType::kOne(), desc_, x.data(), + platform::CudnnDataType::kZero(), desc_, out_data, + MIOPEN_SOFTMAX_ACCURATE, mode)); + } +#else + auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE + : CUDNN_SOFTMAX_MODE_CHANNEL; + if (LogMode) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( + handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), + desc_, x.data(), platform::CudnnDataType::kZero(), desc_, + out_data)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( + handle, CUDNN_SOFTMAX_ACCURATE, mode, + platform::CudnnDataType::kOne(), desc_, x.data(), + platform::CudnnDataType::kZero(), desc_, out_data)); + } +#endif + } +} + +template +void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, + const Tensor& out, const Tensor& dout, + const int input_axis, Tensor* dx) { + auto* dx_data = dx->data(); + + auto dims = out.dims(); + const int rank = dims.size(); + const int axis = CanonicalAxis(input_axis, rank); + const int dim = dims[axis]; + const int N = SizeToAxis(axis, dims); + const int D = SizeOutAxis(axis, dims); + + constexpr int max_dim = 320; + constexpr int warps_per_block = 4; + + if (D == 1 && dim <= max_dim && sizeof(T) <= 4) { + const int kDimLog2 = log2_ceil(dim); + const int kDimCeil = 1 << kDimLog2; + int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + int batches_per_warp = (kDimCeil <= 128) ? 2 : 1; + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / kWarpSize); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (N + batches_per_block - 1) / batches_per_block; + dim3 threads(kWarpSize, warps_per_block, 1); + + // vectorization read/write + using T4 = typename VecT4::Type; + using T2 = typename VecT2::Type; + if (dim % 4 == 0) { + SwitchWarpSoftmaxBackward( + blocks, threads, dev_ctx, dx_data, dout.data(), out.data(), N, + dim, dim, kDimLog2); + } else if (dim % 2 == 0) { + SwitchWarpSoftmaxBackward( + blocks, threads, dev_ctx, dx_data, dout.data(), out.data(), N, + dim, dim, kDimLog2); + } else { + SwitchWarpSoftmaxBackward( + blocks, threads, dev_ctx, dx_data, dout.data(), out.data(), N, + dim, dim, kDimLog2); + } + } else { + ScopedTensorDescriptor desc; + std::vector tensor_dims = {N, dim, D, 1}; + DataLayout layout = DataLayout::kNCHW; +#ifdef PADDLE_WITH_HIP + miopenTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); +#else + cudnnTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); +#endif + + auto handle = dev_ctx.cudnn_handle(); + +#ifdef PADDLE_WITH_HIP + auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE + : MIOPEN_SOFTMAX_MODE_CHANNEL; + if (LogMode) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2( + handle, platform::CudnnDataType::kOne(), desc_, out.data(), + desc_, dout.data(), platform::CudnnDataType::kZero(), desc_, + dx_data, MIOPEN_SOFTMAX_LOG, mode)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2( + handle, platform::CudnnDataType::kOne(), desc_, out.data(), + desc_, dout.data(), platform::CudnnDataType::kZero(), desc_, + dx_data, MIOPEN_SOFTMAX_ACCURATE, mode)); + } +#else + auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE + : CUDNN_SOFTMAX_MODE_CHANNEL; + if (LogMode) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward( + handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), + desc_, out.data(), desc_, dout.data(), + platform::CudnnDataType::kZero(), desc_, dx_data)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward( + handle, CUDNN_SOFTMAX_ACCURATE, mode, + platform::CudnnDataType::kOne(), desc_, out.data(), desc_, + dout.data(), platform::CudnnDataType::kZero(), desc_, dx_data)); + } +#endif + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/softmax_impl.cuh b/paddle/fluid/operators/softmax_impl.cuh deleted file mode 100755 index 2acc55d2398..00000000000 --- a/paddle/fluid/operators/softmax_impl.cuh +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/platform/cuda_device_function.h" - -namespace paddle { -namespace operators { - -template -__device__ __forceinline__ void WarpReduceSum(T* sum) { -#pragma unroll - for (int offset = WarpSize / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < BatchSize; ++i) { - T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); - sum[i] = sum[i] + sum_val; - } - } -} - -template -__device__ __forceinline__ void WarpReduceMax(T* sum) { -#pragma unroll - for (int offset = WarpSize / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < BatchSize; ++i) { - T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); - sum[i] = max(sum[i], max_val); - } - } -} - -} // namespace operators -} // namespace paddle \ No newline at end of file -- GitLab