dropout_impl.cu.h 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>

#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <curand_kernel.h>
#include "paddle/fluid/platform/dynload/curand.h"
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#include "paddle/fluid/platform/dynload/hiprand.h"
#endif

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/tensor_util.h"
S
sneaxiy 已提交
33
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
34
#include "paddle/fluid/operators/dropout_impl_util.h"
35
#include "paddle/fluid/operators/dropout_op.h"
36
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
37
#include "paddle/fluid/platform/aligned_vector.h"
38
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
39
#include "paddle/phi/kernels/funcs/functors.h"
40 41 42 43 44 45 46 47 48

namespace paddle {
namespace operators {

template <typename T, typename MaskType>
__global__ void RandomGenerator(const size_t n, uint64_t seed,
                                const float dropout_prob, const T* src,
                                MaskType* mask, T* dst,
                                bool is_upscale_in_train, uint64_t increment) {
S
sneaxiy 已提交
49
  using MT = typename details::MPTypeTrait<T>::Type;
50 51 52 53 54 55 56 57 58 59 60
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
#ifdef PADDLE_WITH_HIP
  hiprandStatePhilox4_32_10_t state;
  hiprand_init(seed, idx, increment, &state);
#else
  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx, increment, &state);
#endif

  MaskType mask_val;
  T dst_val;
S
sneaxiy 已提交
61
  MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
62 63 64 65 66 67 68 69 70 71 72
  for (; idx < n; idx += blockDim.x * gridDim.x) {
    T src_val = src[idx];
#ifdef PADDLE_WITH_HIP
    if (hiprand_uniform(&state) < dropout_prob) {
#else
    if (curand_uniform(&state) < dropout_prob) {
#endif
      mask_val = 0;
      dst_val = 0;
    } else {
      mask_val = 1;
S
sneaxiy 已提交
73 74 75
      dst_val = is_upscale_in_train
                    ? static_cast<T>(static_cast<MT>(src_val) * factor)
                    : src_val;
76 77 78 79 80 81 82 83 84 85 86 87
    }
    mask[idx] = mask_val;
    dst[idx] = dst_val;
  }
}

template <typename T, typename MaskType, int VecSize>
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
                                          const float dropout_prob,
                                          const T* src, MaskType* mask, T* dst,
                                          bool is_upscale_in_train,
                                          uint64_t increment) {
S
sneaxiy 已提交
88
  using MT = typename details::MPTypeTrait<T>::Type;
89 90 91 92 93 94 95 96 97 98 99 100 101
  using LoadT = platform::AlignedVector<T, VecSize>;
  using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;

#ifdef PADDLE_WITH_HIP
  int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x;
  hiprandStatePhilox4_32_10_t state;
  hiprand_init(seed, idx, increment, &state);
#else
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx, increment, &state);
#endif

S
sneaxiy 已提交
102
  MT factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
  for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
    LoadT src_val;
    platform::Load<T, VecSize>(&src[i], &src_val);

#ifdef PADDLE_WITH_HIP
    float4 rand = hiprand_uniform4(&state);
#else
    float4 rand = curand_uniform4(&state);
#endif

    LoadT dst_val;
    MaskLoadT mask_val;

#pragma unroll
    for (int j = 0; j < VecSize; j++) {
      if ((&rand.x)[j] < dropout_prob) {
        dst_val[j] = 0;
        mask_val[j] = 0;
      } else {
S
sneaxiy 已提交
122 123 124
        dst_val[j] = is_upscale_in_train
                         ? static_cast<T>(static_cast<MT>(src_val[j]) * factor)
                         : src_val[j];
125 126 127 128 129 130 131 132 133
        mask_val[j] = 1;
      }
    }

    platform::Store<T, VecSize>(dst_val, &dst[i]);
    platform::Store<MaskType, VecSize>(mask_val, &mask[i]);
  }
}

134 135
template <typename T, typename MaskType>
struct CudaDropoutGradFunctor {
S
sneaxiy 已提交
136 137 138
  using MT = typename details::MPTypeTrait<T>::Type;

  explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {}
139 140 141

  __device__ __forceinline__ T operator()(const T dout,
                                          const MaskType mask) const {
S
sneaxiy 已提交
142 143
    return static_cast<T>(static_cast<MT>(dout) * static_cast<MT>(mask) *
                          factor_);
144 145 146
  }

 private:
S
sneaxiy 已提交
147
  MT factor_;
148 149
};

150
template <typename T, typename MaskType, int VecSize>
S
sneaxiy 已提交
151 152 153 154 155
__global__ void DropoutGradCUDAKernel(
    const T* dout, const MaskType* mask,
    const typename details::MPTypeTrait<T>::Type factor, const int64_t size,
    T* dx) {
  using MT = typename details::MPTypeTrait<T>::Type;
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
  using LoadT = platform::AlignedVector<T, VecSize>;
  using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;

  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
  for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
    LoadT dout_val;
    platform::Load<T, VecSize>(&dout[i], &dout_val);

    MaskLoadT mask_val;
    platform::Load<MaskType, VecSize>(&mask[i], &mask_val);

    LoadT dx_val;

#pragma unroll
    for (int j = 0; j < VecSize; j++) {
S
sneaxiy 已提交
171 172
      dx_val[j] = static_cast<T>(static_cast<MT>(dout_val[j]) *
                                 static_cast<MT>(mask_val[j]) * factor);
173 174 175 176 177 178 179 180 181 182 183 184 185 186
    }

    platform::Store<T, VecSize>(dx_val, &dx[i]);
  }
}

template <typename T>
void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
                              bool is_test,
                              const std::string dropout_implementation,
                              float dropout_prob, bool upscale_in_train,
                              bool is_fix_seed, int seed_val, const Tensor& x,
                              const Tensor* seed, Tensor* mask, Tensor* y) {
  auto& place = *dev_ctx.eigen_device();
187 188 189 190
  int64_t x_numel = x.numel();
  auto stream = dev_ctx.stream();
  auto* x_data = x.data<T>();
  auto* y_data = y->data<T>();
191 192 193

  if (!is_test) {
    auto* mask_data = mask->data<uint8_t>();
194
    size_t size = phi::product(mask->dims());
195 196 197

    if (dropout_prob == 1.0f) {
#ifdef PADDLE_WITH_HIP
198
      PADDLE_ENFORCE_GPU_SUCCESS(
199
          hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
200
      PADDLE_ENFORCE_GPU_SUCCESS(
201 202
          hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#else
203
      PADDLE_ENFORCE_GPU_SUCCESS(
204
          cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
205
      PADDLE_ENFORCE_GPU_SUCCESS(
206 207 208 209 210 211 212 213 214 215 216 217 218 219
          cudaMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#endif
      return;
    }

    // increment is used to set the args(offset) of curand_init, which defines
    // offset in subsequence.
    // The detail:
    // https://docs.nvidia.com/cuda/curand/device-api-overview.html
    // Increment should be at least the number of curand() random numbers used
    // in each thread to avoid the random number generated this time being the
    // same as the previous calls.
    uint64_t seed_data;
    uint64_t increment;
Z
Zhang Ting 已提交
220 221 222
    // VectorizedRandomGenerator use curand_uniform4, so we only support
    // vec_size is 4;
    int vec_size = (platform::GetVectorizedSize<T>(x_data) == 4) ? 4 : 1;
223
    auto gpu_config = GetGpuLaunchConfig1D(dev_ctx, x_numel, vec_size);
Z
Zhang Ting 已提交
224
    auto offset =
225
        ((x_numel - 1) / (gpu_config.GetThreadNum() * vec_size) + 1) * vec_size;
226 227 228

    GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset,
                            &seed_data, &increment);
229 230 231 232

#ifdef __HIPCC__
    if (vec_size == 4 && size % 4 == 0) {
      hipLaunchKernelGGL(
233 234 235 236
          HIP_KERNEL_NAME(VectorizedRandomGenerator<T, uint8_t, 4>),
          gpu_config.GetGridSize(), gpu_config.GetBlockSize(), 0, stream, size,
          seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train,
          increment);
237 238
    } else {
      hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomGenerator<T, uint8_t>),
239 240 241
                         gpu_config.GetGridSize(), gpu_config.GetBlockSize(), 0,
                         stream, size, seed_data, dropout_prob, x_data,
                         mask_data, y_data, upscale_in_train, increment);
242 243 244
    }
#else
    if (vec_size == 4 && size % 4 == 0) {
245 246
      VectorizedRandomGenerator<T, uint8_t, 4><<<
          gpu_config.block_per_grid, gpu_config.thread_per_block, 0, stream>>>(
247 248 249
          size, seed_data, dropout_prob, x_data, mask_data, y_data,
          upscale_in_train, increment);
    } else {
250 251
      RandomGenerator<T, uint8_t><<<gpu_config.block_per_grid,
                                    gpu_config.thread_per_block, 0, stream>>>(
252 253 254 255 256 257
          size, seed_data, dropout_prob, x_data, mask_data, y_data,
          upscale_in_train, increment);
    }
#endif
  } else {
    if (upscale_in_train) {
258 259 260 261 262 263 264 265 266 267
// todo: can y share with data with x directly?
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_GPU_SUCCESS(
          hipMemcpyAsync(y_data, x_data, sizeof(T) * x_numel,
                         hipMemcpyDeviceToDevice, stream));
#else
      PADDLE_ENFORCE_GPU_SUCCESS(
          cudaMemcpyAsync(y_data, x_data, sizeof(T) * x_numel,
                          cudaMemcpyDeviceToDevice, stream));
#endif
268
    } else {
269 270 271 272 273 274
      T factor = static_cast<T>(1.0f - dropout_prob);
      std::vector<const framework::Tensor*> ins = {&x};
      std::vector<framework::Tensor*> outs = {y};
      auto functor = phi::funcs::ScaleFunctor<T>(factor);
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
275 276 277 278 279 280 281 282 283
    }
  }
}

template <typename T>
void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
                                const std::string dropout_implementation,
                                float dropout_prob, const Tensor& grad_y,
                                const Tensor& mask, int64_t size,
284
                                Tensor* grad_x, bool is_test = false) {
S
sneaxiy 已提交
285
  using MT = typename details::MPTypeTrait<T>::Type;
286 287
  auto stream = dev_ctx.stream();
  MT factor;
288 289
  if (is_test) {
    if (dropout_implementation == "upscale_in_train") {
290
      factor = static_cast<MT>(1.0f);
291
    } else {
292
      factor = static_cast<MT>(1.0f - dropout_prob);
293
    }
294 295 296 297 298
    std::vector<const framework::Tensor*> ins = {&grad_y};
    std::vector<framework::Tensor*> outs = {grad_x};
    auto functor = phi::funcs::ScaleFunctor<T>(factor);
    paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                              &outs, functor);
299
  } else {
300 301
    std::vector<const framework::Tensor*> ins = {&grad_y, &mask};
    std::vector<framework::Tensor*> outs = {grad_x};
302 303
    if (dropout_implementation == "upscale_in_train") {
      if (dropout_prob == 1.0f) {
304 305 306 307 308
#ifdef PADDLE_WITH_HIP
        hipMemset(grad_x->data<T>(), 0, size * sizeof(T));
#else
        cudaMemset(grad_x->data<T>(), 0, size * sizeof(T));
#endif
309
      } else {
310
        factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
311
        paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
312
            dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
313
      }
314
    } else {
315 316 317
      factor = static_cast<MT>(1.0f);
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
          dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
318 319 320 321 322 323
    }
  }
}

}  // namespace operators
}  // namespace paddle