From ef76f664a09dfcb77feb3bc3bdccfbe619fd739f Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Thu, 28 Oct 2021 14:14:40 +0800 Subject: [PATCH] Rewrite Softmax in Kernel Primitive API, test=develop (#36706) --- paddle/fluid/operators/softmax_cudnn_op.cu.h | 401 +++++++++---------- 1 file changed, 191 insertions(+), 210 deletions(-) diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.h b/paddle/fluid/operators/softmax_cudnn_op.cu.h index cb63e88d63..68b694a59f 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.h +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/operators/math/math_cuda_utils.h" #include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/platform/cuda_device_function.h" @@ -99,6 +100,97 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) { } } +namespace kps = paddle::operators::kernel_primitives; + +template +struct ReduceMaxFunctor { + inline Ty initial() { return -std::numeric_limits::infinity(); } + + __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const { + return max(a, b); + } +}; + +template +struct ExpSubFunctor { + HOSTDEVICE inline ExpSubFunctor() { y = static_cast(0.0f); } + + HOSTDEVICE explicit inline ExpSubFunctor(Tx y) : y((Tx)(y)) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(std::exp(x - y)); + } + + private: + Tx y; +}; + +template +struct ExpMulFunctor { + HOSTDEVICE inline ExpMulFunctor() { y = static_cast(1.0f); } + + HOSTDEVICE explicit inline ExpMulFunctor(Tx y) : y((Tx)(y)) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(std::exp(x) * y); + } + + private: + Tx y; +}; + +template +struct UnarySubFunctor { + HOSTDEVICE inline UnarySubFunctor() { y = static_cast(0.0f); } + + HOSTDEVICE explicit inline UnarySubFunctor(Tx y) : y((Tx)(y)) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(x - y); + } + + private: + Tx y; +}; + +template +struct UnaryLogFunctor { + HOSTDEVICE inline UnaryLogFunctor() {} + + HOSTDEVICE explicit inline UnaryLogFunctor(int n) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(std::log(x)); + } +}; + +template +struct DataTransFunctor { + HOSTDEVICE inline DataTransFunctor() {} + + HOSTDEVICE explicit inline DataTransFunctor(int n) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return x == -std::numeric_limits::infinity() + ? -std::numeric_limits::infinity() + : static_cast(x); + } +}; + +template +struct UnaryDivFunctor { + HOSTDEVICE inline UnaryDivFunctor() { n_inv = static_cast(1.0f); } + + HOSTDEVICE explicit inline UnaryDivFunctor(Tx n) : n_inv((Tx)(1.0 / n)) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(x * n_inv); + } + + private: + Tx n_inv; +}; + /* Core function of computing softmax forward for axis=-1. The computation includes @@ -117,12 +209,14 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, constexpr int kDimCeil = 1 << Log2Elements; constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; constexpr int kVSize = sizeof(VecT) / sizeof(T); - constexpr int kIterations = kDimCeil / kWarpSize; - constexpr int kIterationsV = - (kIterations >= kVSize) ? (kIterations / kVSize) : 1; + constexpr int kLoops = kDimCeil / kWarpSize; + constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1; constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1; - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; + constexpr int kStep = kBatchSize * kLoopsV * kVSize; + constexpr int kVItem = kLoopsV * kVSize; + constexpr AccT kLowInf = -std::numeric_limits::infinity(); + using kMode = kps::details::ReduceMode; // max index to read int idx_max_v[kBatchSize]; @@ -133,146 +227,51 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, } // read data from global memory - AccT srcdata[kBatchSize][kIterationsV][kVSize]; - + AccT srcdata[kBatchSize][kLoopsV][kVSize]; + kps::Init(&srcdata[0][0][0], kLowInf); + T src_tmp[kBatchSize][kLoopsV][kVSize]; + kps::Init(&src_tmp[0][0][0], -std::numeric_limits::infinity()); #pragma unroll for (int i = 0; i < kBatchSize; ++i) { -// read data -#pragma unroll - for (int it = 0; it < kIterationsV; ++it) { - int src_idx = threadIdx.x + it * kWarpSize; - if (kVSize == 1) { - if (src_idx < idx_max_v[i]) { - srcdata[i][it][0] = - static_cast(src[(first_batch + i) * stride + src_idx]); - } else { - srcdata[i][it][0] = -std::numeric_limits::infinity(); - } - } else { - const VecT* src_v = - reinterpret_cast(&src[(first_batch + i) * stride]); - if (src_idx < idx_max_v[i]) { - VecT srctmp = src_v[src_idx]; - const T* srcinptr = reinterpret_cast(&srctmp); -#pragma unroll - for (int s = 0; s < kVSize; s++) { - srcdata[i][it][s] = static_cast(srcinptr[s]); - } - } else { -#pragma unroll - for (int s = 0; s < kVSize; s++) { - srcdata[i][it][s] = -std::numeric_limits::infinity(); - } - } - } - } + int ptr = (first_batch + i) * stride; + const VecT* src_v = reinterpret_cast(&src[ptr]); + VecT* reg_v = reinterpret_cast(&src_tmp[i][0][0]); + kps::ReadData( + ®_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1); + kps::ElementwiseUnary>( + &srcdata[i][0][0], &src_tmp[i][0][0], DataTransFunctor()); } - // compute max value - AccT max_value[kBatchSize]; -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { - // it = 0 - AccT valmax = srcdata[i][0][0]; -#pragma unroll - for (int s = 1; s < kVSize; ++s) { - valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s]; - } - max_value[i] = valmax; - -// it = 1, 2, ... -#pragma unroll - for (int it = 1; it < kIterationsV; ++it) { - AccT valmax = srcdata[i][it][0]; -#pragma unroll - for (int s = 1; s < kVSize; ++s) { - valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s]; - } - max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax; - } - } - WarpReduceMax(max_value); + // compute max + AccT max[kBatchSize]; + kps::Init(&max[0], kLowInf); + kps::Reduce, + kMode::kLocalMode>(&max[0], &srcdata[0][0][0], + ReduceMaxFunctor(), true); + WarpReduceMax(max); // compute sum - AccT sum[kBatchSize]; -#pragma unroll + AccT sum[kBatchSize] = {0}; for (int i = 0; i < kBatchSize; ++i) { - // it = 0 - if (LogMode) { - sum[i] = std::exp(srcdata[i][0][0] - max_value[i]); - } else { - srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]); - sum[i] = srcdata[i][0][0]; - } -#pragma unroll - for (int s = 1; s < kVSize; ++s) { - if (LogMode) { - sum[i] += std::exp(srcdata[i][0][s] - max_value[i]); - } else { - srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]); - sum[i] += srcdata[i][0][s]; - } - } - -// it = 1, 2, ... -#pragma unroll - for (int it = 1; it < kIterationsV; ++it) { -#pragma unroll - for (int s = 0; s < kVSize; ++s) { - if (LogMode) { - sum[i] += std::exp(srcdata[i][it][s] - max_value[i]); - } else { - srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]); - sum[i] += srcdata[i][it][s]; - } - } - } + kps::ElementwiseUnary>( + &srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor(max[i])); } + kps::Reduce, + kMode::kLocalMode>(&sum[0], &srcdata[0][0][0], + kps::AddFunctor(), true); WarpReduceSum(sum); -// write result to global memory + // write result to global memory + T out_tmp[kBatchSize][kLoopsV][kVSize]; #pragma unroll for (int i = 0; i < kBatchSize; ++i) { - if (LogMode) { - sum[i] = std::log(sum[i]); - } - -#pragma unroll - for (int it = 0; it < kIterationsV; ++it) { - int idx = threadIdx.x + it * kWarpSize; - if (kVSize == 1) { - if (idx < idx_max_v[i]) { - if (LogMode) { - softmax[(first_batch + i) * stride + idx] = - srcdata[i][it][0] - max_value[i] - sum[i]; - } else { - softmax[(first_batch + i) * stride + idx] = - srcdata[i][it][0] / sum[i]; - } - } else { - break; - } - } else { - VecT* softmax_v = - reinterpret_cast(&softmax[(first_batch + i) * stride]); - VecT tmpdata; - T* tmpptr = reinterpret_cast(&tmpdata); -#pragma unroll - for (int s = 0; s < kVSize; ++s) { - if (LogMode) { - tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i]; - } else { - tmpptr[s] = srcdata[i][it][s] / sum[i]; - } - } - - if (idx < idx_max_v[i]) { - softmax_v[idx] = tmpdata; - } else { - break; - } - } - } + kps::ElementwiseUnary>( + &out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor(sum[i])); + int softmax_ptr = (first_batch + i) * stride; + VecT* softmax_v = reinterpret_cast(&softmax[softmax_ptr]); + VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); + kps::WriteData( + &softmax_v[0], ®_v[0], idx_max_v[i], 0, kWarpSize, 1); } } @@ -293,101 +292,82 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, constexpr int kVSize = sizeof(VecT) / sizeof(T); constexpr int kDimCeil = 1 << Log2Elements; constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; - constexpr int kIterations = kDimCeil / kWarpSize; + constexpr int kLoops = kDimCeil / kWarpSize; constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1; - constexpr int kIterationsV = - (kIterations >= kVSize) ? (kIterations / kVSize) : 1; + constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1; int element_count_v = element_count / kVSize; - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; - int local_batches = batch_size - first_batch; - if (local_batches > kBatchSize) { - local_batches = kBatchSize; + int local_batches = min(batch_size - first_batch, kBatchSize); + + // max index to read + int idx_max_v[kBatchSize]; +#pragma unroll + for (int i = 0; i < kBatchSize; i++) { + int idx_max = ((i + first_batch) < batch_size) ? element_count : 0; + idx_max_v[i] = idx_max / kVSize; } // read data from global memory - VecT src_reg[kBatchSize][kIterationsV]; - VecT grad_reg[kBatchSize][kIterationsV]; - - for (int i = 0; i < kBatchSize; ++i) { - const VecT* src_v = - reinterpret_cast(&src[(first_batch + i) * stride]); - const VecT* grad_v = - reinterpret_cast(&grad[(first_batch + i) * stride]); - - // max index to read - int idx_max = (i < local_batches) ? element_count : 0; - int idx_max_v = idx_max / kVSize; - - // read data - for (int it = 0; it < kIterationsV; ++it) { - int src_idx = threadIdx.x + it * kWarpSize; - if (src_idx < idx_max_v) { - src_reg[i][it] = src_v[src_idx]; - grad_reg[i][it] = grad_v[src_idx]; - } else { + VecT src_reg[kBatchSize][kLoopsV]; + VecT grad_reg[kBatchSize][kLoopsV]; + VecT k_value; + for (int s = 0; s < kVSize; s++) { + reinterpret_cast(&k_value)[s] = 0.0; + } + kps::Init(&src_reg[0][0], k_value); + kps::Init(&grad_reg[0][0], k_value); #pragma unroll - for (int s = 0; s < kVSize; s++) { - reinterpret_cast(&src_reg[i][it])[s] = 0.0; - reinterpret_cast(&grad_reg[i][it])[s] = 0.0; - } - } - } + for (int i = 0; i < kBatchSize; ++i) { + int flag = i < local_batches ? 1 : 0; + int ptr = (first_batch + i) * stride; + const VecT* src_v = reinterpret_cast(&src[ptr]); + const VecT* grad_v = reinterpret_cast(&grad[ptr]); + kps::ReadData( + &src_reg[i][0], &src_v[0], idx_max_v[i], 0, kWarpSize, flag); + kps::ReadData( + &grad_reg[i][0], &grad_v[0], idx_max_v[i], 0, kWarpSize, flag); } + // change T to AccT + AccT src_tmp[kBatchSize][kLoopsV][kVSize]; + AccT grad_tmp[kBatchSize][kLoopsV][kVSize]; + const T* src_ptr = reinterpret_cast(&src_reg[0][0]); + const T* grad_ptr = reinterpret_cast(&grad_reg[0][0]); + constexpr int kStep = kBatchSize * kLoopsV * kVSize; + constexpr int kVItem = kLoopsV * kVSize; + kps::ElementwiseUnary>( + &src_tmp[0][0][0], &src_ptr[0], DataTransFunctor()); + kps::ElementwiseUnary>( + &grad_tmp[0][0][0], &grad_ptr[0], DataTransFunctor()); + // compute sum AccT sum[kBatchSize]{0.0}; -#pragma unroll - for (int i = 0; i < kBatchSize; ++i) { -#pragma unroll - for (int it = 0; it < kIterationsV; ++it) { - T* gradptr = reinterpret_cast(&grad_reg[i][it]); - T* srcptr = reinterpret_cast(&src_reg[i][it]); -#pragma unroll - for (int s = 0; s < kVSize; ++s) { - if (LogMode) { - sum[i] += static_cast(gradptr[s]); - } else { - sum[i] += static_cast(gradptr[s] * srcptr[s]); - } - } - } - } + AccT sum_tmp[kBatchSize][kLoopsV][kVSize]; + AccT* gradptr = reinterpret_cast(&grad_tmp[0][0][0]); + AccT* srcptr = reinterpret_cast(&src_tmp[0][0][0]); + kps::ElementwiseBinary>( + &sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor()); + kps::Reduce, + kps::details::ReduceMode::kLocalMode>( + &sum[0], &sum_tmp[0][0][0], kps::AddFunctor(), true); WarpReduceSum(sum); -// write result + // write result to global memory + AccT out[kBatchSize][kLoopsV][kVSize]; + T out_tmp[kBatchSize][kLoopsV][kVSize]; #pragma unroll for (int i = 0; i < kBatchSize; ++i) { if (i >= local_batches) break; - + AccT* gradptr = reinterpret_cast(&grad_tmp[i][0][0]); + AccT* srcptr = reinterpret_cast(&src_tmp[i][0][0]); + kps::ElementwiseUnary>( + &out[i][0][0], &gradptr[0], UnarySubFunctor(sum[i])); + kps::ElementwiseBinary>( + &out_tmp[i][0][0], &srcptr[0], &out[i][0][0], kps::MulFunctor()); VecT* dst_v = reinterpret_cast(&dst[(first_batch + i) * stride]); - - // max index to write - int idx_max = (i < local_batches) ? element_count : 0; - int idx_max_v = idx_max / kVSize; - -#pragma unroll - for (int it = 0; it < kIterationsV; ++it) { - VecT tmpdata; - T* tmpptr = reinterpret_cast(&tmpdata); - T* gradptr = reinterpret_cast(&grad_reg[i][it]); - T* srcptr = reinterpret_cast(&src_reg[i][it]); -#pragma unroll - for (int s = 0; s < kVSize; ++s) { - if (LogMode) { - tmpptr[s] = static_cast(gradptr[s]) - - std::exp(static_cast(srcptr[s])) * sum[i]; - } else { - tmpptr[s] = static_cast(srcptr[s]) * - (static_cast(gradptr[s]) - sum[i]); - } - } - - int idx = threadIdx.x + it * kWarpSize; - if (idx < idx_max_v) { - dst_v[idx] = tmpdata; - } - } + VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); + kps::WriteData( + &dst_v[0], ®_v[0], idx_max_v[i], 0, kWarpSize, 1); } } @@ -493,6 +473,7 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, // vectorization read/write using T4 = typename VecT4::Type; using T2 = typename VecT2::Type; + if (dim % 4 == 0) { SwitchWarpSoftmaxForward(blocks, threads, dev_ctx, out_data, x.data(), N, dim, -- GitLab