/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #ifdef PADDLE_WITH_CUDA #include #include #endif #ifdef PADDLE_WITH_HIP #include #include #endif #include "paddle/phi/kernels/funcs/dropout_impl_util.h" #include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/primitive/compute_primitives.h" #include "paddle/phi/kernels/primitive/datamover_primitives.h" namespace phi { namespace funcs { template struct DstFunctor { using MT = typename phi::kps::details::MPTypeTrait::Type; HOSTDEVICE inline DstFunctor(const float retain_prob, const bool is_upscale_in_train, const int64_t num) : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train), num_(num) { factor = static_cast(1.0f / retain_prob_); } HOSTDEVICE inline T operator()(const T src_val, const uint8_t mask) const { for (int i = 0; i < num_; i++) { if (mask == static_cast(1)) { return is_upscale_in_train_ ? static_cast(static_cast(src_val) * factor) : static_cast(src_val); } else { return static_cast(0); } } } private: const float retain_prob_; const bool is_upscale_in_train_; const int64_t num_; MT factor; }; template struct MaskFunctor { explicit MaskFunctor(const float retain_prob) : retain_prob_(retain_prob) {} HOSTDEVICE inline void operator()(T* dst, const float* rand, int num) const { static constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; // 0 ~ kCount - 1 is dst, kCount ~ 2 * kCount - 1 is mask #pragma unroll for (int i = 0; i < kCount; i++) { dst[i] = rand[i] < retain_prob_ ? static_cast(1) : static_cast(0); } } private: float retain_prob_; }; template struct DstMaskFunctor { using MT = typename phi::kps::details::MPTypeTrait::Type; HOSTDEVICE inline DstMaskFunctor(const float retain_prob, const bool is_upscale_in_train) : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) { factor = static_cast(1.0f / retain_prob_); } HOSTDEVICE inline void operator()(T* dst, const T* src_val, const float* rand, int num) const { static constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; // 0 ~ kCount - 1 is dst, kCount ~ 2 * kCount - 1 is mask #pragma unroll for (int i = 0; i < kCount; i++) { if (rand[i] < retain_prob_) { dst[i] = is_upscale_in_train_ ? static_cast(static_cast(src_val[i]) * factor) : static_cast(src_val[i]); dst[i + kCount] = static_cast(1); } else { dst[i] = static_cast(0); dst[i + kCount] = dst[i]; } } } private: MT factor; float retain_prob_; bool is_upscale_in_train_; }; template __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, const float dropout_prob, const T* src, uint8_t* mask, T* dst, bool is_upscale_in_train, uint64_t increment, size_t main_offset) { size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); static constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount; #ifdef PADDLE_WITH_HIP hiprandStatePhilox4_32_10_t state; hiprand_init(seed, idx + THREAD_ID_X, increment, &state); using SType = hiprandStatePhilox4_32_10_t; #else curandStatePhilox4_32_10_t state; curand_init(seed, idx + THREAD_ID_X, increment, &state); using SType = curandStatePhilox4_32_10_t; #endif T dst_mask[kCount * 2]; // 0 ~ kCount - 1 : dst, kCount ~ 2 * kCount - 1: mask float rands[kCount]; uint8_t mask_result[kCount]; using Rand = phi::funcs::uniform_distribution; using Cast = kps::IdentityFunctor; int deal_size = BLOCK_NUM_X * kCount; size_t fix = idx * kCount; auto dst_functor = DstMaskFunctor(1.0f - dropout_prob, is_upscale_in_train); for (; fix < main_offset; fix += stride) { kps::ReadData(&dst_mask[0], src + fix, deal_size); kps::ElementwiseRandom( &rands[0], Rand(), &state); // dst kps::OperatorTernary>( &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); kps::WriteData(dst + fix, &dst_mask[0], deal_size); // mask kps::ElementwiseUnary( &mask_result[0], &dst_mask[kCount], Cast()); kps::WriteData( mask + fix, &mask_result[0], deal_size); } int remainder = n - fix; if (remainder > 0) { kps::ReadData(&dst_mask[0], src + fix, remainder); kps::ElementwiseRandom( &rands[0], Rand(), &state); // dst kps::OperatorTernary>( &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount); kps::WriteData(dst + fix, &dst_mask[0], remainder); // mask kps::ElementwiseUnary( &mask_result[0], &dst_mask[kCount], Cast()); kps::WriteData( mask + fix, &mask_result[0], remainder); } } template __global__ void DropOutNdForwardKernel( const size_t n, uint64_t seed, const float dropout_prob, const T* src, uint8_t* mask, uint64_t increment, size_t main_offset, DstFunctor dst_functor, MaskFunctor mask_functor, T* y, int64_t N, kps::details::BroadcastConfig broadcast_config, const uint64_t* seed_ptr) { // Vectorized Generate Mask // kCount is 4 for curand_uniform4 is used if (seed_ptr) { seed = seed_ptr[0]; } constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount; #ifdef PADDLE_WITH_HIP hiprandStatePhilox4_32_10_t state; hiprand_init(seed, idx + THREAD_ID_X, increment, &state); using SType = hiprandStatePhilox4_32_10_t; #else curandStatePhilox4_32_10_t state; curand_init(seed, idx + THREAD_ID_X, increment, &state); using SType = curandStatePhilox4_32_10_t; #endif T dst_mask[kCount]; // 0 ~ kCount - 1 : dst, kCount ~ 2 * kCount - 1: mask float rands[kCount]; uint8_t mask_result[kCount]; using Rand = phi::funcs::uniform_distribution; using Cast = kps::IdentityFunctor; int deal_size = BLOCK_NUM_X * kCount; size_t fix = idx * kCount; for (; fix < main_offset; fix += stride) { kps::ReadData(&dst_mask[0], src + fix, deal_size); kps::ElementwiseRandom( &rands[0], Rand(), &state); // dst kps::OperatorBinary>( &dst_mask[0], &rands[0], mask_functor, kCount); // mask kps::ElementwiseUnary( &mask_result[0], &dst_mask[0], Cast()); kps::WriteData( mask + fix, &mask_result[0], deal_size); } int remainder = n - fix; if (remainder > 0) { kps::ReadData(&dst_mask[0], src + fix, remainder); kps::ElementwiseRandom( &rands[0], Rand(), &state); // dst kps::OperatorBinary>( &dst_mask[0], &rands[0], mask_functor, kCount); // mask kps::ElementwiseUnary( &mask_result[0], &dst_mask[0], Cast()); kps::WriteData( mask + fix, &mask_result[0], remainder); } // Broadcast mask data and do elementwise operaiton with DstFunctor CUDA_KERNEL_LOOP(i, N) { uint32_t offset = 0u; uint32_t idx = i; // Use (j < phi::DDim::kMaxRank) conditiion rather than // (j < broadcast_config.rank) for (#pragma unroll) #pragma unroll for (int j = 0; j < phi::DDim::kMaxRank; ++j) { if (j == broadcast_config.rank) break; auto fast_divmoder = broadcast_config.divmoders[j].Divmod(idx); idx = fast_divmoder.val[0]; offset += broadcast_config.strides[j] * fast_divmoder.val[1]; } y[i] = dst_functor(src[i], mask[offset]); } } template void ScaleByDropoutFactor(const phi::GPUContext& dev_ctx, const phi::DenseTensor& x, phi::DenseTensor* y, MT factor) { std::vector ins = {&x}; std::vector outs = {y}; auto functor = phi::funcs::ScaleFunctor(factor); phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); } template void DropoutFwGPUKernelDriver( const phi::GPUContext& dev_ctx, bool is_test, float dropout_prob, bool upscale_in_train, bool is_fix_seed, int seed_val, const phi::DenseTensor& x, const phi::DenseTensor* seed, phi::DenseTensor* mask, phi::DenseTensor* y, bool is_dropout_nd = false, const std::vector& axis = std::vector()) { int64_t x_numel = x.numel(); auto stream = dev_ctx.stream(); auto* x_data = x.data(); auto* y_data = y->data(); if (!is_test && mask) { auto* mask_data = mask->data(); size_t size = phi::product(mask->dims()); if (dropout_prob == 1.0f) { #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_GPU_SUCCESS( hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); PADDLE_ENFORCE_GPU_SUCCESS( hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream)); #else PADDLE_ENFORCE_GPU_SUCCESS( cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); PADDLE_ENFORCE_GPU_SUCCESS( cudaMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream)); #endif return; } uint64_t seed_data; uint64_t increment; // VectorizedRandomGenerator use curand_uniform4, so kVecSize is 4; constexpr int kVecSize = phi::funcs::uniform_distribution::kReturnsCount; auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, kVecSize); size_t grid_size = gpu_config.GetGridSize(); size_t block_size = gpu_config.GetBlockSize(); int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id); size_t max_grid_size = prop.maxThreadsPerMultiProcessor * prop.multiProcessorCount / block_size; grid_size = std::min(grid_size, max_grid_size); auto offset = ((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize; size_t main_offset = size / (block_size * kVecSize) * (block_size * kVecSize); if (is_dropout_nd) { auto dst_functor = DstFunctor(1.0f - dropout_prob, upscale_in_train, x_numel); std::vector out_dims = std::move(phi::vectorize(x.dims())); std::vector in_dims = std::move(phi::vectorize(mask->dims())); std::reverse(out_dims.begin(), out_dims.end()); std::reverse(in_dims.begin(), in_dims.end()); kps::details::BroadcastConfig broadcast_config( out_dims, in_dims, x.dims().size()); auto mask_functor = MaskFunctor(1.0f - dropout_prob); bool copy_in_kernel = GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment, true); const uint64_t* seed_ptr = copy_in_kernel ? seed->data() : nullptr; DropOutNdForwardKernel <<>>(size, seed_data, dropout_prob, x_data, mask_data, increment, main_offset, dst_functor, mask_functor, y_data, y->numel(), broadcast_config, seed_ptr); } else { bool copy_in_kernel = GetSeedDataAndIncrement( dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment); #define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!is_fix_seed, PD_DROPOUT_KERNEL_NAME, grid_size, block_size, 0, stream, offset, KERNEL_PARAMS.As(1), KERNEL_PARAMS.As(7), size, seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train, increment, main_offset); #undef PD_DROPOUT_KERNEL_NAME } VLOG(4) << "Dropout seed: " << seed << ", offset: " << offset << ", seed_data:" << seed_data; } else { if (upscale_in_train) { // y = x phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, y); } else { using MT = typename phi::kps::details::MPTypeTrait::Type; MT factor = static_cast(1.0f - dropout_prob); // y = factor * x ScaleByDropoutFactor(dev_ctx, x, y, factor); } } } template struct CudaDropoutGradFunctor { using MT = typename phi::kps::details::MPTypeTrait::Type; explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {} __device__ __forceinline__ T operator()(const T dout, const uint8_t mask) const { return static_cast(static_cast(dout) * static_cast(mask) * factor_); } private: MT factor_; }; template void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, float dropout_prob, bool upscale_in_train, const phi::DenseTensor& grad_y, const phi::DenseTensor& mask, phi::DenseTensor* grad_x, bool is_dropout_nd = false) { using MT = typename phi::kps::details::MPTypeTrait::Type; auto stream = dev_ctx.stream(); if (is_test) { MT factor = static_cast(upscale_in_train ? 1.0f : 1.0f - dropout_prob); // y = factor * x ScaleByDropoutFactor(dev_ctx, grad_y, grad_x, factor); } else { phi::DenseTensor broadcasted_mask; if (is_dropout_nd) { broadcasted_mask.Resize(grad_y.dims()); dev_ctx.template Alloc(&broadcasted_mask); std::vector broadcast_ins = {&mask}; std::vector broadcast_outs = {&broadcasted_mask}; phi::funcs::BroadcastKernel(dev_ctx, broadcast_ins, &broadcast_outs, -1, kps::IdentityFunctor()); } std::vector ins = { &grad_y, is_dropout_nd ? &broadcasted_mask : &mask}; std::vector outs = {grad_x}; if (upscale_in_train) { if (dropout_prob == 1.0f) { #ifdef PADDLE_WITH_HIP hipMemset(grad_x->data(), 0, grad_x->numel() * sizeof(T)); #else cudaMemset(grad_x->data(), 0, grad_x->numel() * sizeof(T)); #endif } else { MT factor = static_cast(1.0f / (1.0f - dropout_prob)); phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, CudaDropoutGradFunctor(factor)); } } else { MT factor = static_cast(1.0f); phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, CudaDropoutGradFunctor(factor)); } } } } // namespace funcs } // namespace phi