dropout_impl.cu.h 17.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>

#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <curand_kernel.h>
22

23 24 25 26 27
#include "paddle/fluid/platform/dynload/curand.h"
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
28

29 30 31 32 33 34
#include "paddle/fluid/platform/dynload/hiprand.h"
#endif

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/tensor_util.h"
S
sneaxiy 已提交
35
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
36
#include "paddle/fluid/operators/dropout_impl_util.h"
37
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
H
hong 已提交
38
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
39
#include "paddle/phi/kernels/funcs/broadcast_function.h"
40
#include "paddle/phi/kernels/funcs/distribution_helper.h"
41
#include "paddle/phi/kernels/funcs/functors.h"
42

43 44
namespace paddle {
namespace operators {
45 46 47 48 49 50 51 52 53 54 55 56 57

template <typename T1, typename T2 = T1, typename OutT = T1>
struct DstMaskFunctor {
  const float retain_prob_;
  const bool is_upscale_in_train_;
  using MT = typename details::MPTypeTrait<T1>::Type;
  MT factor;
  HOSTDEVICE inline DstMaskFunctor(const float retain_prob,
                                   const bool is_upscale_in_train)
      : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) {
    factor = static_cast<MT>(1.0f / retain_prob_);
  }

58 59 60 61
  HOSTDEVICE inline void operator()(OutT* dst,
                                    const T1* src_val,
                                    const T2* rand,
                                    int num) const {
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    static constexpr int kCount =
        phi::funcs::uniform_distribution<T2>::kReturnsCount;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
    for (int i = 0; i < kCount; i++) {
      if (rand[i] < retain_prob_) {
        dst[i] = is_upscale_in_train_
                     ? static_cast<T1>(static_cast<MT>(src_val[i]) * factor)
                     : static_cast<T1>(src_val[i]);
        dst[i + kCount] = static_cast<T1>(1);
      } else {
        dst[i] = static_cast<T1>(0);
        dst[i + kCount] = dst[i];
      }
    }
  }
};

80
template <typename T, typename MaskType>
81 82
__global__ void VectorizedRandomGenerator(const size_t n,
                                          uint64_t seed,
83
                                          const float dropout_prob,
84 85 86
                                          const T* src,
                                          MaskType* mask,
                                          T* dst,
87
                                          bool is_upscale_in_train,
88
                                          uint64_t increment,
89
                                          size_t main_offset) {
90 91 92 93
  size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
  static constexpr int kCount =
      phi::funcs::uniform_distribution<float>::kReturnsCount;
  size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
94 95
#ifdef PADDLE_WITH_HIP
  hiprandStatePhilox4_32_10_t state;
96 97
  hiprand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = hiprandStatePhilox4_32_10_t;
98 99
#else
  curandStatePhilox4_32_10_t state;
100 101
  curand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = curandStatePhilox4_32_10_t;
102
#endif
103 104 105 106 107 108
  T dst_mask[kCount * 2];  // 0 ~ kCount -1 : dst;kCount ~ 2 * kCount - 1: mask
  float rands[kCount];
  MaskType mask_result[kCount];
  using Rand = phi::funcs::uniform_distribution<float>;
  using Cast = kps::IdentityFunctor<T>;
  int deal_size = BLOCK_NUM_X * kCount;
109

110
  size_t fix = idx * kCount;
111 112 113 114 115

  auto dst_functor =
      DstMaskFunctor<T, float>(1.0f - dropout_prob, is_upscale_in_train);
  for (; fix < main_offset; fix += stride) {
    kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size);
116 117
    kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(
        &rands[0], Rand(), &state);
118 119 120 121 122 123 124
    // dst
    kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
        &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
    kps::WriteData<T, kCount, 1, 1, false>(dst + fix, &dst_mask[0], deal_size);
    // mask
    kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
        &mask_result[0], &dst_mask[kCount], Cast());
125 126
    kps::WriteData<MaskType, kCount, 1, 1, false>(
        mask + fix, &mask_result[0], deal_size);
127
    if (fix > idx * kCount + 1) {
128 129
      __syncthreads();
    }
130 131 132 133
  }
  int remainder = n - fix;
  if (remainder > 0) {
    kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
134 135
    kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(
        &rands[0], Rand(), &state);
136 137 138 139 140 141 142
    // dst
    kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
        &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
    kps::WriteData<T, kCount, 1, 1, true>(dst + fix, &dst_mask[0], remainder);
    // mask
    kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
        &mask_result[0], &dst_mask[kCount], Cast());
143 144
    kps::WriteData<MaskType, kCount, 1, 1, true>(
        mask + fix, &mask_result[0], remainder);
145
    __syncthreads();
146 147 148
  }
}

149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
template <typename T1, typename T2 = T1, typename OutT = T1>
struct MaskFunctor {
  const float retain_prob_;
  using MT = typename details::MPTypeTrait<T1>::Type;
  MT factor;
  HOSTDEVICE inline MaskFunctor(const float retain_prob)
      : retain_prob_(retain_prob) {
    factor = static_cast<MT>(1.0f / retain_prob_);
  }

  HOSTDEVICE inline void operator()(OutT* dst, const T2* rand, int num) const {
    static constexpr int kCount =
        phi::funcs::uniform_distribution<T2>::kReturnsCount;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
    for (int i = 0; i < kCount; i++) {
      if (rand[i] < retain_prob_) {
        dst[i] = static_cast<T1>(1);
      } else {
        dst[i] = static_cast<T1>(0);
      }
    }
  }
};

template <typename T, typename MaskType>
struct DstFunctor {
  using MT = typename details::MPTypeTrait<T>::Type;
  MT factor;
  HOSTDEVICE inline DstFunctor(const float retain_prob,
                               const bool is_upscale_in_train,
                               const int64_t num)
      : retain_prob_(retain_prob),
        is_upscale_in_train_(is_upscale_in_train),
        num_(num) {
    factor = static_cast<MT>(1.0f / retain_prob_);
  }

  HOSTDEVICE inline T operator()(const T src_val, const MaskType mask) const {
    for (int i = 0; i < num_; i++) {
      if (mask == static_cast<MaskType>(1)) {
        return is_upscale_in_train_
                   ? static_cast<T>(static_cast<MT>(src_val) * factor)
                   : static_cast<T>(src_val);
      } else {
        return static_cast<T>(0);
      }
    }
  }

 private:
  const float retain_prob_;
  const bool is_upscale_in_train_;
  const int64_t num_;
};

template <typename T, typename MaskType>
206 207 208 209 210 211
__global__ void VectorizedGeneratorMask(const size_t n,
                                        uint64_t seed,
                                        const float dropout_prob,
                                        const T* src,
                                        MaskType* mask,
                                        uint64_t increment,
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
                                        size_t main_offset) {
  constexpr int kCount = phi::funcs::uniform_distribution<float>::kReturnsCount;
  size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
  size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
#ifdef PADDLE_WITH_HIP
  hiprandStatePhilox4_32_10_t state;
  hiprand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = hiprandStatePhilox4_32_10_t;
#else
  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = curandStatePhilox4_32_10_t;
#endif
  T dst_mask[kCount];  // 0 ~ kCount -1 : dst;kCount ~ 2 * kCount - 1: mask
  float rands[kCount];
  MaskType mask_result[kCount];
  using Rand = phi::funcs::uniform_distribution<float>;
  using Cast = kps::IdentityFunctor<T>;
  int deal_size = BLOCK_NUM_X * kCount;

  size_t fix = idx * kCount;

  auto mask_functor = MaskFunctor<T, float>(1.0f - dropout_prob);
  for (; fix < main_offset; fix += stride) {
    kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size);
237 238
    kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(
        &rands[0], Rand(), &state);
239 240 241 242 243 244 245
    // dst
    kps::OperatorBinary<float, T, MaskFunctor<T, float>>(
        &dst_mask[0], &rands[0], mask_functor, kCount);

    // mask
    kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
        &mask_result[0], &dst_mask[0], Cast());
246 247
    kps::WriteData<MaskType, kCount, 1, 1, false>(
        mask + fix, &mask_result[0], deal_size);
248 249 250 251 252 253 254
    if (fix > idx * kCount + 1) {
      __syncthreads();
    }
  }
  int remainder = n - fix;
  if (remainder > 0) {
    kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
255 256
    kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(
        &rands[0], Rand(), &state);
257 258 259 260 261 262
    // dst
    kps::OperatorBinary<float, T, MaskFunctor<T, float>>(
        &dst_mask[0], &rands[0], mask_functor, kCount);
    // mask
    kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
        &mask_result[0], &dst_mask[0], Cast());
263 264
    kps::WriteData<MaskType, kCount, 1, 1, true>(
        mask + fix, &mask_result[0], remainder);
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
    __syncthreads();
  }
}

inline void CalcBroadcastedMask(const phi::GPUContext& dev_ctx,
                                const framework::Tensor& mask,
                                framework::Tensor* broadcasted_mask) {
  // The broadcast of mask can be combined to the following ElementwiseKernel
  // when the BroadcastKernel supports different input types.
  broadcasted_mask->mutable_data<uint8_t>(dev_ctx.GetPlace());

  std::vector<const framework::Tensor*> ins = {&mask};
  std::vector<framework::Tensor*> outs = {broadcasted_mask};
  phi::funcs::BroadcastKernel<phi::ElementwiseType::kUnary, uint8_t, uint8_t>(
      dev_ctx, ins, &outs, -1, kps::IdentityFunctor<uint8_t>());
}

template <typename T, typename MT>
void ScaleByDropoutFactor(const phi::GPUContext& dev_ctx,
284 285
                          const framework::Tensor& x,
                          framework::Tensor* y,
286 287 288 289 290 291 292
                          MT factor) {
  std::vector<const framework::Tensor*> ins = {&x};
  std::vector<framework::Tensor*> outs = {y};
  auto functor = phi::funcs::ScaleFunctor<T>(factor);
  phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
}

293
template <typename T>
294 295 296 297 298 299
void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx,
                              bool is_test,
                              float dropout_prob,
                              bool upscale_in_train,
                              bool is_fix_seed,
                              int seed_val,
H
hong 已提交
300 301
                              const framework::Tensor& x,
                              const framework::Tensor* seed,
302 303
                              framework::Tensor* mask,
                              framework::Tensor* y,
304
                              bool is_dropout_nd = false) {
305 306 307 308
  int64_t x_numel = x.numel();
  auto stream = dev_ctx.stream();
  auto* x_data = x.data<T>();
  auto* y_data = y->data<T>();
309 310 311

  if (!is_test) {
    auto* mask_data = mask->data<uint8_t>();
312
    size_t size = phi::product(mask->dims());
313 314 315

    if (dropout_prob == 1.0f) {
#ifdef PADDLE_WITH_HIP
316
      PADDLE_ENFORCE_GPU_SUCCESS(
317
          hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
318
      PADDLE_ENFORCE_GPU_SUCCESS(
319 320
          hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#else
321
      PADDLE_ENFORCE_GPU_SUCCESS(
322
          cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
323
      PADDLE_ENFORCE_GPU_SUCCESS(
324 325 326 327 328 329 330
          cudaMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#endif
      return;
    }

    uint64_t seed_data;
    uint64_t increment;
331
    // VectorizedRandomGenerator use curand_uniform4, so kVecSize is 4;
332 333
    constexpr int kVecSize =
        phi::funcs::uniform_distribution<float>::kReturnsCount;
H
hong 已提交
334
    auto gpu_config =
335
        phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, kVecSize);
336 337 338
    size_t grid_size = gpu_config.GetGridSize();
    size_t block_size = gpu_config.GetBlockSize();

339 340 341 342 343
    int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
    const auto& prop = platform::GetDeviceProperties(device_id);
    size_t max_grid_size = prop.maxThreadsPerMultiProcessor *
                           prop.multiProcessorCount / block_size;
    grid_size = std::min(grid_size, max_grid_size);
344

Z
Zhang Ting 已提交
345
    auto offset =
346
        ((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
347 348
    GetSeedDataAndIncrement(
        dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);
349 350 351
    size_t main_offset =
        size / (block_size * kVecSize) * (block_size * kVecSize);

352
    if (is_dropout_nd) {
353 354 355 356 357 358 359 360
      VectorizedGeneratorMask<T, uint8_t>
          <<<grid_size, block_size, 0, stream>>>(size,
                                                 seed_data,
                                                 dropout_prob,
                                                 x_data,
                                                 mask_data,
                                                 increment,
                                                 main_offset);
361 362 363 364 365

      framework::Tensor broadcasted_mask;
      broadcasted_mask.Resize(x.dims());
      CalcBroadcastedMask(dev_ctx, *mask, &broadcasted_mask);

366 367
      auto dst_functor = DstFunctor<T, uint8_t>(
          1.0f - dropout_prob, upscale_in_train, x_numel);
368 369 370 371
      std::vector<const framework::Tensor*> ins = {&x, &broadcasted_mask};
      std::vector<framework::Tensor*> outs = {y};
      phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, dst_functor);
    } else {
S
sneaxiy 已提交
372
#define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T, uint8_t>
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
      PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!is_fix_seed,
                                         PD_DROPOUT_KERNEL_NAME,
                                         grid_size,
                                         block_size,
                                         0,
                                         stream,
                                         offset,
                                         KERNEL_PARAMS.As<uint64_t>(1),
                                         KERNEL_PARAMS.As<uint64_t>(7),
                                         size,
                                         seed_data,
                                         dropout_prob,
                                         x_data,
                                         mask_data,
                                         y_data,
                                         upscale_in_train,
                                         increment,
                                         main_offset);
S
sneaxiy 已提交
391
#undef PD_DROPOUT_KERNEL_NAME
392
    }
393 394
  } else {
    if (upscale_in_train) {
395 396
      // y = x
      framework::TensorCopy(x, dev_ctx.GetPlace(), dev_ctx, y);
397
    } else {
398 399
      using MT = typename details::MPTypeTrait<T>::Type;
      MT factor = static_cast<MT>(1.0f - dropout_prob);
400 401
      // y = factor * x
      ScaleByDropoutFactor<T, MT>(dev_ctx, x, y, factor);
402 403 404 405
    }
  }
}

406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421
template <typename T, typename MaskType>
struct CudaDropoutGradFunctor {
  using MT = typename details::MPTypeTrait<T>::Type;

  explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {}

  __device__ __forceinline__ T operator()(const T dout,
                                          const MaskType mask) const {
    return static_cast<T>(static_cast<MT>(dout) * static_cast<MT>(mask) *
                          factor_);
  }

 private:
  MT factor_;
};

422
template <typename T>
423 424 425 426
void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
                                bool is_test,
                                float dropout_prob,
                                bool upscale_in_train,
H
hong 已提交
427
                                const framework::Tensor& grad_y,
428
                                const framework::Tensor& mask,
H
hong 已提交
429
                                framework::Tensor* grad_x,
430
                                bool is_dropout_nd = false) {
S
sneaxiy 已提交
431
  using MT = typename details::MPTypeTrait<T>::Type;
432

433
  auto stream = dev_ctx.stream();
434
  if (is_test) {
435 436 437
    MT factor = static_cast<MT>(upscale_in_train ? 1.0f : 1.0f - dropout_prob);
    // y = factor * x
    ScaleByDropoutFactor<T, MT>(dev_ctx, grad_y, grad_x, factor);
438
  } else {
439 440 441 442 443 444 445 446
    framework::Tensor broadcasted_mask;
    if (is_dropout_nd) {
      broadcasted_mask.Resize(grad_y.dims());
      CalcBroadcastedMask(dev_ctx, mask, &broadcasted_mask);
    }

    std::vector<const framework::Tensor*> ins = {
        &grad_y, is_dropout_nd ? &broadcasted_mask : &mask};
447
    std::vector<framework::Tensor*> outs = {grad_x};
448
    if (upscale_in_train) {
449
      if (dropout_prob == 1.0f) {
450
#ifdef PADDLE_WITH_HIP
451
        hipMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
452
#else
453
        cudaMemset(grad_x->data<T>(), 0, grad_x->numel() * sizeof(T));
454
#endif
455
      } else {
456 457
        MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
        phi::funcs::ElementwiseKernel<T>(
458
            dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
459
      }
460
    } else {
461 462
      MT factor = static_cast<MT>(1.0f);
      phi::funcs::ElementwiseKernel<T>(
463
          dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
464 465 466 467 468 469
    }
  }
}

}  // namespace operators
}  // namespace paddle