From b4931ab1cda5282749c3e27cba33aa24c5e0ec0c Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Mon, 11 Jan 2021 15:54:39 +0800 Subject: [PATCH] [Cherry pick] improve dropout (#30260) * improve dropout (#29465) * improve drop out * add VectorizedRandomGeneratorWithGenerator * fix bug * modify according to comments * improve dropout grad (#29605) * improve grad perf * fix the bug of dropout_grad (#29813) --- paddle/fluid/operators/dropout_op.cu | 183 +++++++++++++-------------- paddle/fluid/operators/dropout_op.h | 71 ++++++++++- 2 files changed, 156 insertions(+), 98 deletions(-) diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index 49ad67bbca3..cf90b9eb52b 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/dropout_op.h" @@ -27,24 +28,18 @@ namespace paddle { namespace operators { template -__global__ void RandomGenerator(const size_t n, const int seed, +__global__ void RandomGenerator(const size_t n, uint64_t seed, const float dropout_prob, const T* src, MaskType* mask_data, T* dst, - bool is_upscale_in_train) { + bool is_upscale_in_train, uint64_t increment) { curandStatePhilox4_32_10_t state; int idx = blockDim.x * blockIdx.x + threadIdx.x; - int step_size = 0; + curand_init(seed, idx, increment, &state); MaskType mask; T dest; for (; idx < n; idx += blockDim.x * gridDim.x) { T s = src[idx]; - if (step_size == 0) { - curand_init(seed, idx, idx, &state); - step_size = blockDim.x * gridDim.x; - } else { - curand_init(seed, idx, step_size, &state); - } if (curand_uniform(&state) < dropout_prob) { mask = 0; dest = 0; @@ -61,74 +56,49 @@ __global__ void RandomGenerator(const size_t n, const int seed, } } -template -__global__ void RandomGeneratorWithSeed(const size_t n, const int* seed, - const float dropout_prob, const T* src, - MaskType* mask_data, T* dst, - bool is_upscale_in_train) { +template +__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, + const float dropout_prob, + const T* src, MaskType* mask_data, + T* dst, bool is_upscale_in_train, + uint64_t increment) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; curandStatePhilox4_32_10_t state; - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int step_size = 0; + curand_init(seed, idx, increment, &state); MaskType mask; T dest; - for (; idx < n; idx += blockDim.x * gridDim.x) { - T s = src[idx]; - if (step_size == 0) { - curand_init(seed[0], idx, idx, &state); - step_size = blockDim.x * gridDim.x; - } else { - curand_init(seed[0], idx, step_size, &state); - } - if (curand_uniform(&state) < dropout_prob) { - mask = 0; - dest = 0; - } else { - mask = 1; - if (is_upscale_in_train) { - dest = s / static_cast(1.0f - dropout_prob); + using LoadT = AlignedVector; + using MaskLoadT = AlignedVector; + T factor = static_cast(1.0f / (1.0f - dropout_prob)); + for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) { + T src_vec[VecSize]; + LoadT* value = reinterpret_cast(&src_vec); + *value = *reinterpret_cast(&src[i]); + float4 rand = curand_uniform4(&state); + + T dest_vec[VecSize]; + MaskType mask_vec[VecSize]; + +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + if ((&rand.x)[ii] < dropout_prob) { + dest_vec[ii] = 0; + mask_vec[ii] = 0; } else { - dest = s; + if (is_upscale_in_train) { + dest_vec[ii] = src_vec[ii] * factor; + } else { + dest_vec[ii] = src_vec[ii]; + } + mask_vec[ii] = 1; } } - mask_data[idx] = mask; - dst[idx] = dest; - } -} -template -__global__ void RandomGeneratorWithGenerator(const size_t n, uint64_t seed, - const float dropout_prob, - const T* src, MaskType* mask_data, - T* dst, bool is_upscale_in_train, - uint64_t increment) { - curandStatePhilox4_32_10_t state; - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int step_size = 0; - - MaskType mask; - T dest; - for (; idx < n; idx += blockDim.x * gridDim.x) { - T s = src[idx]; - if (step_size == 0) { - curand_init(seed, idx, increment, &state); - step_size = blockDim.x * gridDim.x; - } else { - curand_init(seed, idx, increment, &state); - } - if (curand_uniform(&state) < dropout_prob) { - mask = 0; - dest = 0; - } else { - mask = 1; - if (is_upscale_in_train) { - dest = s / static_cast(1.0f - dropout_prob); - } else { - dest = s; - } - } - mask_data[idx] = mask; - dst[idx] = dest; + *(reinterpret_cast(&dst[i])) = + *reinterpret_cast(&dest_vec[0]); + *(reinterpret_cast(&mask_data[i])) = + *reinterpret_cast(&mask_vec[0]); } } @@ -168,38 +138,61 @@ class GPUDropoutKernel : public framework::OpKernel { return; } - int threads = 512; - int grid = (x_numel + threads - 1) / threads; + const auto& dev_ctx = context.cuda_device_context(); + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(dev_ctx, size); + + // increment is used to set the args(offset) of curand_init, which defines + // offset in subsequence. + // The detail: + // https://docs.nvidia.com/cuda/curand/device-api-overview.html + // Increment should be at least the number of curand() random numbers used + // in each thread to avoid the random number generated this time being the + // same as the previous calls. + uint64_t seed_data; + uint64_t increment; + int vec_size = VectorizedSize(x_data); + auto offset = ((x_numel - 1) / (config.block_per_grid.x * + config.thread_per_block.x * vec_size) + + 1) * + vec_size; + int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()) + .GetDeviceId(); + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + if (seed && platform::is_gpu_place(seed->place())) { - auto seed_gpu_data = seed->data(); - RandomGeneratorWithSeed<<>>( - size, seed_gpu_data, dropout_prob, x_data, mask_data, y_data, - upscale_in_train); - return; - } - int seed_data; - std::random_device rnd; - if (seed) { - seed_data = *(seed->data()); + framework::Tensor seed_cpu_tensor; + TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); + seed_data = static_cast(seed_cpu_tensor.data()[0]); + increment = offset; + } else if (gen_cuda->GetIsInitPy() && (!context.Attr("fix_seed"))) { + auto seed_offset = gen_cuda->IncrementOffset(offset); + seed_data = seed_offset.first; + increment = seed_offset.second; } else { - seed_data = - context.Attr("fix_seed") ? context.Attr("seed") : rnd(); + if (seed) { + seed_data = *(seed->data()); + } else { + std::random_device rnd; + seed_data = context.Attr("fix_seed") ? context.Attr("seed") + : rnd(); + } + increment = offset; } - int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()) - .GetDeviceId(); - auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); - if (gen_cuda->GetIsInitPy() && (!context.Attr("fix_seed"))) { - auto seed_offset = gen_cuda->IncrementOffset(1); - RandomGeneratorWithGenerator<<>>( - size, seed_offset.first, dropout_prob, x_data, mask_data, y_data, - upscale_in_train, seed_offset.second); - return; + if (vec_size == 4 && size % 4 == 0) { + VectorizedRandomGenerator< + T, uint8_t, + 4><<>>( + size, seed_data, dropout_prob, x_data, mask_data, y_data, + upscale_in_train, increment); + } else { + RandomGenerator<<>>( + size, seed_data, dropout_prob, x_data, mask_data, y_data, + upscale_in_train, increment); } - RandomGenerator<<>>( - size, seed_data, dropout_prob, x_data, mask_data, y_data, - upscale_in_train); } else { auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 161c4282ec2..d77193e4851 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -17,13 +17,62 @@ limitations under the License. */ #include #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { +// aligned vector generates vectorized load/store on CUDA +template +struct alignas(sizeof(T) * Size) AlignedVector { + T val[Size]; +}; + +template +inline int VectorizedSize(const T* pointer) { + uint64_t address = reinterpret_cast(pointer); + constexpr int vec4 = std::alignment_of>::value; // NOLINT + if (address % vec4 == 0) { + return 4; + } + return 1; +} + +#ifdef __NVCC__ +template +__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask, + const T factor, const int64_t size, + T* dx) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + + using LoadT = AlignedVector; + using MaskLoadT = AlignedVector; + + for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { + T dout_vec[VecSize]; + LoadT* dout_value = reinterpret_cast(&dout_vec); + *dout_value = *reinterpret_cast(&dout[i]); + + MaskType mask_vec[VecSize]; + MaskLoadT* mask_value = reinterpret_cast(&mask_vec); + *mask_value = *reinterpret_cast(&mask[i]); + + T dx_vec[VecSize]; + +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + dx_vec[ii] = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; + } + + *(reinterpret_cast(&dx[i])) = *reinterpret_cast(&dx_vec[0]); + } +} +#endif + using Tensor = framework::Tensor; template @@ -119,6 +168,7 @@ class DropoutGradKernel : public framework::OpKernel { auto* grad_y = context.Input(framework::GradVarName("Out")); auto* mask = context.Input("Mask"); grad_x->mutable_data(context.GetPlace()); + auto size = grad_x->numel(); auto M = EigenVector::Flatten(*mask); auto dX = EigenVector::Flatten(*grad_x); @@ -126,7 +176,6 @@ class DropoutGradKernel : public framework::OpKernel { auto& place = *context.template device_context().eigen_device(); - auto& dropout_implementation = context.Attr("dropout_implementation"); if (dropout_implementation == "upscale_in_train") { @@ -134,8 +183,24 @@ class DropoutGradKernel : public framework::OpKernel { if (dropout_prob == 1.0f) { dX.device(place) = static_cast(0) * dY; } else { - dX.device(place) = - dY * M.cast() / static_cast(1.0f - dropout_prob); + int vec_size = VectorizedSize(grad_y->data()); + if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 && + size % 4 == 0) { +#ifdef __NVCC__ + auto factor = static_cast(1.0f / (1.0f - dropout_prob)); + auto stream = context.cuda_device_context().stream(); + platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D( + context.cuda_device_context(), size); + DropoutGradCUDAKernel< + T, uint8_t, + 4><<>>( + grad_y->data(), mask->data(), factor, size, + grad_x->data()); +#endif + } else { + dX.device(place) = + dY * M.cast() / static_cast(1.0f - dropout_prob); + } } } else { dX.device(place) = dY * M.cast(); -- GitLab