未验证 提交 101c9bb0 编写于 作者: L limingshu 提交者: GitHub

Optimization for DropoutNd on Host side (#51934)

* first commit

* fix bugs

* remove_useless sync
上级 f8a8dd5e
...@@ -41,7 +41,7 @@ namespace funcs { ...@@ -41,7 +41,7 @@ namespace funcs {
template <typename T> template <typename T>
struct DstFunctor { struct DstFunctor {
using MT = typename phi::kps::details::MPTypeTrait<T>::Type; using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
MT factor;
HOSTDEVICE inline DstFunctor(const float retain_prob, HOSTDEVICE inline DstFunctor(const float retain_prob,
const bool is_upscale_in_train, const bool is_upscale_in_train,
const int64_t num) const int64_t num)
...@@ -67,17 +67,12 @@ struct DstFunctor { ...@@ -67,17 +67,12 @@ struct DstFunctor {
const float retain_prob_; const float retain_prob_;
const bool is_upscale_in_train_; const bool is_upscale_in_train_;
const int64_t num_; const int64_t num_;
MT factor;
}; };
template <typename T> template <typename T>
struct MaskFunctor { struct MaskFunctor {
const float retain_prob_; explicit MaskFunctor(const float retain_prob) : retain_prob_(retain_prob) {}
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
MT factor;
HOSTDEVICE inline MaskFunctor(const float retain_prob)
: retain_prob_(retain_prob) {
factor = static_cast<MT>(1.0f / retain_prob_);
}
HOSTDEVICE inline void operator()(T* dst, const float* rand, int num) const { HOSTDEVICE inline void operator()(T* dst, const float* rand, int num) const {
static constexpr int kCount = static constexpr int kCount =
...@@ -88,14 +83,14 @@ struct MaskFunctor { ...@@ -88,14 +83,14 @@ struct MaskFunctor {
dst[i] = rand[i] < retain_prob_ ? static_cast<T>(1) : static_cast<T>(0); dst[i] = rand[i] < retain_prob_ ? static_cast<T>(1) : static_cast<T>(0);
} }
} }
private:
float retain_prob_;
}; };
template <typename T> template <typename T>
struct DstMaskFunctor { struct DstMaskFunctor {
const float retain_prob_;
const bool is_upscale_in_train_;
using MT = typename phi::kps::details::MPTypeTrait<T>::Type; using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
MT factor;
HOSTDEVICE inline DstMaskFunctor(const float retain_prob, HOSTDEVICE inline DstMaskFunctor(const float retain_prob,
const bool is_upscale_in_train) const bool is_upscale_in_train)
: retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) { : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) {
...@@ -122,6 +117,11 @@ struct DstMaskFunctor { ...@@ -122,6 +117,11 @@ struct DstMaskFunctor {
} }
} }
} }
private:
MT factor;
float retain_prob_;
bool is_upscale_in_train_;
}; };
template <typename T> template <typename T>
...@@ -172,9 +172,6 @@ __global__ void VectorizedRandomGenerator(const size_t n, ...@@ -172,9 +172,6 @@ __global__ void VectorizedRandomGenerator(const size_t n,
&mask_result[0], &dst_mask[kCount], Cast()); &mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<uint8_t, kCount, 1, false>( kps::WriteData<uint8_t, kCount, 1, false>(
mask + fix, &mask_result[0], deal_size); mask + fix, &mask_result[0], deal_size);
if (fix > idx * kCount + 1) {
__syncthreads();
}
} }
int remainder = n - fix; int remainder = n - fix;
if (remainder > 0) { if (remainder > 0) {
...@@ -190,7 +187,6 @@ __global__ void VectorizedRandomGenerator(const size_t n, ...@@ -190,7 +187,6 @@ __global__ void VectorizedRandomGenerator(const size_t n,
&mask_result[0], &dst_mask[kCount], Cast()); &mask_result[0], &dst_mask[kCount], Cast());
kps::WriteData<uint8_t, kCount, 1, true>( kps::WriteData<uint8_t, kCount, 1, true>(
mask + fix, &mask_result[0], remainder); mask + fix, &mask_result[0], remainder);
__syncthreads();
} }
} }
...@@ -204,11 +200,17 @@ __global__ void DropOutNdForwardKernel( ...@@ -204,11 +200,17 @@ __global__ void DropOutNdForwardKernel(
uint64_t increment, uint64_t increment,
size_t main_offset, size_t main_offset,
DstFunctor<T> dst_functor, DstFunctor<T> dst_functor,
MaskFunctor<T> mask_functor,
T* y, T* y,
int64_t N, int64_t N,
kps::details::BroadcastConfig broadcast_config) { kps::details::BroadcastConfig broadcast_config,
const uint64_t* seed_ptr) {
// Vectorized Generate Mask // Vectorized Generate Mask
// kCount is 4 for curand_uniform4 is used // kCount is 4 for curand_uniform4 is used
if (seed_ptr) {
seed = seed_ptr[0];
}
constexpr int kCount = phi::funcs::uniform_distribution<float>::kReturnsCount; constexpr int kCount = phi::funcs::uniform_distribution<float>::kReturnsCount;
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X); size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount; size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
...@@ -229,8 +231,6 @@ __global__ void DropOutNdForwardKernel( ...@@ -229,8 +231,6 @@ __global__ void DropOutNdForwardKernel(
int deal_size = BLOCK_NUM_X * kCount; int deal_size = BLOCK_NUM_X * kCount;
size_t fix = idx * kCount; size_t fix = idx * kCount;
auto mask_functor = MaskFunctor<T>(1.0f - dropout_prob);
for (; fix < main_offset; fix += stride) { for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, false>(&dst_mask[0], src + fix, deal_size); kps::ReadData<T, kCount, 1, false>(&dst_mask[0], src + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, Rand>( kps::ElementwiseRandom<SType, float, kCount, Rand>(
...@@ -244,9 +244,6 @@ __global__ void DropOutNdForwardKernel( ...@@ -244,9 +244,6 @@ __global__ void DropOutNdForwardKernel(
&mask_result[0], &dst_mask[0], Cast()); &mask_result[0], &dst_mask[0], Cast());
kps::WriteData<uint8_t, kCount, 1, false>( kps::WriteData<uint8_t, kCount, 1, false>(
mask + fix, &mask_result[0], deal_size); mask + fix, &mask_result[0], deal_size);
if (fix > idx * kCount + 1) {
__syncthreads();
}
} }
int remainder = n - fix; int remainder = n - fix;
if (remainder > 0) { if (remainder > 0) {
...@@ -261,7 +258,6 @@ __global__ void DropOutNdForwardKernel( ...@@ -261,7 +258,6 @@ __global__ void DropOutNdForwardKernel(
&mask_result[0], &dst_mask[0], Cast()); &mask_result[0], &dst_mask[0], Cast());
kps::WriteData<uint8_t, kCount, 1, true>( kps::WriteData<uint8_t, kCount, 1, true>(
mask + fix, &mask_result[0], remainder); mask + fix, &mask_result[0], remainder);
__syncthreads();
} }
// Broadcast mask data and do elementwise operaiton with DstFunctor // Broadcast mask data and do elementwise operaiton with DstFunctor
CUDA_KERNEL_LOOP(i, N) { CUDA_KERNEL_LOOP(i, N) {
...@@ -347,8 +343,6 @@ void DropoutFwGPUKernelDriver( ...@@ -347,8 +343,6 @@ void DropoutFwGPUKernelDriver(
auto offset = auto offset =
((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize; ((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
GetSeedDataAndIncrement(
dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);
size_t main_offset = size_t main_offset =
size / (block_size * kVecSize) * (block_size * kVecSize); size / (block_size * kVecSize) * (block_size * kVecSize);
...@@ -356,15 +350,25 @@ void DropoutFwGPUKernelDriver( ...@@ -356,15 +350,25 @@ void DropoutFwGPUKernelDriver(
auto dst_functor = auto dst_functor =
DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel); DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel);
auto input_x_dims = x.dims(); std::vector<int64_t> out_dims = phi::vectorize<int64_t>(x.dims());
auto mask_dims = mask->dims(); std::vector<int64_t> in_dims = phi::vectorize<int64_t>(mask->dims());
std::vector<int64_t> out_dims = phi::vectorize<int64_t>(input_x_dims); std::reverse(out_dims.begin(), out_dims.end());
std::vector<int64_t> in_dims = phi::vectorize<int64_t>(mask_dims); std::reverse(in_dims.begin(), in_dims.end());
reverse(out_dims.begin(), out_dims.end());
reverse(in_dims.begin(), in_dims.end());
kps::details::BroadcastConfig broadcast_config( kps::details::BroadcastConfig broadcast_config(
out_dims, in_dims, x.dims().size()); out_dims, in_dims, x.dims().size());
auto mask_functor = MaskFunctor<T>(1.0f - dropout_prob);
bool copy_in_kernel = GetSeedDataAndIncrement(dev_ctx,
seed,
is_fix_seed,
seed_val,
offset,
&seed_data,
&increment,
true);
const uint64_t* seed_ptr =
copy_in_kernel ? seed->data<uint64_t>() : nullptr;
DropOutNdForwardKernel<T> DropOutNdForwardKernel<T>
<<<grid_size, block_size, 0, stream>>>(size, <<<grid_size, block_size, 0, stream>>>(size,
seed_data, seed_data,
...@@ -374,10 +378,15 @@ void DropoutFwGPUKernelDriver( ...@@ -374,10 +378,15 @@ void DropoutFwGPUKernelDriver(
increment, increment,
main_offset, main_offset,
dst_functor, dst_functor,
mask_functor,
y_data, y_data,
y->numel(), y->numel(),
broadcast_config); broadcast_config,
seed_ptr);
} else { } else {
bool copy_in_kernel = GetSeedDataAndIncrement(
dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);
#define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T> #define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T>
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!is_fix_seed, PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!is_fix_seed,
PD_DROPOUT_KERNEL_NAME, PD_DROPOUT_KERNEL_NAME,
......
...@@ -22,27 +22,33 @@ limitations under the License. */ ...@@ -22,27 +22,33 @@ limitations under the License. */
namespace phi { namespace phi {
namespace funcs { namespace funcs {
inline void GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx, inline bool GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* seed, const phi::DenseTensor* seed,
const bool is_fix_seed, const bool is_fix_seed,
const int seed_val, const int seed_val,
const int offset, const int offset,
uint64_t* seed_data, uint64_t* seed_data,
uint64_t* increment) { uint64_t* increment,
bool use_copy = true) {
auto gen_cuda = dev_ctx.GetGenerator(); auto gen_cuda = dev_ctx.GetGenerator();
if (seed) { if (seed) {
phi::DenseTensor seed_cpu_tensor; if (use_copy) {
phi::Copy(dev_ctx, *seed, phi::CPUPlace(), true, &seed_cpu_tensor); phi::DenseTensor seed_cpu_tensor;
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]); phi::Copy(dev_ctx, *seed, phi::CPUPlace(), true, &seed_cpu_tensor);
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
}
*increment = offset; *increment = offset;
return true;
} else if (!is_fix_seed) { } else if (!is_fix_seed) {
auto seed_offset = gen_cuda->IncrementOffset(offset); auto seed_offset = gen_cuda->IncrementOffset(offset);
*seed_data = seed_offset.first; *seed_data = seed_offset.first;
*increment = seed_offset.second; *increment = seed_offset.second;
return false;
} else { } else {
*seed_data = seed_val; *seed_data = seed_val;
*increment = offset; *increment = offset;
return false;
} }
} }
......
...@@ -67,18 +67,10 @@ void FusedLinearParamGradAddImpl(const Context &ctx, ...@@ -67,18 +67,10 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
dout_copy.Resize({M, N}); dout_copy.Resize({M, N});
if (kIsMultiPrecision) { if (kIsMultiPrecision) {
*dbias_out = phi::Sum<T, Context>( *dbias_out = phi::Sum<T, Context>(
ctx, ctx, dout_copy, {0}, phi::CppTypeToDataType<MT>::Type(), false);
dout_copy,
{0},
paddle::experimental::CppTypeToDataType<MT>::Type(),
false);
} else { } else {
*dbias_out = phi::Sum<T, Context>( *dbias_out = phi::Sum<T, Context>(
ctx, ctx, dout_copy, {0}, phi::CppTypeToDataType<T>::Type(), false);
dout_copy,
{0},
paddle::experimental::CppTypeToDataType<T>::Type(),
false);
} }
} }
...@@ -141,12 +133,12 @@ void FusedLinearParamGradAdd(const Context &ctx, ...@@ -141,12 +133,12 @@ void FusedLinearParamGradAdd(const Context &ctx,
if (multi_precision) { if (multi_precision) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dweight_out->dtype(), dweight_out->dtype(),
paddle::experimental::CppTypeToDataType<MT>::Type(), phi::CppTypeToDataType<MT>::Type(),
phi::errors::InvalidArgument("Invaid data type error.")); phi::errors::InvalidArgument("Invaid data type error."));
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dweight_out->dtype(), dweight_out->dtype(),
paddle::experimental::CppTypeToDataType<T>::Type(), phi::CppTypeToDataType<T>::Type(),
phi::errors::InvalidArgument("Invaid data type error.")); phi::errors::InvalidArgument("Invaid data type error."));
} }
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册