未验证 提交 b4931ab1 编写于 作者: Z Zhang Ting 提交者: GitHub

[Cherry pick] improve dropout (#30260)

* improve dropout (#29465)

* improve drop out

* add VectorizedRandomGeneratorWithGenerator

* fix bug

* modify according to comments

* improve dropout grad (#29605)

* improve grad perf

* fix the bug of dropout_grad (#29813)
上级 b80beb16
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <algorithm>
#include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/dropout_op.h"
......@@ -27,24 +28,18 @@ namespace paddle {
namespace operators {
template <typename T, typename MaskType>
__global__ void RandomGenerator(const size_t n, const int seed,
__global__ void RandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob, const T* src,
MaskType* mask_data, T* dst,
bool is_upscale_in_train) {
bool is_upscale_in_train, uint64_t increment) {
curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
curand_init(seed, idx, increment, &state);
MaskType mask;
T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
if (step_size == 0) {
curand_init(seed, idx, idx, &state);
step_size = blockDim.x * gridDim.x;
} else {
curand_init(seed, idx, step_size, &state);
}
if (curand_uniform(&state) < dropout_prob) {
mask = 0;
dest = 0;
......@@ -61,74 +56,49 @@ __global__ void RandomGenerator(const size_t n, const int seed,
}
}
template <typename T, typename MaskType>
__global__ void RandomGeneratorWithSeed(const size_t n, const int* seed,
const float dropout_prob, const T* src,
MaskType* mask_data, T* dst,
bool is_upscale_in_train) {
template <typename T, typename MaskType, int VecSize>
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob,
const T* src, MaskType* mask_data,
T* dst, bool is_upscale_in_train,
uint64_t increment) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
curand_init(seed, idx, increment, &state);
MaskType mask;
T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
if (step_size == 0) {
curand_init(seed[0], idx, idx, &state);
step_size = blockDim.x * gridDim.x;
} else {
curand_init(seed[0], idx, step_size, &state);
}
if (curand_uniform(&state) < dropout_prob) {
mask = 0;
dest = 0;
} else {
mask = 1;
if (is_upscale_in_train) {
dest = s / static_cast<T>(1.0f - dropout_prob);
using LoadT = AlignedVector<T, VecSize>;
using MaskLoadT = AlignedVector<MaskType, VecSize>;
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
T src_vec[VecSize];
LoadT* value = reinterpret_cast<LoadT*>(&src_vec);
*value = *reinterpret_cast<const LoadT*>(&src[i]);
float4 rand = curand_uniform4(&state);
T dest_vec[VecSize];
MaskType mask_vec[VecSize];
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
if ((&rand.x)[ii] < dropout_prob) {
dest_vec[ii] = 0;
mask_vec[ii] = 0;
} else {
dest = s;
if (is_upscale_in_train) {
dest_vec[ii] = src_vec[ii] * factor;
} else {
dest_vec[ii] = src_vec[ii];
}
mask_vec[ii] = 1;
}
}
mask_data[idx] = mask;
dst[idx] = dest;
}
}
template <typename T, typename MaskType>
__global__ void RandomGeneratorWithGenerator(const size_t n, uint64_t seed,
const float dropout_prob,
const T* src, MaskType* mask_data,
T* dst, bool is_upscale_in_train,
uint64_t increment) {
curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
MaskType mask;
T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
if (step_size == 0) {
curand_init(seed, idx, increment, &state);
step_size = blockDim.x * gridDim.x;
} else {
curand_init(seed, idx, increment, &state);
}
if (curand_uniform(&state) < dropout_prob) {
mask = 0;
dest = 0;
} else {
mask = 1;
if (is_upscale_in_train) {
dest = s / static_cast<T>(1.0f - dropout_prob);
} else {
dest = s;
}
}
mask_data[idx] = mask;
dst[idx] = dest;
*(reinterpret_cast<LoadT*>(&dst[i])) =
*reinterpret_cast<LoadT*>(&dest_vec[0]);
*(reinterpret_cast<MaskLoadT*>(&mask_data[i])) =
*reinterpret_cast<MaskLoadT*>(&mask_vec[0]);
}
}
......@@ -168,38 +138,61 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
return;
}
int threads = 512;
int grid = (x_numel + threads - 1) / threads;
const auto& dev_ctx = context.cuda_device_context();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(dev_ctx, size);
// increment is used to set the args(offset) of curand_init, which defines
// offset in subsequence.
// The detail:
// https://docs.nvidia.com/cuda/curand/device-api-overview.html
// Increment should be at least the number of curand() random numbers used
// in each thread to avoid the random number generated this time being the
// same as the previous calls.
uint64_t seed_data;
uint64_t increment;
int vec_size = VectorizedSize<T>(x_data);
auto offset = ((x_numel - 1) / (config.block_per_grid.x *
config.thread_per_block.x * vec_size) +
1) *
vec_size;
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
.GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (seed && platform::is_gpu_place(seed->place())) {
auto seed_gpu_data = seed->data<int>();
RandomGeneratorWithSeed<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_gpu_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train);
return;
}
int seed_data;
std::random_device rnd;
if (seed) {
seed_data = *(seed->data<int>());
framework::Tensor seed_cpu_tensor;
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
increment = offset;
} else if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
auto seed_offset = gen_cuda->IncrementOffset(offset);
seed_data = seed_offset.first;
increment = seed_offset.second;
} else {
seed_data =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
if (seed) {
seed_data = *(seed->data<int>());
} else {
std::random_device rnd;
seed_data = context.Attr<bool>("fix_seed") ? context.Attr<int>("seed")
: rnd();
}
increment = offset;
}
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
.GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
auto seed_offset = gen_cuda->IncrementOffset(1);
RandomGeneratorWithGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_offset.first, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, seed_offset.second);
return;
if (vec_size == 4 && size % 4 == 0) {
VectorizedRandomGenerator<
T, uint8_t,
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment);
} else {
RandomGenerator<T, uint8_t><<<config.block_per_grid,
config.thread_per_block, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment);
}
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train);
} else {
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
......
......@@ -17,13 +17,62 @@ limitations under the License. */
#include <random>
#include <string>
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle {
namespace operators {
// aligned vector generates vectorized load/store on CUDA
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
};
template <typename T>
inline int VectorizedSize(const T* pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
if (address % vec4 == 0) {
return 4;
}
return 1;
}
#ifdef __NVCC__
template <typename T, typename MaskType, int VecSize>
__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
const T factor, const int64_t size,
T* dx) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = AlignedVector<T, VecSize>;
using MaskLoadT = AlignedVector<MaskType, VecSize>;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
T dout_vec[VecSize];
LoadT* dout_value = reinterpret_cast<LoadT*>(&dout_vec);
*dout_value = *reinterpret_cast<const LoadT*>(&dout[i]);
MaskType mask_vec[VecSize];
MaskLoadT* mask_value = reinterpret_cast<MaskLoadT*>(&mask_vec);
*mask_value = *reinterpret_cast<const MaskLoadT*>(&mask[i]);
T dx_vec[VecSize];
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
dx_vec[ii] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
}
*(reinterpret_cast<LoadT*>(&dx[i])) = *reinterpret_cast<LoadT*>(&dx_vec[0]);
}
}
#endif
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
......@@ -119,6 +168,7 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace());
auto size = grad_x->numel();
auto M = EigenVector<uint8_t>::Flatten(*mask);
auto dX = EigenVector<T>::Flatten(*grad_x);
......@@ -126,7 +176,6 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
if (dropout_implementation == "upscale_in_train") {
......@@ -134,8 +183,24 @@ class DropoutGradKernel : public framework::OpKernel<T> {
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
} else {
dX.device(place) =
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
int vec_size = VectorizedSize<T>(grad_y->data<T>());
if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 &&
size % 4 == 0) {
#ifdef __NVCC__
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
auto stream = context.cuda_device_context().stream();
platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(
context.cuda_device_context(), size);
DropoutGradCUDAKernel<
T, uint8_t,
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
grad_y->data<T>(), mask->data<uint8_t>(), factor, size,
grad_x->data<T>());
#endif
} else {
dX.device(place) =
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
}
} else {
dX.device(place) = dY * M.cast<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册