From 2b88057f7df1992237d8c726a173b34170efe76d Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Wed, 15 Sep 2021 11:24:30 +0800 Subject: [PATCH] Refactor dropout cuda impl for code reuse. (#35621) --- paddle/fluid/operators/dropout_impl.cu.h | 297 +++++++++++++++++++++++ paddle/fluid/operators/dropout_op.cu | 239 +++--------------- paddle/fluid/operators/dropout_op.h | 51 +--- 3 files changed, 336 insertions(+), 251 deletions(-) create mode 100644 paddle/fluid/operators/dropout_impl.cu.h diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h new file mode 100644 index 0000000000..4261a5f253 --- /dev/null +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -0,0 +1,297 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#ifdef PADDLE_WITH_CUDA +#include +#include +#include "paddle/fluid/platform/dynload/curand.h" +#endif +#ifdef PADDLE_WITH_HIP +#include +#include +#include "paddle/fluid/platform/dynload/hiprand.h" +#endif + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/platform/aligned_vector.h" +#include "paddle/fluid/platform/gpu_launch_config.h" + +namespace paddle { +namespace operators { + +template +__global__ void RandomGenerator(const size_t n, uint64_t seed, + const float dropout_prob, const T* src, + MaskType* mask, T* dst, + bool is_upscale_in_train, uint64_t increment) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; +#ifdef PADDLE_WITH_HIP + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, idx, increment, &state); +#else + curandStatePhilox4_32_10_t state; + curand_init(seed, idx, increment, &state); +#endif + + MaskType mask_val; + T dst_val; + T factor = static_cast(1.0f / (1.0f - dropout_prob)); + for (; idx < n; idx += blockDim.x * gridDim.x) { + T src_val = src[idx]; +#ifdef PADDLE_WITH_HIP + if (hiprand_uniform(&state) < dropout_prob) { +#else + if (curand_uniform(&state) < dropout_prob) { +#endif + mask_val = 0; + dst_val = 0; + } else { + mask_val = 1; + dst_val = is_upscale_in_train ? src_val * factor : src_val; + } + mask[idx] = mask_val; + dst[idx] = dst_val; + } +} + +template +__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, + const float dropout_prob, + const T* src, MaskType* mask, T* dst, + bool is_upscale_in_train, + uint64_t increment) { + using LoadT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + +#ifdef PADDLE_WITH_HIP + int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, idx, increment, &state); +#else + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seed, idx, increment, &state); +#endif + + T factor = static_cast(1.0f / (1.0f - dropout_prob)); + for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) { + LoadT src_val; + platform::Load(&src[i], &src_val); + +#ifdef PADDLE_WITH_HIP + float4 rand = hiprand_uniform4(&state); +#else + float4 rand = curand_uniform4(&state); +#endif + + LoadT dst_val; + MaskLoadT mask_val; + +#pragma unroll + for (int j = 0; j < VecSize; j++) { + if ((&rand.x)[j] < dropout_prob) { + dst_val[j] = 0; + mask_val[j] = 0; + } else { + dst_val[j] = is_upscale_in_train ? src_val[j] * factor : src_val[j]; + mask_val[j] = 1; + } + } + + platform::Store(dst_val, &dst[i]); + platform::Store(mask_val, &mask[i]); + } +} + +template +__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask, + const T factor, const int64_t size, + T* dx) { + using LoadT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { + LoadT dout_val; + platform::Load(&dout[i], &dout_val); + + MaskLoadT mask_val; + platform::Load(&mask[i], &mask_val); + + LoadT dx_val; + +#pragma unroll + for (int j = 0; j < VecSize; j++) { + dx_val[j] = dout_val[j] * static_cast(mask_val[j]) * factor; + } + + platform::Store(dx_val, &dx[i]); + } +} + +template +void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, + bool is_test, + const std::string dropout_implementation, + float dropout_prob, bool upscale_in_train, + bool is_fix_seed, int seed_val, const Tensor& x, + const Tensor* seed, Tensor* mask, Tensor* y) { + auto& place = *dev_ctx.eigen_device(); + + if (!is_test) { + int64_t x_numel = x.numel(); + auto stream = dev_ctx.stream(); + auto* mask_data = mask->data(); + size_t size = framework::product(mask->dims()); + + auto* x_data = x.data(); + auto* y_data = y->data(); + if (dropout_prob == 1.0f) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream)); +#endif + return; + } + + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(dev_ctx, size); + + // increment is used to set the args(offset) of curand_init, which defines + // offset in subsequence. + // The detail: + // https://docs.nvidia.com/cuda/curand/device-api-overview.html + // Increment should be at least the number of curand() random numbers used + // in each thread to avoid the random number generated this time being the + // same as the previous calls. + uint64_t seed_data; + uint64_t increment; + int vec_size = platform::GetVectorizedSize(x_data); + auto offset = ((x_numel - 1) / (config.block_per_grid.x * + config.thread_per_block.x * vec_size) + + 1) * + vec_size; + int device_id = + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + + if ((seed) && platform::is_gpu_place(seed->place())) { + framework::Tensor seed_cpu_tensor; + TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); + seed_data = static_cast(seed_cpu_tensor.data()[0]); + increment = offset; + } else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) { + auto seed_offset = gen_cuda->IncrementOffset(offset); + seed_data = seed_offset.first; + increment = seed_offset.second; + } else { + if (seed) { + seed_data = *(seed->data()); + } else { + std::random_device rnd; + seed_data = is_fix_seed ? seed_val : rnd(); + } + increment = offset; + } + +#ifdef __HIPCC__ + if (vec_size == 4 && size % 4 == 0) { + hipLaunchKernelGGL( + HIP_KERNEL_NAME(VectorizedRandomGenerator), + config.block_per_grid, config.thread_per_block, 0, stream, size, + seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train, + increment); + } else { + hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomGenerator), + config.block_per_grid, config.thread_per_block, 0, + stream, size, seed_data, dropout_prob, x_data, + mask_data, y_data, upscale_in_train, increment); + } +#else + if (vec_size == 4 && size % 4 == 0) { + VectorizedRandomGenerator< + T, uint8_t, + 4><<>>( + size, seed_data, dropout_prob, x_data, mask_data, y_data, + upscale_in_train, increment); + } else { + RandomGenerator<<>>( + size, seed_data, dropout_prob, x_data, mask_data, y_data, + upscale_in_train, increment); + } +#endif + } else { + auto X = EigenMatrix::Reshape(x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); + if (upscale_in_train) { + Y.device(place) = X; + } else { + Y.device(place) = X * static_cast(1.0f - dropout_prob); + } + } +} + +template +void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, + const std::string dropout_implementation, + float dropout_prob, const Tensor& grad_y, + const Tensor& mask, int64_t size, + Tensor* grad_x) { + auto M = EigenVector::Flatten(mask); + auto dX = EigenVector::Flatten(*grad_x); + auto dY = EigenVector::Flatten(grad_y); + + auto& place = *dev_ctx.eigen_device(); + if (dropout_implementation == "upscale_in_train") { + if (dropout_prob == 1.0f) { + dX.device(place) = static_cast(0) * dY; + } else { + int vec_size = platform::GetVectorizedSize(grad_y.data()); + if (vec_size == 4 && size % 4 == 0) { + auto factor = static_cast(1.0f / (1.0f - dropout_prob)); + auto stream = dev_ctx.stream(); + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(dev_ctx, size); + DropoutGradCUDAKernel< + T, uint8_t, + 4><<>>( + grad_y.data(), mask.data(), factor, size, + grad_x->data()); + } else { + dX.device(place) = + dY * M.cast() / static_cast(1.0f - dropout_prob); + } + } + } else { + dX.device(place) = dY * M.cast(); + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index 958f037a04..447184b948 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -12,113 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_CUDA -#include -#include -#include "paddle/fluid/platform/dynload/curand.h" -#endif -#ifdef PADDLE_WITH_HIP -#include -#include -#include "paddle/fluid/platform/dynload/hiprand.h" -#endif -#include -#include -#include -#include -#include #include -#include "paddle/fluid/memory/memcpy.h" + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/dropout_impl.cu.h" #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { -template -__global__ void RandomGenerator(const size_t n, uint64_t seed, - const float dropout_prob, const T* src, - MaskType* mask, T* dst, - bool is_upscale_in_train, uint64_t increment) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; -#ifdef PADDLE_WITH_HIP - hiprandStatePhilox4_32_10_t state; - hiprand_init(seed, idx, increment, &state); -#else - curandStatePhilox4_32_10_t state; - curand_init(seed, idx, increment, &state); -#endif - - MaskType mask_val; - T dst_val; - T factor = static_cast(1.0f / (1.0f - dropout_prob)); - for (; idx < n; idx += blockDim.x * gridDim.x) { - T src_val = src[idx]; -#ifdef PADDLE_WITH_HIP - if (hiprand_uniform(&state) < dropout_prob) { -#else - if (curand_uniform(&state) < dropout_prob) { -#endif - mask_val = 0; - dst_val = 0; - } else { - mask_val = 1; - dst_val = is_upscale_in_train ? src_val * factor : src_val; - } - mask[idx] = mask_val; - dst[idx] = dst_val; - } -} - -template -__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, - const float dropout_prob, - const T* src, MaskType* mask, T* dst, - bool is_upscale_in_train, - uint64_t increment) { - using LoadT = platform::AlignedVector; - using MaskLoadT = platform::AlignedVector; - -#ifdef PADDLE_WITH_HIP - int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x; - hiprandStatePhilox4_32_10_t state; - hiprand_init(seed, idx, increment, &state); -#else - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - curandStatePhilox4_32_10_t state; - curand_init(seed, idx, increment, &state); -#endif - - T factor = static_cast(1.0f / (1.0f - dropout_prob)); - for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) { - LoadT src_val; - platform::Load(&src[i], &src_val); - -#ifdef PADDLE_WITH_HIP - float4 rand = hiprand_uniform4(&state); -#else - float4 rand = curand_uniform4(&state); -#endif - - LoadT dst_val; - MaskLoadT mask_val; - -#pragma unroll - for (int j = 0; j < VecSize; j++) { - if ((&rand.x)[j] < dropout_prob) { - dst_val[j] = 0; - mask_val[j] = 0; - } else { - dst_val[j] = is_upscale_in_train ? src_val[j] * factor : src_val[j]; - mask_val[j] = 1; - } - } - - platform::Store(dst_val, &dst[i]); - platform::Store(mask_val, &mask[i]); - } -} - // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. @@ -137,109 +40,41 @@ class GPUDropoutKernel : public framework::OpKernel { context.Attr("dropout_implementation"); bool upscale_in_train = (dropout_implementation == "upscale_in_train"); - auto& place = *context.template device_context().eigen_device(); - if (!context.Attr("is_test")) { - int64_t x_numel = x->numel(); - auto stream = context.cuda_device_context().stream(); - - auto* mask = context.Output("Mask"); - auto* mask_data = mask->mutable_data(context.GetPlace()); - size_t size = framework::product(mask->dims()); - auto* x_data = x->data(); - auto* y_data = y->mutable_data(context.GetPlace()); - if (dropout_prob == 1.0f) { -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_CUDA_SUCCESS( - hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); - PADDLE_ENFORCE_CUDA_SUCCESS( - hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream)); -#else - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); - PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( - mask_data, 0, x_numel * sizeof(*mask_data), stream)); -#endif - return; - } + bool is_test = context.Attr("is_test"); - const auto& dev_ctx = context.cuda_device_context(); - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(dev_ctx, size); + auto& dev_ctx = context.cuda_device_context(); + auto* mask = context.Output("Mask"); + mask->mutable_data(context.GetPlace()); - // increment is used to set the args(offset) of curand_init, which defines - // offset in subsequence. - // The detail: - // https://docs.nvidia.com/cuda/curand/device-api-overview.html - // Increment should be at least the number of curand() random numbers used - // in each thread to avoid the random number generated this time being the - // same as the previous calls. - uint64_t seed_data; - uint64_t increment; - int vec_size = platform::GetVectorizedSize(x_data); - auto offset = ((x_numel - 1) / (config.block_per_grid.x * - config.thread_per_block.x * vec_size) + - 1) * - vec_size; - int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()) - .GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + bool is_fix_seed = context.Attr("fix_seed"); + int seed_val = context.Attr("seed"); + DropoutFwGPUKernelDriver(dev_ctx, is_test, dropout_implementation, + dropout_prob, upscale_in_train, is_fix_seed, + seed_val, *x, seed, mask, y); + } +}; - if (seed && platform::is_gpu_place(seed->place())) { - framework::Tensor seed_cpu_tensor; - TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); - seed_data = static_cast(seed_cpu_tensor.data()[0]); - increment = offset; - } else if (gen_cuda->GetIsInitPy() && (!context.Attr("fix_seed"))) { - auto seed_offset = gen_cuda->IncrementOffset(offset); - seed_data = seed_offset.first; - increment = seed_offset.second; - } else { - if (seed) { - seed_data = *(seed->data()); - } else { - std::random_device rnd; - seed_data = context.Attr("fix_seed") ? context.Attr("seed") - : rnd(); - } - increment = offset; - } +template +class GPUDropoutGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ(!context.Attr("is_test"), true, + platform::errors::PreconditionNotMet( + "GradOp is only callable when is_test is false")); + + auto* grad_x = context.Output(framework::GradVarName("X")); + auto* grad_y = context.Input(framework::GradVarName("Out")); + auto* mask = context.Input("Mask"); + grad_x->mutable_data(context.GetPlace()); + auto size = grad_x->numel(); + auto& dropout_implementation = + context.Attr("dropout_implementation"); + float dropout_prob = context.Attr("dropout_prob"); -#ifdef __HIPCC__ - if (vec_size == 4 && size % 4 == 0) { - hipLaunchKernelGGL( - HIP_KERNEL_NAME(VectorizedRandomGenerator), - config.block_per_grid, config.thread_per_block, 0, stream, size, - seed_data, dropout_prob, x_data, mask_data, y_data, - upscale_in_train, increment); - } else { - hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomGenerator), - config.block_per_grid, config.thread_per_block, 0, - stream, size, seed_data, dropout_prob, x_data, - mask_data, y_data, upscale_in_train, increment); - } -#else - if (vec_size == 4 && size % 4 == 0) { - VectorizedRandomGenerator< - T, uint8_t, - 4><<>>( - size, seed_data, dropout_prob, x_data, mask_data, y_data, - upscale_in_train, increment); - } else { - RandomGenerator<<>>( - size, seed_data, dropout_prob, x_data, mask_data, y_data, - upscale_in_train, increment); - } -#endif - } else { - auto X = EigenMatrix::Reshape(*x, 1); - auto Y = EigenMatrix::Reshape(*y, 1); - if (upscale_in_train) { - Y.device(place) = X; - } else { - Y.device(place) = X * static_cast(1.0f - dropout_prob); - } - } + auto& dev_ctx = + context.template device_context(); + DropoutGradGPUKernelDriver(dev_ctx, dropout_implementation, dropout_prob, + *grad_y, *mask, size, grad_x); } }; @@ -253,6 +88,6 @@ REGISTER_OP_CUDA_KERNEL( ops::GPUDropoutKernel, ops::GPUDropoutKernel); REGISTER_OP_CUDA_KERNEL( - dropout_grad, ops::DropoutGradKernel, - ops::DropoutGradKernel, - ops::DropoutGradKernel); + dropout_grad, ops::GPUDropoutGradKernel, + ops::GPUDropoutGradKernel, + ops::GPUDropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 96e6725212..831255bc1d 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -21,40 +21,10 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/aligned_vector.h" -#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { -#if defined(__NVCC__) || defined(__HIPCC__) -template -__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask, - const T factor, const int64_t size, - T* dx) { - using LoadT = platform::AlignedVector; - using MaskLoadT = platform::AlignedVector; - - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { - LoadT dout_val; - platform::Load(&dout[i], &dout_val); - - MaskLoadT mask_val; - platform::Load(&mask[i], &mask_val); - - LoadT dx_val; - -#pragma unroll - for (int j = 0; j < VecSize; j++) { - dx_val[j] = dout_val[j] * static_cast(mask_val[j]) * factor; - } - - platform::Store(dx_val, &dx[i]); - } -} -#endif - using Tensor = framework::Tensor; template @@ -137,7 +107,6 @@ class CPUDropoutKernel : public framework::OpKernel { } } }; - template class DropoutGradKernel : public framework::OpKernel { public: @@ -146,7 +115,6 @@ class DropoutGradKernel : public framework::OpKernel { auto* grad_y = context.Input(framework::GradVarName("Out")); auto* mask = context.Input("Mask"); grad_x->mutable_data(context.GetPlace()); - auto size = grad_x->numel(); auto dX = EigenVector::Flatten(*grad_x); auto dY = EigenVector::Flatten(*grad_y); @@ -169,23 +137,8 @@ class DropoutGradKernel : public framework::OpKernel { if (dropout_prob == 1.0f) { dX.device(place) = static_cast(0) * dY; } else { - int vec_size = platform::GetVectorizedSize(grad_y->data()); - if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 && - size % 4 == 0) { -#if defined(__NVCC__) || defined(__HIPCC__) - auto factor = static_cast(1.0f / (1.0f - dropout_prob)); - auto stream = context.cuda_device_context().stream(); - platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D( - context.cuda_device_context(), size); - DropoutGradCUDAKernel<<< - config.block_per_grid, config.thread_per_block, 0, stream>>>( - grad_y->data(), mask->data(), factor, size, - grad_x->data()); -#endif - } else { - dX.device(place) = - dY * M.cast() / static_cast(1.0f - dropout_prob); - } + dX.device(place) = + dY * M.cast() / static_cast(1.0f - dropout_prob); } } else { dX.device(place) = dY * M.cast(); -- GitLab