dropout_impl.cu.h 11.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>

#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <curand_kernel.h>
#include "paddle/fluid/platform/dynload/curand.h"
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#include "paddle/fluid/platform/dynload/hiprand.h"
#endif

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/tensor_util.h"
S
sneaxiy 已提交
33
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
34
#include "paddle/fluid/operators/dropout_impl_util.h"
35
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
H
hong 已提交
36
#include "paddle/fluid/platform/aligned_vector.h"
37
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
H
hong 已提交
38
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
39
#include "paddle/phi/kernels/funcs/distribution_helper.h"
40
#include "paddle/phi/kernels/funcs/functors.h"
41

42 43
namespace paddle {
namespace operators {
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76

template <typename T1, typename T2 = T1, typename OutT = T1>
struct DstMaskFunctor {
  const float retain_prob_;
  const bool is_upscale_in_train_;
  using MT = typename details::MPTypeTrait<T1>::Type;
  MT factor;
  HOSTDEVICE inline DstMaskFunctor(const float retain_prob,
                                   const bool is_upscale_in_train)
      : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) {
    factor = static_cast<MT>(1.0f / retain_prob_);
  }

  HOSTDEVICE inline void operator()(OutT* dst, const T1* src_val,
                                    const T2* rand, int num) const {
    static constexpr int kCount =
        phi::funcs::uniform_distribution<T2>::kReturnsCount;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
    for (int i = 0; i < kCount; i++) {
      if (rand[i] < retain_prob_) {
        dst[i] = is_upscale_in_train_
                     ? static_cast<T1>(static_cast<MT>(src_val[i]) * factor)
                     : static_cast<T1>(src_val[i]);
        dst[i + kCount] = static_cast<T1>(1);
      } else {
        dst[i] = static_cast<T1>(0);
        dst[i + kCount] = dst[i];
      }
    }
  }
};

77
template <typename T, typename MaskType>
78 79 80 81
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
                                          const float dropout_prob,
                                          const T* src, MaskType* mask, T* dst,
                                          bool is_upscale_in_train,
82
                                          uint64_t increment,
83
                                          size_t main_offset) {
84 85 86 87
  size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
  static constexpr int kCount =
      phi::funcs::uniform_distribution<float>::kReturnsCount;
  size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
88 89
#ifdef PADDLE_WITH_HIP
  hiprandStatePhilox4_32_10_t state;
90 91
  hiprand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = hiprandStatePhilox4_32_10_t;
92 93
#else
  curandStatePhilox4_32_10_t state;
94 95
  curand_init(seed, idx + THREAD_ID_X, increment, &state);
  using SType = curandStatePhilox4_32_10_t;
96
#endif
97 98 99 100 101 102
  T dst_mask[kCount * 2];  // 0 ~ kCount -1 : dst;kCount ~ 2 * kCount - 1: mask
  float rands[kCount];
  MaskType mask_result[kCount];
  using Rand = phi::funcs::uniform_distribution<float>;
  using Cast = kps::IdentityFunctor<T>;
  int deal_size = BLOCK_NUM_X * kCount;
103

104
  size_t fix = idx * kCount;
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

  auto dst_functor =
      DstMaskFunctor<T, float>(1.0f - dropout_prob, is_upscale_in_train);
  for (; fix < main_offset; fix += stride) {
    kps::ReadData<T, kCount, 1, 1, false>(&dst_mask[0], src + fix, deal_size);
    kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
                                                          &state);
    // dst
    kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
        &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
    kps::WriteData<T, kCount, 1, 1, false>(dst + fix, &dst_mask[0], deal_size);
    // mask
    kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
        &mask_result[0], &dst_mask[kCount], Cast());
    kps::WriteData<MaskType, kCount, 1, 1, false>(mask + fix, &mask_result[0],
                                                  deal_size);
    if (fix > idx * kCount + 1) {
122 123
      __syncthreads();
    }
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
  }
  int remainder = n - fix;
  if (remainder > 0) {
    kps::ReadData<T, kCount, 1, 1, true>(&dst_mask[0], src + fix, remainder);
    kps::ElementwiseRandom<SType, float, kCount, 1, Rand>(&rands[0], Rand(),
                                                          &state);
    // dst
    kps::OperatorTernary<T, float, T, DstMaskFunctor<T, float>>(
        &dst_mask[0], &dst_mask[0], &rands[0], dst_functor, kCount);
    kps::WriteData<T, kCount, 1, 1, true>(dst + fix, &dst_mask[0], remainder);
    // mask
    kps::ElementwiseUnary<T, MaskType, kCount, 1, 1, Cast>(
        &mask_result[0], &dst_mask[kCount], Cast());
    kps::WriteData<MaskType, kCount, 1, 1, true>(mask + fix, &mask_result[0],
                                                 remainder);
    __syncthreads();
140 141 142 143
  }
}

template <typename T>
H
hong 已提交
144
void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
145 146
                              const std::string dropout_implementation,
                              float dropout_prob, bool upscale_in_train,
H
hong 已提交
147 148 149 150
                              bool is_fix_seed, int seed_val,
                              const framework::Tensor& x,
                              const framework::Tensor* seed,
                              framework::Tensor* mask, framework::Tensor* y) {
151
  auto& place = *dev_ctx.eigen_device();
152 153 154 155
  int64_t x_numel = x.numel();
  auto stream = dev_ctx.stream();
  auto* x_data = x.data<T>();
  auto* y_data = y->data<T>();
156 157 158

  if (!is_test) {
    auto* mask_data = mask->data<uint8_t>();
159
    size_t size = phi::product(mask->dims());
160 161 162

    if (dropout_prob == 1.0f) {
#ifdef PADDLE_WITH_HIP
163
      PADDLE_ENFORCE_GPU_SUCCESS(
164
          hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
165
      PADDLE_ENFORCE_GPU_SUCCESS(
166 167
          hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#else
168
      PADDLE_ENFORCE_GPU_SUCCESS(
169
          cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
170
      PADDLE_ENFORCE_GPU_SUCCESS(
171 172 173 174 175 176 177
          cudaMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#endif
      return;
    }

    uint64_t seed_data;
    uint64_t increment;
178
    // VectorizedRandomGenerator use curand_uniform4, so kVecSize is 4;
179 180
    constexpr int kVecSize =
        phi::funcs::uniform_distribution<float>::kReturnsCount;
H
hong 已提交
181
    auto gpu_config =
182
        phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, kVecSize);
183 184 185
    size_t grid_size = gpu_config.GetGridSize();
    size_t block_size = gpu_config.GetBlockSize();

186 187 188 189 190
    int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
    const auto& prop = platform::GetDeviceProperties(device_id);
    size_t max_grid_size = prop.maxThreadsPerMultiProcessor *
                           prop.multiProcessorCount / block_size;
    grid_size = std::min(grid_size, max_grid_size);
191

Z
Zhang Ting 已提交
192
    auto offset =
193
        ((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
194 195
    GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset,
                            &seed_data, &increment);
196 197 198
    size_t main_offset =
        size / (block_size * kVecSize) * (block_size * kVecSize);

199 200 201 202 203
    PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(
        !is_fix_seed, (VectorizedRandomGenerator<T, uint8_t>), grid_size,
        block_size, 0, stream, offset, KERNEL_PARAMS.As<uint64_t>(1),
        KERNEL_PARAMS.As<uint64_t>(7), size, seed_data, dropout_prob, x_data,
        mask_data, y_data, upscale_in_train, increment, main_offset);
204 205
  } else {
    if (upscale_in_train) {
206 207 208 209 210 211 212 213 214 215
// todo: can y share with data with x directly?
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_GPU_SUCCESS(
          hipMemcpyAsync(y_data, x_data, sizeof(T) * x_numel,
                         hipMemcpyDeviceToDevice, stream));
#else
      PADDLE_ENFORCE_GPU_SUCCESS(
          cudaMemcpyAsync(y_data, x_data, sizeof(T) * x_numel,
                          cudaMemcpyDeviceToDevice, stream));
#endif
216
    } else {
217 218
      using MT = typename details::MPTypeTrait<T>::Type;
      MT factor = static_cast<MT>(1.0f - dropout_prob);
219 220 221 222 223
      std::vector<const framework::Tensor*> ins = {&x};
      std::vector<framework::Tensor*> outs = {y};
      auto functor = phi::funcs::ScaleFunctor<T>(factor);
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
224 225 226 227
    }
  }
}

228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
template <typename T, typename MaskType>
struct CudaDropoutGradFunctor {
  using MT = typename details::MPTypeTrait<T>::Type;

  explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {}

  __device__ __forceinline__ T operator()(const T dout,
                                          const MaskType mask) const {
    return static_cast<T>(static_cast<MT>(dout) * static_cast<MT>(mask) *
                          factor_);
  }

 private:
  MT factor_;
};

244
template <typename T>
H
hong 已提交
245
void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
246
                                const std::string dropout_implementation,
H
hong 已提交
247 248 249 250 251
                                float dropout_prob,
                                const framework::Tensor& grad_y,
                                const framework::Tensor& mask, int64_t size,
                                framework::Tensor* grad_x,
                                bool is_test = false) {
S
sneaxiy 已提交
252
  using MT = typename details::MPTypeTrait<T>::Type;
253 254
  auto stream = dev_ctx.stream();
  MT factor;
255 256
  if (is_test) {
    if (dropout_implementation == "upscale_in_train") {
257
      factor = static_cast<MT>(1.0f);
258
    } else {
259
      factor = static_cast<MT>(1.0f - dropout_prob);
260
    }
261 262 263 264 265
    std::vector<const framework::Tensor*> ins = {&grad_y};
    std::vector<framework::Tensor*> outs = {grad_x};
    auto functor = phi::funcs::ScaleFunctor<T>(factor);
    paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                              &outs, functor);
266
  } else {
267 268
    std::vector<const framework::Tensor*> ins = {&grad_y, &mask};
    std::vector<framework::Tensor*> outs = {grad_x};
269 270
    if (dropout_implementation == "upscale_in_train") {
      if (dropout_prob == 1.0f) {
271 272 273 274 275
#ifdef PADDLE_WITH_HIP
        hipMemset(grad_x->data<T>(), 0, size * sizeof(T));
#else
        cudaMemset(grad_x->data<T>(), 0, size * sizeof(T));
#endif
276
      } else {
277
        factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
278
        paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
279
            dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
280
      }
281
    } else {
282 283 284
      factor = static_cast<MT>(1.0f);
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
          dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
285 286 287 288 289 290
    }
  }
}

}  // namespace operators
}  // namespace paddle