From 95d3ebc812cb7a4c01bc9c3651dc6c3eec284ec2 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Wed, 23 Mar 2022 11:09:14 +0800 Subject: [PATCH] Modified dropout Kernel with Kernel Primitive API (#40766) --- paddle/fluid/operators/dropout_impl.cu.h | 255 +++++++----------- .../phi/kernels/funcs/distribution_helper.h | 18 +- .../kernels/gpu/masked_select_grad_kernel.cu | 5 +- 3 files changed, 121 insertions(+), 157 deletions(-) diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index 144198367d5..94db4c62e39 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -35,143 +35,99 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/functors.h" - namespace paddle { namespace operators { +template +struct DstMaskGenerator { + const float dropout_prob_; + const bool is_upscale_in_train_; + using MT = typename details::MPTypeTrait::Type; + MT factor; + HOSTDEVICE inline DstMaskGenerator(const float dropout_prob, + const bool is_upscale_in_train) + : dropout_prob_(dropout_prob), is_upscale_in_train_(is_upscale_in_train) { + factor = static_cast(1.0f / (1.0f - dropout_prob_)); + } -template -__global__ void RandomGenerator(const size_t n, uint64_t seed, - const float dropout_prob, const T* src, - MaskType* mask, T* dst, - bool is_upscale_in_train, uint64_t increment) { - using MT = typename details::MPTypeTrait::Type; - int idx = blockDim.x * blockIdx.x + threadIdx.x; -#ifdef PADDLE_WITH_HIP - hiprandStatePhilox4_32_10_t state; - hiprand_init(seed, idx, increment, &state); -#else - curandStatePhilox4_32_10_t state; - curand_init(seed, idx, increment, &state); -#endif - - MaskType mask_val; - T dst_val; - MT factor = static_cast(1.0f / (1.0f - dropout_prob)); - for (; idx < n; idx += blockDim.x * gridDim.x) { - T src_val = src[idx]; -#ifdef PADDLE_WITH_HIP - if (hiprand_uniform(&state) < dropout_prob) { -#else - if (curand_uniform(&state) < dropout_prob) { -#endif - mask_val = 0; - dst_val = 0; - } else { - mask_val = 1; - dst_val = is_upscale_in_train - ? static_cast(static_cast(src_val) * factor) - : src_val; + HOSTDEVICE inline void operator()(OutT* dst, const T1* src_val, + const T2* rand, int num) const { + static constexpr int kCount = + phi::funcs::uniform_distribution::kReturnsCount; +// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask +#pragma unroll + for (int i = 0; i < kCount; i++) { + if (rand[i] < dropout_prob_) { + dst[i] = static_cast(0); + dst[i + kCount] = dst[i]; + } else { + dst[i] = is_upscale_in_train_ + ? static_cast(static_cast(src_val[i]) * factor) + : static_cast(src_val[i]); + dst[i + kCount] = static_cast(1); + } } - mask[idx] = mask_val; - dst[idx] = dst_val; } -} +}; -template +template __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, const float dropout_prob, const T* src, MaskType* mask, T* dst, bool is_upscale_in_train, - uint64_t increment) { - using MT = typename details::MPTypeTrait::Type; - using LoadT = phi::AlignedVector; - using MaskLoadT = phi::AlignedVector; - + uint64_t increment, + size_t main_offset) { + size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); + static constexpr int kCount = + phi::funcs::uniform_distribution::kReturnsCount; + size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount; #ifdef PADDLE_WITH_HIP - int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x; hiprandStatePhilox4_32_10_t state; - hiprand_init(seed, idx, increment, &state); + hiprand_init(seed, idx + THREAD_ID_X, increment, &state); + using SType = hiprandStatePhilox4_32_10_t; #else - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; curandStatePhilox4_32_10_t state; - curand_init(seed, idx, increment, &state); -#endif - - MT factor = static_cast(1.0f / (1.0f - dropout_prob)); - for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) { - LoadT src_val; - phi::Load(&src[i], &src_val); - -#ifdef PADDLE_WITH_HIP - float4 rand = hiprand_uniform4(&state); -#else - float4 rand = curand_uniform4(&state); + curand_init(seed, idx + THREAD_ID_X, increment, &state); + using SType = curandStatePhilox4_32_10_t; #endif - - LoadT dst_val; - MaskLoadT mask_val; - -#pragma unroll - for (int j = 0; j < VecSize; j++) { - if ((&rand.x)[j] < dropout_prob) { - dst_val[j] = 0; - mask_val[j] = 0; - } else { - dst_val[j] = is_upscale_in_train - ? static_cast(static_cast(src_val[j]) * factor) - : src_val[j]; - mask_val[j] = 1; - } - } - - phi::Store(dst_val, &dst[i]); - phi::Store(mask_val, &mask[i]); + T dst_mask[kCount * 2]; // 0 ~ kCount -1 : dst;kCount ~ 2 * kCount - 1: mask + float rands[kCount]; + MaskType mask_result[kCount]; + using Rand = phi::funcs::uniform_distribution; + using Cast = kps::IdentityFunctor; + int deal_size = BLOCK_NUM_X * kCount; + auto dst_functor = + DstMaskGenerator(dropout_prob, is_upscale_in_train); + size_t fix = idx * kCount; + for (; fix < main_offset; fix += stride) { + kps::ReadData(&dst_mask[0], src + fix, deal_size); + kps::ElementwiseRandom(&rands[0], Rand(), + &state); + // dst + kps::OperatorTernary>( + &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); + kps::WriteData(dst + fix, &dst_mask[0], deal_size); + // mask + kps::ElementwiseUnary( + &mask_result[0], &dst_mask[kCount], Cast()); + kps::WriteData(mask + fix, &mask_result[0], + deal_size); } -} - -template -struct CudaDropoutGradFunctor { - using MT = typename details::MPTypeTrait::Type; - - explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {} - - __device__ __forceinline__ T operator()(const T dout, - const MaskType mask) const { - return static_cast(static_cast(dout) * static_cast(mask) * - factor_); - } - - private: - MT factor_; -}; - -template -__global__ void DropoutGradCUDAKernel( - const T* dout, const MaskType* mask, - const typename details::MPTypeTrait::Type factor, const int64_t size, - T* dx) { - using MT = typename details::MPTypeTrait::Type; - using LoadT = phi::AlignedVector; - using MaskLoadT = phi::AlignedVector; - - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { - LoadT dout_val; - phi::Load(&dout[i], &dout_val); - - MaskLoadT mask_val; - phi::Load(&mask[i], &mask_val); - - LoadT dx_val; - -#pragma unroll - for (int j = 0; j < VecSize; j++) { - dx_val[j] = static_cast(static_cast(dout_val[j]) * - static_cast(mask_val[j]) * factor); - } - - phi::Store(dx_val, &dx[i]); + int remainder = n - fix; + if (remainder > 0) { + kps::ReadData(&dst_mask[0], src + fix, remainder); + kps::ElementwiseRandom(&rands[0], Rand(), + &state); + // dst + kps::OperatorTernary>( + &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); + kps::WriteData(dst + fix, &dst_mask[0], remainder); + // mask + kps::ElementwiseUnary( + &mask_result[0], &dst_mask[kCount], Cast()); + kps::WriteData(mask + fix, &mask_result[0], + remainder); } } @@ -218,42 +174,21 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, uint64_t seed_data; uint64_t increment; // VectorizedRandomGenerator use curand_uniform4, so we only support - // vec_size is 4; - int vec_size = (phi::GetVectorizedSize(x_data) == 4) ? 4 : 1; + // kVecSize is 4; + constexpr int kVecSize = + phi::funcs::uniform_distribution::kReturnsCount; auto gpu_config = - phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, vec_size); + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, kVecSize); auto offset = - ((x_numel - 1) / (gpu_config.GetThreadNum() * vec_size) + 1) * vec_size; - + ((x_numel - 1) / (gpu_config.GetThreadNum() * kVecSize) + 1) * kVecSize; GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment); - -#ifdef __HIPCC__ - if (vec_size == 4 && size % 4 == 0) { - hipLaunchKernelGGL( - HIP_KERNEL_NAME(VectorizedRandomGenerator), - gpu_config.GetGridSize(), gpu_config.GetBlockSize(), 0, stream, size, - seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train, - increment); - } else { - hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomGenerator), - gpu_config.GetGridSize(), gpu_config.GetBlockSize(), 0, - stream, size, seed_data, dropout_prob, x_data, - mask_data, y_data, upscale_in_train, increment); - } -#else - if (vec_size == 4 && size % 4 == 0) { - VectorizedRandomGenerator<<< - gpu_config.block_per_grid, gpu_config.thread_per_block, 0, stream>>>( - size, seed_data, dropout_prob, x_data, mask_data, y_data, - upscale_in_train, increment); - } else { - RandomGenerator<<>>( - size, seed_data, dropout_prob, x_data, mask_data, y_data, - upscale_in_train, increment); - } -#endif + size_t main_offset = size / (gpu_config.GetBlockSize() * kVecSize) * + (gpu_config.GetBlockSize() * kVecSize); + VectorizedRandomGenerator<<< + gpu_config.GetGridSize(), gpu_config.GetBlockSize(), 0, stream>>>( + size, seed_data, dropout_prob, x_data, mask_data, y_data, + upscale_in_train, increment, main_offset); } else { if (upscale_in_train) { // todo: can y share with data with x directly? @@ -278,6 +213,22 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, } } +template +struct CudaDropoutGradFunctor { + using MT = typename details::MPTypeTrait::Type; + + explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {} + + __device__ __forceinline__ T operator()(const T dout, + const MaskType mask) const { + return static_cast(static_cast(dout) * static_cast(mask) * + factor_); + } + + private: + MT factor_; +}; + template void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, const std::string dropout_implementation, diff --git a/paddle/phi/kernels/funcs/distribution_helper.h b/paddle/phi/kernels/funcs/distribution_helper.h index acc31d68b78..f752ec0c5cf 100644 --- a/paddle/phi/kernels/funcs/distribution_helper.h +++ b/paddle/phi/kernels/funcs/distribution_helper.h @@ -114,13 +114,19 @@ struct normal_transform { namespace kps = phi::kps; /*********************** Distribution Function *************************/ -template -struct uniform_distribution; template struct normal_distribution; #if defined(__NVCC__) +template +struct uniform_distribution { + __device__ inline T operator()(curandStatePhilox4_32_10_t *state) const { + return static_cast(curand_uniform(state)); + } + static constexpr int kReturnsCount = 1; +}; + template <> struct uniform_distribution { __device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const { @@ -177,6 +183,14 @@ struct normal_distribution { }; #else +template +struct uniform_distribution { + __device__ inline T operator()(hiprandStatePhilox4_32_10_t *state) const { + return hiprand_uniform(state); + } + static constexpr int kReturnsCount = 1; +}; + template <> struct uniform_distribution { __device__ inline float4 operator()( diff --git a/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu b/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu index 5d0097af2ca..5a4ce3a2679 100644 --- a/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu @@ -17,11 +17,10 @@ #include #include +#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h" #include "paddle/phi/kernels/masked_select_grad_kernel.h" -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" namespace phi { template @@ -50,7 +49,7 @@ void MaskedSelectGradKernel(const Context& dev_ctx, const DenseTensor& mask, DenseTensor* x_grad) { auto mask_size = mask.numel(); - auto* out_data = x_grad->mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(x_grad); if (mask_size <= 0) return; using Functor = MaskedSelectGradFunctor; phi::funcs::SelectKernel( -- GitLab