未验证 提交 6702040e 编写于 作者: Z Zhang Ting 提交者: GitHub

improve dropout (#29465)

* improve drop out

* add VectorizedRandomGeneratorWithGenerator

* fix bug

* modify according to comments
上级 30d9589a
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <algorithm>
#include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/dropout_op.h"
......@@ -26,60 +27,35 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T, typename MaskType>
__global__ void RandomGenerator(const size_t n, const int seed,
const float dropout_prob, const T* src,
MaskType* mask_data, T* dst,
bool is_upscale_in_train) {
curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
// aligned vector generates vectorized load/store on CUDA
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
};
MaskType mask;
T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
if (step_size == 0) {
curand_init(seed, idx, idx, &state);
step_size = blockDim.x * gridDim.x;
} else {
curand_init(seed, idx, step_size, &state);
}
if (curand_uniform(&state) < dropout_prob) {
mask = 0;
dest = 0;
} else {
mask = 1;
if (is_upscale_in_train) {
dest = s / static_cast<T>(1.0f - dropout_prob);
} else {
dest = s;
}
}
mask_data[idx] = mask;
dst[idx] = dest;
template <typename T>
inline int VectorizedSize(const T* pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
if (address % vec4 == 0) {
return 4;
}
return 1;
}
template <typename T, typename MaskType>
__global__ void RandomGeneratorWithSeed(const size_t n, const int* seed,
const float dropout_prob, const T* src,
MaskType* mask_data, T* dst,
bool is_upscale_in_train) {
__global__ void RandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob, const T* src,
MaskType* mask_data, T* dst,
bool is_upscale_in_train, uint64_t increment) {
curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
curand_init(seed, idx, increment, &state);
MaskType mask;
T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
if (step_size == 0) {
curand_init(seed[0], idx, idx, &state);
step_size = blockDim.x * gridDim.x;
} else {
curand_init(seed[0], idx, step_size, &state);
}
if (curand_uniform(&state) < dropout_prob) {
mask = 0;
dest = 0;
......@@ -96,39 +72,49 @@ __global__ void RandomGeneratorWithSeed(const size_t n, const int* seed,
}
}
template <typename T, typename MaskType>
__global__ void RandomGeneratorWithGenerator(const size_t n, uint64_t seed,
const float dropout_prob,
const T* src, MaskType* mask_data,
T* dst, bool is_upscale_in_train,
uint64_t increment) {
template <typename T, typename MaskType, int VecSize>
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob,
const T* src, MaskType* mask_data,
T* dst, bool is_upscale_in_train,
uint64_t increment) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
curand_init(seed, idx, increment, &state);
MaskType mask;
T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
if (step_size == 0) {
curand_init(seed, idx, increment, &state);
step_size = blockDim.x * gridDim.x;
} else {
curand_init(seed, idx, increment, &state);
}
if (curand_uniform(&state) < dropout_prob) {
mask = 0;
dest = 0;
} else {
mask = 1;
if (is_upscale_in_train) {
dest = s / static_cast<T>(1.0f - dropout_prob);
using LoadT = AlignedVector<T, VecSize>;
using MaskLoadT = AlignedVector<MaskType, VecSize>;
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
T src_vec[VecSize];
LoadT* value = reinterpret_cast<LoadT*>(&src_vec);
*value = *reinterpret_cast<const LoadT*>(&src[i]);
float4 rand = curand_uniform4(&state);
T dest_vec[VecSize];
MaskType mask_vec[VecSize];
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
if ((&rand.x)[ii] < dropout_prob) {
dest_vec[ii] = 0;
mask_vec[ii] = 0;
} else {
dest = s;
if (is_upscale_in_train) {
dest_vec[ii] = src_vec[ii] * factor;
} else {
dest_vec[ii] = src_vec[ii];
}
mask_vec[ii] = 1;
}
}
mask_data[idx] = mask;
dst[idx] = dest;
*(reinterpret_cast<LoadT*>(&dst[i])) =
*reinterpret_cast<LoadT*>(&dest_vec[0]);
*(reinterpret_cast<MaskLoadT*>(&mask_data[i])) =
*reinterpret_cast<MaskLoadT*>(&mask_vec[0]);
}
}
......@@ -170,36 +156,57 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int threads = 512;
int grid = (x_numel + threads - 1) / threads;
const auto& dev_ctx = context.cuda_device_context();
int blocks_per_sm =
dev_ctx.GetMaxPhysicalThreadCount() / dev_ctx.GetSMCount() / threads;
grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid);
// increment is used to set the args(offset) of curand_init, which defines
// offset in subsequence.
// The detail:
// https://docs.nvidia.com/cuda/curand/device-api-overview.html
// Increment should be at least the number of curand() random numbers used
// in each thread to avoid the random number generated this time being the
// same as the previous calls.
uint64_t seed_data;
uint64_t increment;
int vec_size = VectorizedSize<T>(x_data);
auto offset =
((x_numel - 1) / (threads * grid * vec_size) + 1) * vec_size;
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
.GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (seed && platform::is_gpu_place(seed->place())) {
auto seed_gpu_data = seed->data<int>();
RandomGeneratorWithSeed<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_gpu_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train);
return;
}
int seed_data;
std::random_device rnd;
if (seed) {
seed_data = *(seed->data<int>());
framework::Tensor seed_cpu_tensor;
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
increment = offset;
} else if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
auto seed_offset = gen_cuda->IncrementOffset(offset);
seed_data = seed_offset.first;
increment = seed_offset.second;
} else {
seed_data =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
if (seed) {
seed_data = *(seed->data<int>());
} else {
std::random_device rnd;
seed_data = context.Attr<bool>("fix_seed") ? context.Attr<int>("seed")
: rnd();
}
increment = offset;
}
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
.GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
auto seed_offset = gen_cuda->IncrementOffset(1);
RandomGeneratorWithGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_offset.first, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, seed_offset.second);
return;
if (vec_size == 4) {
VectorizedRandomGenerator<T, uint8_t, 4><<<grid, threads, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment);
} else {
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment);
}
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train);
} else {
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册