未验证 提交 1e9127f6 编写于 作者: Z Zhang Ting 提交者: GitHub

improve dropout grad (#29605)

* improve grad perf
上级 eab44e1f
...@@ -27,22 +27,6 @@ limitations under the License. */ ...@@ -27,22 +27,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// aligned vector generates vectorized load/store on CUDA
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
};
template <typename T>
inline int VectorizedSize(const T* pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
if (address % vec4 == 0) {
return 4;
}
return 1;
}
template <typename T, typename MaskType> template <typename T, typename MaskType>
__global__ void RandomGenerator(const size_t n, uint64_t seed, __global__ void RandomGenerator(const size_t n, uint64_t seed,
const float dropout_prob, const T* src, const float dropout_prob, const T* src,
...@@ -154,12 +138,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -154,12 +138,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
return; return;
} }
int threads = 512;
int grid = (x_numel + threads - 1) / threads;
const auto& dev_ctx = context.cuda_device_context(); const auto& dev_ctx = context.cuda_device_context();
int blocks_per_sm = platform::GpuLaunchConfig config =
dev_ctx.GetMaxPhysicalThreadCount() / dev_ctx.GetSMCount() / threads; platform::GetGpuLaunchConfig1D(dev_ctx, size);
grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid);
// increment is used to set the args(offset) of curand_init, which defines // increment is used to set the args(offset) of curand_init, which defines
// offset in subsequence. // offset in subsequence.
...@@ -171,8 +152,10 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -171,8 +152,10 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
uint64_t seed_data; uint64_t seed_data;
uint64_t increment; uint64_t increment;
int vec_size = VectorizedSize<T>(x_data); int vec_size = VectorizedSize<T>(x_data);
auto offset = auto offset = ((x_numel - 1) / (config.block_per_grid.x *
((x_numel - 1) / (threads * grid * vec_size) + 1) * vec_size; config.thread_per_block.x * vec_size) +
1) *
vec_size;
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()) int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
.GetDeviceId(); .GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
...@@ -197,12 +180,15 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -197,12 +180,15 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
increment = offset; increment = offset;
} }
if (vec_size == 4) { if (vec_size == 4 && size % 4 == 0) {
VectorizedRandomGenerator<T, uint8_t, 4><<<grid, threads, 0, stream>>>( VectorizedRandomGenerator<
T, uint8_t,
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data, size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment); upscale_in_train, increment);
} else { } else {
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>( RandomGenerator<T, uint8_t><<<config.block_per_grid,
config.thread_per_block, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data, size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, increment); upscale_in_train, increment);
} }
......
...@@ -17,13 +17,59 @@ limitations under the License. */ ...@@ -17,13 +17,59 @@ limitations under the License. */
#include <random> #include <random>
#include <string> #include <string>
#include <algorithm>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// aligned vector generates vectorized load/store on CUDA
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
};
template <typename T>
inline int VectorizedSize(const T* pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
if (address % vec4 == 0) {
return 4;
}
return 1;
}
#ifdef __NVCC__
template <typename T, typename MaskType, int VecSize>
__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
const T factor, const int64_t size,
T* dx) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = AlignedVector<T, VecSize>;
using MaskLoadT = AlignedVector<MaskType, VecSize>;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
T dout_vec[VecSize];
LoadT* value = reinterpret_cast<LoadT*>(&dout_vec);
*value = *reinterpret_cast<const LoadT*>(&dout[i]);
T dx_vec[VecSize];
MaskType mask_vec[VecSize];
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
dx_vec[ii] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
}
*(reinterpret_cast<LoadT*>(&dx[i])) = *reinterpret_cast<LoadT*>(&dx_vec[0]);
}
}
#endif
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
...@@ -119,6 +165,7 @@ class DropoutGradKernel : public framework::OpKernel<T> { ...@@ -119,6 +165,7 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out")); auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* mask = context.Input<Tensor>("Mask"); auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace()); grad_x->mutable_data<T>(context.GetPlace());
auto size = grad_x->numel();
auto M = EigenVector<uint8_t>::Flatten(*mask); auto M = EigenVector<uint8_t>::Flatten(*mask);
auto dX = EigenVector<T>::Flatten(*grad_x); auto dX = EigenVector<T>::Flatten(*grad_x);
...@@ -126,7 +173,6 @@ class DropoutGradKernel : public framework::OpKernel<T> { ...@@ -126,7 +173,6 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
auto& dropout_implementation = auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation"); context.Attr<std::string>("dropout_implementation");
if (dropout_implementation == "upscale_in_train") { if (dropout_implementation == "upscale_in_train") {
...@@ -134,8 +180,24 @@ class DropoutGradKernel : public framework::OpKernel<T> { ...@@ -134,8 +180,24 @@ class DropoutGradKernel : public framework::OpKernel<T> {
if (dropout_prob == 1.0f) { if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY; dX.device(place) = static_cast<T>(0) * dY;
} else { } else {
dX.device(place) = int vec_size = VectorizedSize<T>(grad_y->data<T>());
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob); if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 &&
size % 4 == 0) {
#ifdef __NVCC__
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
auto stream = context.cuda_device_context().stream();
platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(
context.cuda_device_context(), size);
DropoutGradCUDAKernel<
T, uint8_t,
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
grad_y->data<T>(), mask->data<uint8_t>(), factor, size,
grad_x->data<T>());
#endif
} else {
dX.device(place) =
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
} }
} else { } else {
dX.device(place) = dY * M.cast<T>(); dX.device(place) = dY * M.cast<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册