dropout_impl.cu.h 18.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>

#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <curand_kernel.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#endif

28 29 30
#include "paddle/phi/kernels/funcs/dropout_impl_util.h"

#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
H
hong 已提交
31
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
32
#include "paddle/phi/kernels/funcs/broadcast_function.h"
33
#include "paddle/phi/kernels/funcs/distribution_helper.h"
34
#include "paddle/phi/kernels/funcs/functors.h"
35
#include "paddle/phi/kernels/primitive/compute_primitives.h"
B
Bo Zhang 已提交
36
#include "paddle/phi/kernels/primitive/datamover_primitives.h"
37

38 39
namespace phi {
namespace funcs {
40

B
Bo Zhang 已提交
41 42 43
template <typename T>
struct DstFunctor {
  using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
44

B
Bo Zhang 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
  HOSTDEVICE inline DstFunctor(const float retain_prob,
                               const bool is_upscale_in_train,
                               const int64_t num)
      : retain_prob_(retain_prob),
        is_upscale_in_train_(is_upscale_in_train),
        num_(num) {
    factor = static_cast<MT>(1.0f / retain_prob_);
  }

  HOSTDEVICE inline T operator()(const T src_val, const uint8_t mask) const {
    for (int i = 0; i < num_; i++) {
      if (mask == static_cast<uint8_t>(1)) {
        return is_upscale_in_train_
                   ? static_cast<T>(static_cast<MT>(src_val) * factor)
                   : static_cast<T>(src_val);
      } else {
        return static_cast<T>(0);
      }
    }
  }

 private:
  const float retain_prob_;
  const bool is_upscale_in_train_;
  const int64_t num_;
70
  MT factor;
B
Bo Zhang 已提交
71 72 73 74
};

template <typename T>
struct MaskFunctor {
75
  explicit MaskFunctor(const float retain_prob) : retain_prob_(retain_prob) {}
B
Bo Zhang 已提交
76 77 78 79 80 81 82 83 84 85

  HOSTDEVICE inline void operator()(T* dst, const float* rand, int num) const {
    static constexpr int kCount =
        phi::funcs::uniform_distribution<float>::kReturnsCount;
// 0 ~ kCount - 1 is dst, kCount ~ 2 * kCount - 1 is mask
#pragma unroll
    for (int i = 0; i < kCount; i++) {
      dst[i] = rand[i] < retain_prob_ ? static_cast<T>(1) : static_cast<T>(0);
    }
  }
86 87 88

 private:
  float retain_prob_;
B
Bo Zhang 已提交
89 90 91
};

template <typename T>
92
struct DstMaskFunctor {
B
Bo Zhang 已提交
93
  using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
94 95 96 97 98 99
  HOSTDEVICE inline DstMaskFunctor(const float retain_prob,
                                   const bool is_upscale_in_train)
      : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) {
    factor = static_cast<MT>(1.0f / retain_prob_);
  }

B
Bo Zhang 已提交
100 101 102
  HOSTDEVICE inline void operator()(T* dst,
                                    const T* src_val,
                                    const float* rand,
103
                                    int num) const {
104
    static constexpr int kCount =
B
Bo Zhang 已提交
105 106
        phi::funcs::uniform_distribution<float>::kReturnsCount;
// 0 ~ kCount - 1 is dst, kCount ~ 2 * kCount - 1 is mask
107 108 109 110
#pragma unroll
    for (int i = 0; i < kCount; i++) {
      if (rand[i] < retain_prob_) {
        dst[i] = is_upscale_in_train_
B
Bo Zhang 已提交
111 112 113
                     ? static_cast<T>(static_cast<MT>(src_val[i]) * factor)
                     : static_cast<T>(src_val[i]);
        dst[i + kCount] = static_cast<T>(1);
114
      } else {
B
Bo Zhang 已提交
115
        dst[i] = static_cast<T>(0);
116 117 118 119
        dst[i + kCount] = dst[i];
      }
    }
  }
120 121 122 123 124

 private:
  MT factor;
  float retain_prob_;
  bool is_upscale_in_train_;
125 126
};

B
Bo Zhang 已提交
127
template <typename T>
128 129
__global__ void VectorizedRandomGenerator(const size_t n,
                                          uint64_t seed,
130
                                          const float dropout_prob,
131
                                          const T* src,
B
Bo Zhang 已提交
132
                                          uint8_t* mask,
133
                                          T* dst,
134
                                          bool is_upscale_in_train,
135
                                          uint64_t increment,
136
                                          size_t main_offset) {
137 138 139 140
  size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
  static constexpr int kCount =
      phi::funcs::uniform_distribution<float>::kReturnsCount;
  size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
141 142
#ifdef PADDLE_WITH_HIP
  hiprandStatePhilox4_32_10_t state;
143 144
  hiprand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = hiprandStatePhilox4_32_10_t;
145 146
#else
  curandStatePhilox4_32_10_t state;
147 148
  curand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = curandStatePhilox4_32_10_t;
149
#endif
B
Bo Zhang 已提交
150 151
  T dst_mask[kCount *
             2];  // 0 ~ kCount - 1 : dst,  kCount ~ 2 * kCount - 1: mask
152
  float rands[kCount];
B
Bo Zhang 已提交
153
  uint8_t mask_result[kCount];
154 155 156
  using Rand = phi::funcs::uniform_distribution<float>;
  using Cast = kps::IdentityFunctor<T>;
  int deal_size = BLOCK_NUM_X * kCount;
157

158
  size_t fix = idx * kCount;
159 160

  auto dst_functor =
B
Bo Zhang 已提交
161
      DstMaskFunctor<T>(1.0f - dropout_prob, is_upscale_in_train);
162
  for (; fix < main_offset; fix += stride) {
163 164
    kps::ReadData<T, kCount, 1, false>(&dst_mask[0], src + fix, deal_size);
    kps::ElementwiseRandom<SType, float, kCount, Rand>(
165
        &rands[0], Rand(), &state);
166
    // dst
B
Bo Zhang 已提交
167
    kps::OperatorTernary<T, float, T, DstMaskFunctor<T>>(
168
        &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
169
    kps::WriteData<T, kCount, 1, false>(dst + fix, &dst_mask[0], deal_size);
170
    // mask
B
Bo Zhang 已提交
171
    kps::ElementwiseUnary<T, uint8_t, kCount, 1, Cast>(
172
        &mask_result[0], &dst_mask[kCount], Cast());
B
Bo Zhang 已提交
173
    kps::WriteData<uint8_t, kCount, 1, false>(
174
        mask + fix, &mask_result[0], deal_size);
175 176 177
  }
  int remainder = n - fix;
  if (remainder > 0) {
178 179
    kps::ReadData<T, kCount, 1, true>(&dst_mask[0], src + fix, remainder);
    kps::ElementwiseRandom<SType, float, kCount, Rand>(
180
        &rands[0], Rand(), &state);
181
    // dst
B
Bo Zhang 已提交
182
    kps::OperatorTernary<T, float, T, DstMaskFunctor<T>>(
183
        &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
184
    kps::WriteData<T, kCount, 1, true>(dst + fix, &dst_mask[0], remainder);
185
    // mask
B
Bo Zhang 已提交
186
    kps::ElementwiseUnary<T, uint8_t, kCount, 1, Cast>(
187
        &mask_result[0], &dst_mask[kCount], Cast());
B
Bo Zhang 已提交
188
    kps::WriteData<uint8_t, kCount, 1, true>(
189
        mask + fix, &mask_result[0], remainder);
190 191 192
  }
}

B
Bo Zhang 已提交
193 194 195 196 197 198 199 200 201 202
template <typename T>
__global__ void DropOutNdForwardKernel(
    const size_t n,
    uint64_t seed,
    const float dropout_prob,
    const T* src,
    uint8_t* mask,
    uint64_t increment,
    size_t main_offset,
    DstFunctor<T> dst_functor,
203
    MaskFunctor<T> mask_functor,
B
Bo Zhang 已提交
204 205
    T* y,
    int64_t N,
206 207
    kps::details::BroadcastConfig broadcast_config,
    const uint64_t* seed_ptr) {
B
Bo Zhang 已提交
208 209
  // Vectorized Generate Mask
  // kCount is 4 for curand_uniform4 is used
210 211 212 213
  if (seed_ptr) {
    seed = seed_ptr[0];
  }

214 215 216 217 218 219 220 221 222 223 224 225
  constexpr int kCount = phi::funcs::uniform_distribution<float>::kReturnsCount;
  size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
  size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
#ifdef PADDLE_WITH_HIP
  hiprandStatePhilox4_32_10_t state;
  hiprand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = hiprandStatePhilox4_32_10_t;
#else
  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = curandStatePhilox4_32_10_t;
#endif
B
Bo Zhang 已提交
226
  T dst_mask[kCount];  // 0 ~ kCount - 1 : dst,  kCount ~ 2 * kCount - 1: mask
227
  float rands[kCount];
B
Bo Zhang 已提交
228
  uint8_t mask_result[kCount];
229 230 231 232 233 234
  using Rand = phi::funcs::uniform_distribution<float>;
  using Cast = kps::IdentityFunctor<T>;
  int deal_size = BLOCK_NUM_X * kCount;

  size_t fix = idx * kCount;
  for (; fix < main_offset; fix += stride) {
235 236
    kps::ReadData<T, kCount, 1, false>(&dst_mask[0], src + fix, deal_size);
    kps::ElementwiseRandom<SType, float, kCount, Rand>(
237
        &rands[0], Rand(), &state);
238
    // dst
B
Bo Zhang 已提交
239
    kps::OperatorBinary<float, T, MaskFunctor<T>>(
240 241 242
        &dst_mask[0], &rands[0], mask_functor, kCount);

    // mask
B
Bo Zhang 已提交
243
    kps::ElementwiseUnary<T, uint8_t, kCount, 1, Cast>(
244
        &mask_result[0], &dst_mask[0], Cast());
B
Bo Zhang 已提交
245
    kps::WriteData<uint8_t, kCount, 1, false>(
246
        mask + fix, &mask_result[0], deal_size);
247 248 249
  }
  int remainder = n - fix;
  if (remainder > 0) {
250 251
    kps::ReadData<T, kCount, 1, true>(&dst_mask[0], src + fix, remainder);
    kps::ElementwiseRandom<SType, float, kCount, Rand>(
252
        &rands[0], Rand(), &state);
253
    // dst
B
Bo Zhang 已提交
254
    kps::OperatorBinary<float, T, MaskFunctor<T>>(
255 256
        &dst_mask[0], &rands[0], mask_functor, kCount);
    // mask
B
Bo Zhang 已提交
257
    kps::ElementwiseUnary<T, uint8_t, kCount, 1, Cast>(
258
        &mask_result[0], &dst_mask[0], Cast());
B
Bo Zhang 已提交
259
    kps::WriteData<uint8_t, kCount, 1, true>(
260
        mask + fix, &mask_result[0], remainder);
261
  }
B
Bo Zhang 已提交
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
  // Broadcast mask data and do elementwise operaiton with DstFunctor
  CUDA_KERNEL_LOOP(i, N) {
    uint32_t offset = 0u;
    uint32_t idx = i;
    // Use (j < phi::DDim::kMaxRank) conditiion rather than
    // (j < broadcast_config.rank) for (#pragma unroll)
#pragma unroll
    for (int j = 0; j < phi::DDim::kMaxRank; ++j) {
      if (j == broadcast_config.rank) break;
      auto fast_divmoder = broadcast_config.divmoders[j].Divmod(idx);
      idx = fast_divmoder.val[0];
      offset += broadcast_config.strides[j] * fast_divmoder.val[1];
    }
    y[i] = dst_functor(src[i], mask[offset]);
  }
277 278 279 280
}

template <typename T, typename MT>
void ScaleByDropoutFactor(const phi::GPUContext& dev_ctx,
281 282
                          const phi::DenseTensor& x,
                          phi::DenseTensor* y,
283
                          MT factor) {
284 285
  std::vector<const phi::DenseTensor*> ins = {&x};
  std::vector<phi::DenseTensor*> outs = {y};
286 287 288 289
  auto functor = phi::funcs::ScaleFunctor<T>(factor);
  phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
}

290
template <typename T>
B
Bo Zhang 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303
void DropoutFwGPUKernelDriver(
    const phi::GPUContext& dev_ctx,
    bool is_test,
    float dropout_prob,
    bool upscale_in_train,
    bool is_fix_seed,
    int seed_val,
    const phi::DenseTensor& x,
    const phi::DenseTensor* seed,
    phi::DenseTensor* mask,
    phi::DenseTensor* y,
    bool is_dropout_nd = false,
    const std::vector<int>& axis = std::vector<int>()) {
304 305 306 307
  int64_t x_numel = x.numel();
  auto stream = dev_ctx.stream();
  auto* x_data = x.data<T>();
  auto* y_data = y->data<T>();
308

309
  if (!is_test && mask) {
310
    auto* mask_data = mask->data<uint8_t>();
311
    size_t size = phi::product(mask->dims());
312 313 314

    if (dropout_prob == 1.0f) {
#ifdef PADDLE_WITH_HIP
315
      PADDLE_ENFORCE_GPU_SUCCESS(
316
          hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
317
      PADDLE_ENFORCE_GPU_SUCCESS(
318 319
          hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#else
320
      PADDLE_ENFORCE_GPU_SUCCESS(
321
          cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
322
      PADDLE_ENFORCE_GPU_SUCCESS(
323 324 325 326 327 328 329
          cudaMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#endif
      return;
    }

    uint64_t seed_data;
    uint64_t increment;
330
    // VectorizedRandomGenerator use curand_uniform4, so kVecSize is 4;
331 332
    constexpr int kVecSize =
        phi::funcs::uniform_distribution<float>::kReturnsCount;
H
hong 已提交
333
    auto gpu_config =
334
        phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, kVecSize);
335 336 337
    size_t grid_size = gpu_config.GetGridSize();
    size_t block_size = gpu_config.GetBlockSize();

338
    int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
339
    const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id);
340 341 342
    size_t max_grid_size = prop.maxThreadsPerMultiProcessor *
                           prop.multiProcessorCount / block_size;
    grid_size = std::min(grid_size, max_grid_size);
343

Z
Zhang Ting 已提交
344
    auto offset =
345 346 347 348
        ((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
    size_t main_offset =
        size / (block_size * kVecSize) * (block_size * kVecSize);

349
    if (is_dropout_nd) {
B
Bo Zhang 已提交
350 351 352
      auto dst_functor =
          DstFunctor<T>(1.0f - dropout_prob, upscale_in_train, x_numel);

353 354 355 356
      std::vector<int64_t> out_dims =
          std::move(phi::vectorize<int64_t>(x.dims()));
      std::vector<int64_t> in_dims =
          std::move(phi::vectorize<int64_t>(mask->dims()));
357 358
      std::reverse(out_dims.begin(), out_dims.end());
      std::reverse(in_dims.begin(), in_dims.end());
B
Bo Zhang 已提交
359 360 361
      kps::details::BroadcastConfig broadcast_config(
          out_dims, in_dims, x.dims().size());

362 363 364 365 366 367 368 369 370 371 372 373
      auto mask_functor = MaskFunctor<T>(1.0f - dropout_prob);
      bool copy_in_kernel = GetSeedDataAndIncrement(dev_ctx,
                                                    seed,
                                                    is_fix_seed,
                                                    seed_val,
                                                    offset,
                                                    &seed_data,
                                                    &increment,
                                                    true);
      const uint64_t* seed_ptr =
          copy_in_kernel ? seed->data<uint64_t>() : nullptr;

B
Bo Zhang 已提交
374
      DropOutNdForwardKernel<T>
375 376 377 378 379 380
          <<<grid_size, block_size, 0, stream>>>(size,
                                                 seed_data,
                                                 dropout_prob,
                                                 x_data,
                                                 mask_data,
                                                 increment,
B
Bo Zhang 已提交
381 382
                                                 main_offset,
                                                 dst_functor,
383
                                                 mask_functor,
B
Bo Zhang 已提交
384 385
                                                 y_data,
                                                 y->numel(),
386 387
                                                 broadcast_config,
                                                 seed_ptr);
388
    } else {
389 390 391
      bool copy_in_kernel = GetSeedDataAndIncrement(
          dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);

B
Bo Zhang 已提交
392
#define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T>
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
      PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!is_fix_seed,
                                         PD_DROPOUT_KERNEL_NAME,
                                         grid_size,
                                         block_size,
                                         0,
                                         stream,
                                         offset,
                                         KERNEL_PARAMS.As<uint64_t>(1),
                                         KERNEL_PARAMS.As<uint64_t>(7),
                                         size,
                                         seed_data,
                                         dropout_prob,
                                         x_data,
                                         mask_data,
                                         y_data,
                                         upscale_in_train,
                                         increment,
                                         main_offset);
S
sneaxiy 已提交
411
#undef PD_DROPOUT_KERNEL_NAME
412
    }
413 414
    VLOG(4) << "Dropout seed: " << seed << ", offset: " << offset
            << ", seed_data:" << seed_data;
415 416
  } else {
    if (upscale_in_train) {
417
      // y = x
418
      phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, y);
419
    } else {
420
      using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
421
      MT factor = static_cast<MT>(1.0f - dropout_prob);
422 423
      // y = factor * x
      ScaleByDropoutFactor<T, MT>(dev_ctx, x, y, factor);
424 425 426 427
    }
  }
}

B
Bo Zhang 已提交
428
template <typename T>
429
struct CudaDropoutGradFunctor {
430
  using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
431 432 433 434

  explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {}

  __device__ __forceinline__ T operator()(const T dout,
B
Bo Zhang 已提交
435
                                          const uint8_t mask) const {
436 437 438 439 440 441 442 443
    return static_cast<T>(static_cast<MT>(dout) * static_cast<MT>(mask) *
                          factor_);
  }

 private:
  MT factor_;
};

444
template <typename T>
445 446 447 448
void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
                                bool is_test,
                                float dropout_prob,
                                bool upscale_in_train,
449 450 451
                                const phi::DenseTensor& grad_y,
                                const phi::DenseTensor& mask,
                                phi::DenseTensor* grad_x,
452
                                bool is_dropout_nd = false) {
453
  using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
454

455
  auto stream = dev_ctx.stream();
456
  if (is_test) {
457 458 459
    MT factor = static_cast<MT>(upscale_in_train ? 1.0f : 1.0f - dropout_prob);
    // y = factor * x
    ScaleByDropoutFactor<T, MT>(dev_ctx, grad_y, grad_x, factor);
460
  } else {
461
    phi::DenseTensor broadcasted_mask;
462 463
    if (is_dropout_nd) {
      broadcasted_mask.Resize(grad_y.dims());
B
Bo Zhang 已提交
464 465 466 467 468 469 470 471 472 473 474
      dev_ctx.template Alloc<uint8_t>(&broadcasted_mask);

      std::vector<const phi::DenseTensor*> broadcast_ins = {&mask};
      std::vector<phi::DenseTensor*> broadcast_outs = {&broadcasted_mask};
      phi::funcs::BroadcastKernel<phi::ElementwiseType::kUnary,
                                  uint8_t,
                                  uint8_t>(dev_ctx,
                                           broadcast_ins,
                                           &broadcast_outs,
                                           -1,
                                           kps::IdentityFunctor<uint8_t>());
475 476
    }

477
    std::vector<const phi::DenseTensor*> ins = {
478
        &grad_y, is_dropout_nd ? &broadcasted_mask : &mask};
479
    std::vector<phi::DenseTensor*> outs = {grad_x};
480
    if (upscale_in_train) {
481
      if (dropout_prob == 1.0f) {
482
#ifdef PADDLE_WITH_HIP
483
        hipMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
484
#else
485
        cudaMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
486
#endif
487
      } else {
488 489
        MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
        phi::funcs::ElementwiseKernel<T>(
B
Bo Zhang 已提交
490
            dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
491
      }
492
    } else {
493 494
      MT factor = static_cast<MT>(1.0f);
      phi::funcs::ElementwiseKernel<T>(
B
Bo Zhang 已提交
495
          dev_ctx, ins, &outs, CudaDropoutGradFunctor<T>(factor));
496 497 498 499
    }
  }
}

500 501
}  // namespace funcs
}  // namespace phi