dropout_impl.cu.h 11.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>

#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <curand_kernel.h>
22

23 24 25 26 27
#include "paddle/fluid/platform/dynload/curand.h"
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
28

29 30 31 32 33 34
#include "paddle/fluid/platform/dynload/hiprand.h"
#endif

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/tensor_util.h"
S
sneaxiy 已提交
35
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
36
#include "paddle/fluid/operators/dropout_impl_util.h"
37
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
H
hong 已提交
38
#include "paddle/fluid/platform/aligned_vector.h"
39
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
H
hong 已提交
40
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
41
#include "paddle/phi/kernels/funcs/distribution_helper.h"
42
#include "paddle/phi/kernels/funcs/functors.h"
43

44 45
namespace paddle {
namespace operators {
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

template <typename T1, typename T2 = T1, typename OutT = T1>
struct DstMaskFunctor {
  const float retain_prob_;
  const bool is_upscale_in_train_;
  using MT = typename details::MPTypeTrait<T1>::Type;
  MT factor;
  HOSTDEVICE inline DstMaskFunctor(const float retain_prob,
                                   const bool is_upscale_in_train)
      : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) {
    factor = static_cast<MT>(1.0f / retain_prob_);
  }

  HOSTDEVICE inline void operator()(OutT* dst, const T1* src_val,
                                    const T2* rand, int num) const {
    static constexpr int kCount =
        phi::funcs::uniform_distribution<T2>::kReturnsCount;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
    for (int i = 0; i < kCount; i++) {
      if (rand[i] < retain_prob_) {
        dst[i] = is_upscale_in_train_
                     ? static_cast<T1>(static_cast<MT>(src_val[i]) * factor)
                     : static_cast<T1>(src_val[i]);
        dst[i + kCount] = static_cast<T1>(1);
      } else {
        dst[i] = static_cast<T1>(0);
        dst[i + kCount] = dst[i];
      }
    }
  }
};

79
template <typename T, typename MaskType>
80 81 82 83
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
                                          const float dropout_prob,
                                          const T* src, MaskType* mask, T* dst,
                                          bool is_upscale_in_train,
84
                                          uint64_t increment,
85
                                          size_t main_offset) {
86 87 88 89
  size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
  static constexpr int kCount =
      phi::funcs::uniform_distribution<float>::kReturnsCount;
  size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
90 91
#ifdef PADDLE_WITH_HIP
  hiprandStatePhilox4_32_10_t state;
92 93
  hiprand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = hiprandStatePhilox4_32_10_t;
94 95
#else
  curandStatePhilox4_32_10_t state;
96 97
  curand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = curandStatePhilox4_32_10_t;
98
#endif
99 100 101 102 103 104
  T dst_mask[kCount * 2];  // 0 ~ kCount -1 : dst;kCount ~ 2 * kCount - 1: mask
  float rands[kCount];
  MaskType mask_result[kCount];
  using Rand = phi::funcs::uniform_distribution<float>;
  using Cast = kps::IdentityFunctor<T>;
  int deal_size = BLOCK_NUM_X * kCount;
105

106
  size_t fix = idx * kCount;
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123

  auto dst_functor =
      DstMaskFunctor<T, float>(1.0f - dropout_prob, is_upscale_in_train);
  for (; fix < main_offset; fix += stride) {
    kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size);
    kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
                                                          &state);
    // dst
    kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
        &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
    kps::WriteData<T, kCount, 1, 1, false>(dst + fix, &dst_mask[0], deal_size);
    // mask
    kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
        &mask_result[0], &dst_mask[kCount], Cast());
    kps::WriteData<MaskType, kCount, 1, 1, false>(mask + fix, &mask_result[0],
                                                  deal_size);
    if (fix > idx * kCount + 1) {
124 125
      __syncthreads();
    }
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
  }
  int remainder = n - fix;
  if (remainder > 0) {
    kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
    kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
                                                          &state);
    // dst
    kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
        &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
    kps::WriteData<T, kCount, 1, 1, true>(dst + fix, &dst_mask[0], remainder);
    // mask
    kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
        &mask_result[0], &dst_mask[kCount], Cast());
    kps::WriteData<MaskType, kCount, 1, 1, true>(mask + fix, &mask_result[0],
                                                 remainder);
    __syncthreads();
142 143 144 145
  }
}

template <typename T>
H
hong 已提交
146
void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
147 148
                              const std::string dropout_implementation,
                              float dropout_prob, bool upscale_in_train,
H
hong 已提交
149 150 151 152
                              bool is_fix_seed, int seed_val,
                              const framework::Tensor& x,
                              const framework::Tensor* seed,
                              framework::Tensor* mask, framework::Tensor* y) {
153
  auto& place = *dev_ctx.eigen_device();
154 155 156 157
  int64_t x_numel = x.numel();
  auto stream = dev_ctx.stream();
  auto* x_data = x.data<T>();
  auto* y_data = y->data<T>();
158 159 160

  if (!is_test) {
    auto* mask_data = mask->data<uint8_t>();
161
    size_t size = phi::product(mask->dims());
162 163 164

    if (dropout_prob == 1.0f) {
#ifdef PADDLE_WITH_HIP
165
      PADDLE_ENFORCE_GPU_SUCCESS(
166
          hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
167
      PADDLE_ENFORCE_GPU_SUCCESS(
168 169
          hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#else
170
      PADDLE_ENFORCE_GPU_SUCCESS(
171
          cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
172
      PADDLE_ENFORCE_GPU_SUCCESS(
173 174 175 176 177 178 179
          cudaMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#endif
      return;
    }

    uint64_t seed_data;
    uint64_t increment;
180
    // VectorizedRandomGenerator use curand_uniform4, so kVecSize is 4;
181 182
    constexpr int kVecSize =
        phi::funcs::uniform_distribution<float>::kReturnsCount;
H
hong 已提交
183
    auto gpu_config =
184
        phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, kVecSize);
185 186 187
    size_t grid_size = gpu_config.GetGridSize();
    size_t block_size = gpu_config.GetBlockSize();

188 189 190 191 192
    int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
    const auto& prop = platform::GetDeviceProperties(device_id);
    size_t max_grid_size = prop.maxThreadsPerMultiProcessor *
                           prop.multiProcessorCount / block_size;
    grid_size = std::min(grid_size, max_grid_size);
193

Z
Zhang Ting 已提交
194
    auto offset =
195
        ((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
196 197
    GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset,
                            &seed_data, &increment);
198 199 200
    size_t main_offset =
        size / (block_size * kVecSize) * (block_size * kVecSize);

S
sneaxiy 已提交
201
#define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator<T, uint8_t>
202
    PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(
S
sneaxiy 已提交
203 204 205 206 207
        !is_fix_seed, PD_DROPOUT_KERNEL_NAME, grid_size, block_size, 0, stream,
        offset, KERNEL_PARAMS.As<uint64_t>(1), KERNEL_PARAMS.As<uint64_t>(7),
        size, seed_data, dropout_prob, x_data, mask_data, y_data,
        upscale_in_train, increment, main_offset);
#undef PD_DROPOUT_KERNEL_NAME
208 209
  } else {
    if (upscale_in_train) {
210 211 212 213 214 215 216 217 218 219
// todo: can y share with data with x directly?
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_GPU_SUCCESS(
          hipMemcpyAsync(y_data, x_data, sizeof(T) * x_numel,
                         hipMemcpyDeviceToDevice, stream));
#else
      PADDLE_ENFORCE_GPU_SUCCESS(
          cudaMemcpyAsync(y_data, x_data, sizeof(T) * x_numel,
                          cudaMemcpyDeviceToDevice, stream));
#endif
220
    } else {
221 222
      using MT = typename details::MPTypeTrait<T>::Type;
      MT factor = static_cast<MT>(1.0f - dropout_prob);
223 224 225 226 227
      std::vector<const framework::Tensor*> ins = {&x};
      std::vector<framework::Tensor*> outs = {y};
      auto functor = phi::funcs::ScaleFunctor<T>(factor);
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
228 229 230 231
    }
  }
}

232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
template <typename T, typename MaskType>
struct CudaDropoutGradFunctor {
  using MT = typename details::MPTypeTrait<T>::Type;

  explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {}

  __device__ __forceinline__ T operator()(const T dout,
                                          const MaskType mask) const {
    return static_cast<T>(static_cast<MT>(dout) * static_cast<MT>(mask) *
                          factor_);
  }

 private:
  MT factor_;
};

248
template <typename T>
H
hong 已提交
249
void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
250
                                const std::string dropout_implementation,
H
hong 已提交
251 252 253 254 255
                                float dropout_prob,
                                const framework::Tensor& grad_y,
                                const framework::Tensor& mask, int64_t size,
                                framework::Tensor* grad_x,
                                bool is_test = false) {
S
sneaxiy 已提交
256
  using MT = typename details::MPTypeTrait<T>::Type;
257 258
  auto stream = dev_ctx.stream();
  MT factor;
259 260
  if (is_test) {
    if (dropout_implementation == "upscale_in_train") {
261
      factor = static_cast<MT>(1.0f);
262
    } else {
263
      factor = static_cast<MT>(1.0f - dropout_prob);
264
    }
265 266 267 268 269
    std::vector<const framework::Tensor*> ins = {&grad_y};
    std::vector<framework::Tensor*> outs = {grad_x};
    auto functor = phi::funcs::ScaleFunctor<T>(factor);
    paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                              &outs, functor);
270
  } else {
271 272
    std::vector<const framework::Tensor*> ins = {&grad_y, &mask};
    std::vector<framework::Tensor*> outs = {grad_x};
273 274
    if (dropout_implementation == "upscale_in_train") {
      if (dropout_prob == 1.0f) {
275 276 277 278 279
#ifdef PADDLE_WITH_HIP
        hipMemset(grad_x->data<T>(), 0, size * sizeof(T));
#else
        cudaMemset(grad_x->data<T>(), 0, size * sizeof(T));
#endif
280
      } else {
281
        factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
282
        paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
283
            dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
284
      }
285
    } else {
286 287 288
      factor = static_cast<MT>(1.0f);
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
          dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
289 290 291 292 293 294
    }
  }
}

}  // namespace operators
}  // namespace paddle