dropout_impl.cu.h 14.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>

#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <curand_kernel.h>
#include "paddle/fluid/platform/dynload/curand.h"
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#include "paddle/fluid/platform/dynload/hiprand.h"
#endif

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/tensor_util.h"
S
sneaxiy 已提交
33
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
34
#include "paddle/fluid/operators/dropout_impl_util.h"
35
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
H
hong 已提交
36 37
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
38
#include "paddle/phi/kernels/funcs/distribution_helper.h"
39
#include "paddle/phi/kernels/funcs/functors.h"
40 41 42

DECLARE_bool(use_curand);

43 44
namespace paddle {
namespace operators {
45

46 47 48 49 50 51 52 53 54 55 56
template <typename T1, typename T2 = T1, typename OutT = T1>
struct DstMaskGenerator {
  const float dropout_prob_;
  const bool is_upscale_in_train_;
  using MT = typename details::MPTypeTrait<T1>::Type;
  MT factor;
  HOSTDEVICE inline DstMaskGenerator(const float dropout_prob,
                                     const bool is_upscale_in_train)
      : dropout_prob_(dropout_prob), is_upscale_in_train_(is_upscale_in_train) {
    factor = static_cast<MT>(1.0f / (1.0f - dropout_prob_));
  }
57

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
  HOSTDEVICE inline void operator()(OutT* dst, const T1* src_val,
                                    const T2* rand, int num) const {
    static constexpr int kCount =
        phi::funcs::uniform_distribution<T2>::kReturnsCount;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
    for (int i = 0; i < kCount; i++) {
      if (rand[i] < dropout_prob_) {
        dst[i] = static_cast<T1>(0);
        dst[i + kCount] = dst[i];
      } else {
        dst[i] = is_upscale_in_train_
                     ? static_cast<T1>(static_cast<MT>(src_val[i]) * factor)
                     : static_cast<T1>(src_val[i]);
        dst[i + kCount] = static_cast<T1>(1);
      }
74 75
    }
  }
76
};
77

78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
template <typename T1, typename T2 = T1, typename OutT = T1>
struct DstMaskFunctor {
  const float retain_prob_;
  const bool is_upscale_in_train_;
  using MT = typename details::MPTypeTrait<T1>::Type;
  MT factor;
  HOSTDEVICE inline DstMaskFunctor(const float retain_prob,
                                   const bool is_upscale_in_train)
      : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) {
    factor = static_cast<MT>(1.0f / retain_prob_);
  }

  HOSTDEVICE inline void operator()(OutT* dst, const T1* src_val,
                                    const T2* rand, int num) const {
    static constexpr int kCount =
        phi::funcs::uniform_distribution<T2>::kReturnsCount;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
    for (int i = 0; i < kCount; i++) {
      if (rand[i] < retain_prob_) {
        dst[i] = is_upscale_in_train_
                     ? static_cast<T1>(static_cast<MT>(src_val[i]) * factor)
                     : static_cast<T1>(src_val[i]);
        dst[i + kCount] = static_cast<T1>(1);
      } else {
        dst[i] = static_cast<T1>(0);
        dst[i + kCount] = dst[i];
      }
    }
  }
};

110
template <typename T, typename MaskType>
111 112 113 114
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
                                          const float dropout_prob,
                                          const T* src, MaskType* mask, T* dst,
                                          bool is_upscale_in_train,
115
                                          uint64_t increment,
116
                                          size_t main_offset, bool use_curand) {
117 118 119 120
  size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
  static constexpr int kCount =
      phi::funcs::uniform_distribution<float>::kReturnsCount;
  size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
121 122
#ifdef PADDLE_WITH_HIP
  hiprandStatePhilox4_32_10_t state;
123 124
  hiprand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = hiprandStatePhilox4_32_10_t;
125 126
#else
  curandStatePhilox4_32_10_t state;
127 128
  curand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = curandStatePhilox4_32_10_t;
129
#endif
130 131 132 133 134 135
  T dst_mask[kCount * 2];  // 0 ~ kCount -1 : dst;kCount ~ 2 * kCount - 1: mask
  float rands[kCount];
  MaskType mask_result[kCount];
  using Rand = phi::funcs::uniform_distribution<float>;
  using Cast = kps::IdentityFunctor<T>;
  int deal_size = BLOCK_NUM_X * kCount;
136

137
  size_t fix = idx * kCount;
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
  if (use_curand) {
    auto dst_functor =
        DstMaskFunctor<T, float>(1.0f - dropout_prob, is_upscale_in_train);
    for (; fix < main_offset; fix += stride) {
      kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size);
      kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
                                                            &state);
      // dst
      kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
          &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
      kps::WriteData<T, kCount, 1, 1, false>(dst + fix, &dst_mask[0],
                                             deal_size);
      // mask
      kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
          &mask_result[0], &dst_mask[kCount], Cast());
      kps::WriteData<MaskType, kCount, 1, 1, false>(mask + fix, &mask_result[0],
                                                    deal_size);
      if (fix > idx * kCount + 1) {
        __syncthreads();
      }
    }
    int remainder = n - fix;
    if (remainder > 0) {
      kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
      kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
                                                            &state);
      // dst
      kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
          &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
      kps::WriteData<T, kCount, 1, 1, true>(dst + fix, &dst_mask[0], remainder);
      // mask
      kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
          &mask_result[0], &dst_mask[kCount], Cast());
      kps::WriteData<MaskType, kCount, 1, 1, true>(mask + fix, &mask_result[0],
                                                   remainder);
      __syncthreads();
    }
  } else {
    auto dst_functor =
        DstMaskGenerator<T, float>(dropout_prob, is_upscale_in_train);
    for (; fix < main_offset; fix += stride) {
      kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size);
      kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
                                                            &state);
      // dst
      kps::OperatorTernary<T, float, T, DstMaskGenerator<T, float>>(
          &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
      kps::WriteData<T, kCount, 1, 1, false>(dst + fix, &dst_mask[0],
                                             deal_size);
      // mask
      kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
          &mask_result[0], &dst_mask[kCount], Cast());
      kps::WriteData<MaskType, kCount, 1, 1, false>(mask + fix, &mask_result[0],
                                                    deal_size);
    }
    int remainder = n - fix;
    if (remainder > 0) {
      kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
      kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
                                                            &state);
      // dst
      kps::OperatorTernary<T, float, T, DstMaskGenerator<T, float>>(
          &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
      kps::WriteData<T, kCount, 1, 1, true>(dst + fix, &dst_mask[0], remainder);
      // mask
      kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
          &mask_result[0], &dst_mask[kCount], Cast());
      kps::WriteData<MaskType, kCount, 1, 1, true>(mask + fix, &mask_result[0],
                                                   remainder);
    }
208 209 210 211
  }
}

template <typename T>
H
hong 已提交
212
void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
213 214
                              const std::string dropout_implementation,
                              float dropout_prob, bool upscale_in_train,
H
hong 已提交
215 216 217 218
                              bool is_fix_seed, int seed_val,
                              const framework::Tensor& x,
                              const framework::Tensor* seed,
                              framework::Tensor* mask, framework::Tensor* y) {
219
  auto& place = *dev_ctx.eigen_device();
220 221 222 223
  int64_t x_numel = x.numel();
  auto stream = dev_ctx.stream();
  auto* x_data = x.data<T>();
  auto* y_data = y->data<T>();
224 225 226

  if (!is_test) {
    auto* mask_data = mask->data<uint8_t>();
227
    size_t size = phi::product(mask->dims());
228 229 230

    if (dropout_prob == 1.0f) {
#ifdef PADDLE_WITH_HIP
231
      PADDLE_ENFORCE_GPU_SUCCESS(
232
          hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
233
      PADDLE_ENFORCE_GPU_SUCCESS(
234 235
          hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#else
236
      PADDLE_ENFORCE_GPU_SUCCESS(
237
          cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
238
      PADDLE_ENFORCE_GPU_SUCCESS(
239 240 241 242 243 244 245
          cudaMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#endif
      return;
    }

    uint64_t seed_data;
    uint64_t increment;
246
    // VectorizedRandomGenerator use curand_uniform4, so kVecSize is 4;
247 248
    constexpr int kVecSize =
        phi::funcs::uniform_distribution<float>::kReturnsCount;
H
hong 已提交
249
    auto gpu_config =
250
        phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, kVecSize);
251 252 253 254 255 256 257 258 259 260 261
    size_t grid_size = gpu_config.GetGridSize();
    size_t block_size = gpu_config.GetBlockSize();

    if (FLAGS_use_curand) {
      int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
      const auto& prop = platform::GetDeviceProperties(device_id);
      size_t max_grid_size = prop.maxThreadsPerMultiProcessor *
                             prop.multiProcessorCount / block_size;
      grid_size = std::min(grid_size, max_grid_size);
    }

Z
Zhang Ting 已提交
262
    auto offset =
263
        ((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
264 265
    GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset,
                            &seed_data, &increment);
266 267 268 269
    size_t main_offset =
        size / (block_size * kVecSize) * (block_size * kVecSize);

    VectorizedRandomGenerator<T, uint8_t><<<grid_size, block_size, 0, stream>>>(
270
        size, seed_data, dropout_prob, x_data, mask_data, y_data,
271
        upscale_in_train, increment, main_offset, FLAGS_use_curand);
272 273
  } else {
    if (upscale_in_train) {
274 275 276 277 278 279 280 281 282 283
// todo: can y share with data with x directly?
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_GPU_SUCCESS(
          hipMemcpyAsync(y_data, x_data, sizeof(T) * x_numel,
                         hipMemcpyDeviceToDevice, stream));
#else
      PADDLE_ENFORCE_GPU_SUCCESS(
          cudaMemcpyAsync(y_data, x_data, sizeof(T) * x_numel,
                          cudaMemcpyDeviceToDevice, stream));
#endif
284
    } else {
285 286
      using MT = typename details::MPTypeTrait<T>::Type;
      MT factor = static_cast<MT>(1.0f - dropout_prob);
287 288 289 290 291
      std::vector<const framework::Tensor*> ins = {&x};
      std::vector<framework::Tensor*> outs = {y};
      auto functor = phi::funcs::ScaleFunctor<T>(factor);
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
292 293 294 295
    }
  }
}

296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
template <typename T, typename MaskType>
struct CudaDropoutGradFunctor {
  using MT = typename details::MPTypeTrait<T>::Type;

  explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {}

  __device__ __forceinline__ T operator()(const T dout,
                                          const MaskType mask) const {
    return static_cast<T>(static_cast<MT>(dout) * static_cast<MT>(mask) *
                          factor_);
  }

 private:
  MT factor_;
};

312
template <typename T>
H
hong 已提交
313
void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
314
                                const std::string dropout_implementation,
H
hong 已提交
315 316 317 318 319
                                float dropout_prob,
                                const framework::Tensor& grad_y,
                                const framework::Tensor& mask, int64_t size,
                                framework::Tensor* grad_x,
                                bool is_test = false) {
S
sneaxiy 已提交
320
  using MT = typename details::MPTypeTrait<T>::Type;
321 322
  auto stream = dev_ctx.stream();
  MT factor;
323 324
  if (is_test) {
    if (dropout_implementation == "upscale_in_train") {
325
      factor = static_cast<MT>(1.0f);
326
    } else {
327
      factor = static_cast<MT>(1.0f - dropout_prob);
328
    }
329 330 331 332 333
    std::vector<const framework::Tensor*> ins = {&grad_y};
    std::vector<framework::Tensor*> outs = {grad_x};
    auto functor = phi::funcs::ScaleFunctor<T>(factor);
    paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                              &outs, functor);
334
  } else {
335 336
    std::vector<const framework::Tensor*> ins = {&grad_y, &mask};
    std::vector<framework::Tensor*> outs = {grad_x};
337 338
    if (dropout_implementation == "upscale_in_train") {
      if (dropout_prob == 1.0f) {
339 340 341 342 343
#ifdef PADDLE_WITH_HIP
        hipMemset(grad_x->data<T>(), 0, size * sizeof(T));
#else
        cudaMemset(grad_x->data<T>(), 0, size * sizeof(T));
#endif
344
      } else {
345
        factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
346
        paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
347
            dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
348
      }
349
    } else {
350 351 352
      factor = static_cast<MT>(1.0f);
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
          dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
353 354 355 356 357 358
    }
  }
}

}  // namespace operators
}  // namespace paddle