From b81358d12a9f83855897b7ac58d7ddf2f95efb7e Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 15 Feb 2022 10:41:41 +0800 Subject: [PATCH] add dropout fp32 (#39501) --- paddle/fluid/operators/dropout_impl.cu.h | 38 ++++++++++++++++-------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index d7c49466d5..96af5ac26d 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -30,6 +30,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/dropout_impl_util.h" #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" @@ -45,6 +46,7 @@ __global__ void RandomGenerator(const size_t n, uint64_t seed, const float dropout_prob, const T* src, MaskType* mask, T* dst, bool is_upscale_in_train, uint64_t increment) { + using MT = typename details::MPTypeTrait::Type; int idx = blockDim.x * blockIdx.x + threadIdx.x; #ifdef PADDLE_WITH_HIP hiprandStatePhilox4_32_10_t state; @@ -56,7 +58,7 @@ __global__ void RandomGenerator(const size_t n, uint64_t seed, MaskType mask_val; T dst_val; - T factor = static_cast(1.0f / (1.0f - dropout_prob)); + MT factor = static_cast(1.0f / (1.0f - dropout_prob)); for (; idx < n; idx += blockDim.x * gridDim.x) { T src_val = src[idx]; #ifdef PADDLE_WITH_HIP @@ -68,7 +70,9 @@ __global__ void RandomGenerator(const size_t n, uint64_t seed, dst_val = 0; } else { mask_val = 1; - dst_val = is_upscale_in_train ? src_val * factor : src_val; + dst_val = is_upscale_in_train + ? static_cast(static_cast(src_val) * factor) + : src_val; } mask[idx] = mask_val; dst[idx] = dst_val; @@ -81,6 +85,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, const T* src, MaskType* mask, T* dst, bool is_upscale_in_train, uint64_t increment) { + using MT = typename details::MPTypeTrait::Type; using LoadT = platform::AlignedVector; using MaskLoadT = platform::AlignedVector; @@ -94,7 +99,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, curand_init(seed, idx, increment, &state); #endif - T factor = static_cast(1.0f / (1.0f - dropout_prob)); + MT factor = static_cast(1.0f / (1.0f - dropout_prob)); for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) { LoadT src_val; platform::Load(&src[i], &src_val); @@ -114,7 +119,9 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, dst_val[j] = 0; mask_val[j] = 0; } else { - dst_val[j] = is_upscale_in_train ? src_val[j] * factor : src_val[j]; + dst_val[j] = is_upscale_in_train + ? static_cast(static_cast(src_val[j]) * factor) + : src_val[j]; mask_val[j] = 1; } } @@ -126,21 +133,26 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, template struct CudaDropoutGradFunctor { - explicit CudaDropoutGradFunctor(const T factor) : factor_(factor) {} + using MT = typename details::MPTypeTrait::Type; + + explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {} __device__ __forceinline__ T operator()(const T dout, const MaskType mask) const { - return dout * static_cast(mask) * factor_; + return static_cast(static_cast(dout) * static_cast(mask) * + factor_); } private: - T factor_; + MT factor_; }; template -__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask, - const T factor, const int64_t size, - T* dx) { +__global__ void DropoutGradCUDAKernel( + const T* dout, const MaskType* mask, + const typename details::MPTypeTrait::Type factor, const int64_t size, + T* dx) { + using MT = typename details::MPTypeTrait::Type; using LoadT = platform::AlignedVector; using MaskLoadT = platform::AlignedVector; @@ -156,7 +168,8 @@ __global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask, #pragma unroll for (int j = 0; j < VecSize; j++) { - dx_val[j] = dout_val[j] * static_cast(mask_val[j]) * factor; + dx_val[j] = static_cast(static_cast(dout_val[j]) * + static_cast(mask_val[j]) * factor); } platform::Store(dx_val, &dx[i]); @@ -257,6 +270,7 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, float dropout_prob, const Tensor& grad_y, const Tensor& mask, int64_t size, Tensor* grad_x, bool is_test = false) { + using MT = typename details::MPTypeTrait::Type; auto dX = EigenVector::Flatten(*grad_x); auto dY = EigenVector::Flatten(grad_y); @@ -273,7 +287,7 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, if (dropout_prob == 1.0f) { dX.device(place) = static_cast(0) * dY; } else { - auto factor = static_cast(1.0f / (1.0f - dropout_prob)); + auto factor = static_cast(1.0f / (1.0f - dropout_prob)); auto stream = dev_ctx.stream(); std::vector ins = {&grad_y, &mask}; std::vector outs = {grad_x}; -- GitLab