dropout_op.cu 9.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#ifdef PADDLE_WITH_CUDA
16 17
#include <cuda.h>
#include <curand_kernel.h>
18 19 20 21 22 23 24
#include "paddle/fluid/platform/dynload/curand.h"
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#include "paddle/fluid/platform/dynload/hiprand.h"
#endif
25 26 27 28
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
Z
Zhang Ting 已提交
29
#include <algorithm>
P
phlrain 已提交
30
#include <string>
31
#include "paddle/fluid/memory/memcpy.h"
Y
Yi Wang 已提交
32
#include "paddle/fluid/operators/dropout_op.h"
K
Kexin Zhao 已提交
33
#include "paddle/fluid/platform/float16.h"
34

35 36 37
namespace paddle {
namespace operators {

38
template <typename T, typename MaskType>
Z
Zhang Ting 已提交
39 40 41 42
__global__ void RandomGenerator(const size_t n, uint64_t seed,
                                const float dropout_prob, const T* src,
                                MaskType* mask_data, T* dst,
                                bool is_upscale_in_train, uint64_t increment) {
43
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
44 45 46 47 48
#ifdef PADDLE_WITH_HIP
  hiprandStatePhilox4_32_10_t state;
  hiprand_init(seed, idx, increment, &state);
#else
  curandStatePhilox4_32_10_t state;
Z
Zhang Ting 已提交
49
  curand_init(seed, idx, increment, &state);
50
#endif
51 52 53 54 55

  MaskType mask;
  T dest;
  for (; idx < n; idx += blockDim.x * gridDim.x) {
    T s = src[idx];
56 57 58
#ifdef PADDLE_WITH_HIP
    if (hiprand_uniform(&state) < dropout_prob) {
#else
59
    if (curand_uniform(&state) < dropout_prob) {
60
#endif
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
      mask = 0;
      dest = 0;
    } else {
      mask = 1;
      if (is_upscale_in_train) {
        dest = s / static_cast<T>(1.0f - dropout_prob);
      } else {
        dest = s;
      }
    }
    mask_data[idx] = mask;
    dst[idx] = dest;
  }
}

Z
Zhang Ting 已提交
76 77 78 79 80 81
template <typename T, typename MaskType, int VecSize>
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
                                          const float dropout_prob,
                                          const T* src, MaskType* mask_data,
                                          T* dst, bool is_upscale_in_train,
                                          uint64_t increment) {
82 83 84 85 86
#ifdef PADDLE_WITH_HIP
  int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x;
  hiprandStatePhilox4_32_10_t state;
  hiprand_init(seed, idx, increment, &state);
#else
Z
Zhang Ting 已提交
87
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
Y
yaoxuefeng 已提交
88
  curandStatePhilox4_32_10_t state;
Z
Zhang Ting 已提交
89
  curand_init(seed, idx, increment, &state);
90
#endif
Y
yaoxuefeng 已提交
91 92 93

  MaskType mask;
  T dest;
Z
Zhang Ting 已提交
94 95 96 97 98 99 100
  using LoadT = AlignedVector<T, VecSize>;
  using MaskLoadT = AlignedVector<MaskType, VecSize>;
  T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
  for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
    T src_vec[VecSize];
    LoadT* value = reinterpret_cast<LoadT*>(&src_vec);
    *value = *reinterpret_cast<const LoadT*>(&src[i]);
101 102 103
#ifdef PADDLE_WITH_HIP
    float4 rand = hiprand_uniform4(&state);
#else
Z
Zhang Ting 已提交
104
    float4 rand = curand_uniform4(&state);
105
#endif
Z
Zhang Ting 已提交
106 107 108 109 110 111 112 113 114

    T dest_vec[VecSize];
    MaskType mask_vec[VecSize];

#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
      if ((&rand.x)[ii] < dropout_prob) {
        dest_vec[ii] = 0;
        mask_vec[ii] = 0;
Y
yaoxuefeng 已提交
115
      } else {
Z
Zhang Ting 已提交
116 117 118 119 120 121
        if (is_upscale_in_train) {
          dest_vec[ii] = src_vec[ii] * factor;
        } else {
          dest_vec[ii] = src_vec[ii];
        }
        mask_vec[ii] = 1;
Y
yaoxuefeng 已提交
122 123
      }
    }
Z
Zhang Ting 已提交
124 125 126 127 128

    *(reinterpret_cast<LoadT*>(&dst[i])) =
        *reinterpret_cast<LoadT*>(&dest_vec[0]);
    *(reinterpret_cast<MaskLoadT*>(&mask_data[i])) =
        *reinterpret_cast<MaskLoadT*>(&mask_vec[0]);
Y
yaoxuefeng 已提交
129 130 131
  }
}

132 133 134
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
K
Kexin Zhao 已提交
135
template <typename Place, typename T>
Y
Yu Yang 已提交
136
class GPUDropoutKernel : public framework::OpKernel<T> {
137 138 139
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
M
mapingshuo 已提交
140 141
    auto* seed =
        context.HasInput("Seed") ? context.Input<Tensor>("Seed") : nullptr;
142 143
    auto* y = context.Output<Tensor>("Out");
    y->mutable_data<T>(context.GetPlace());
K
Kexin Zhao 已提交
144
    float dropout_prob = context.Attr<float>("dropout_prob");
145

Z
Zeng Jinle 已提交
146
    auto& dropout_implementation =
P
phlrain 已提交
147
        context.Attr<std::string>("dropout_implementation");
Z
Zeng Jinle 已提交
148 149
    bool upscale_in_train = (dropout_implementation == "upscale_in_train");

Q
QI JUN 已提交
150
    auto& place = *context.template device_context<Place>().eigen_device();
151
    if (!context.Attr<bool>("is_test")) {
Z
Zeng Jinle 已提交
152 153 154
      int64_t x_numel = x->numel();
      auto stream = context.cuda_device_context().stream();

155
      auto* mask = context.Output<Tensor>("Mask");
Z
Zeng Jinle 已提交
156
      auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
D
dzhwinter 已提交
157 158 159
      size_t size = framework::product(mask->dims());
      auto* x_data = x->data<T>();
      auto* y_data = y->mutable_data<T>(context.GetPlace());
Z
Zeng Jinle 已提交
160
      if (dropout_prob == 1.0f) {
161 162 163 164 165 166
#ifdef PADDLE_WITH_HIP
        PADDLE_ENFORCE_CUDA_SUCCESS(
            hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
        PADDLE_ENFORCE_CUDA_SUCCESS(
            hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
#else
167 168 169 170
        PADDLE_ENFORCE_CUDA_SUCCESS(
            cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
        PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(
            mask_data, 0, x_numel * sizeof(*mask_data), stream));
171
#endif
Z
Zeng Jinle 已提交
172 173
        return;
      }
174

Z
Zhang Ting 已提交
175
      const auto& dev_ctx = context.cuda_device_context();
Z
Zhang Ting 已提交
176 177
      platform::GpuLaunchConfig config =
          platform::GetGpuLaunchConfig1D(dev_ctx, size);
Z
Zhang Ting 已提交
178 179 180 181 182 183 184 185 186 187 188

      // increment is used to set the args(offset) of curand_init, which defines
      // offset in subsequence.
      // The detail:
      // https://docs.nvidia.com/cuda/curand/device-api-overview.html
      // Increment should be at least the number of curand() random numbers used
      // in each thread to avoid the random number generated this time being the
      // same as the previous calls.
      uint64_t seed_data;
      uint64_t increment;
      int vec_size = VectorizedSize<T>(x_data);
Z
Zhang Ting 已提交
189 190 191 192
      auto offset = ((x_numel - 1) / (config.block_per_grid.x *
                                      config.thread_per_block.x * vec_size) +
                     1) *
                    vec_size;
Z
Zhang Ting 已提交
193 194 195 196
      int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
                          .GetDeviceId();
      auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);

197
      if (seed && platform::is_gpu_place(seed->place())) {
Z
Zhang Ting 已提交
198 199 200 201 202 203 204 205
        framework::Tensor seed_cpu_tensor;
        TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
        seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
        increment = offset;
      } else if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
        auto seed_offset = gen_cuda->IncrementOffset(offset);
        seed_data = seed_offset.first;
        increment = seed_offset.second;
206
      } else {
Z
Zhang Ting 已提交
207 208 209 210 211 212 213 214
        if (seed) {
          seed_data = *(seed->data<int>());
        } else {
          std::random_device rnd;
          seed_data = context.Attr<bool>("fix_seed") ? context.Attr<int>("seed")
                                                     : rnd();
        }
        increment = offset;
215 216
      }

217 218 219 220 221 222 223 224 225 226 227 228 229 230
#ifdef __HIPCC__
      if (vec_size == 4 && size % 4 == 0) {
        hipLaunchKernelGGL(
            HIP_KERNEL_NAME(VectorizedRandomGenerator<T, uint8_t, 4>),
            config.block_per_grid, config.thread_per_block, 0, stream, size,
            seed_data, dropout_prob, x_data, mask_data, y_data,
            upscale_in_train, increment);
      } else {
        hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomGenerator<T, uint8_t>),
                           config.block_per_grid, config.thread_per_block, 0,
                           stream, size, seed_data, dropout_prob, x_data,
                           mask_data, y_data, upscale_in_train, increment);
      }
#else
Z
Zhang Ting 已提交
231 232 233 234
      if (vec_size == 4 && size % 4 == 0) {
        VectorizedRandomGenerator<
            T, uint8_t,
            4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
Z
Zhang Ting 已提交
235 236 237
            size, seed_data, dropout_prob, x_data, mask_data, y_data,
            upscale_in_train, increment);
      } else {
Z
Zhang Ting 已提交
238 239
        RandomGenerator<T, uint8_t><<<config.block_per_grid,
                                      config.thread_per_block, 0, stream>>>(
Z
Zhang Ting 已提交
240 241
            size, seed_data, dropout_prob, x_data, mask_data, y_data,
            upscale_in_train, increment);
Y
yaoxuefeng 已提交
242
      }
243
#endif
244
    } else {
245 246
      auto X = EigenMatrix<T>::Reshape(*x, 1);
      auto Y = EigenMatrix<T>::Reshape(*y, 1);
Z
Zeng Jinle 已提交
247
      if (upscale_in_train) {
P
phlrain 已提交
248 249 250 251
        Y.device(place) = X;
      } else {
        Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
      }
252
    }
253 254 255 256 257 258
  }
};

}  // namespace operators
}  // namespace paddle

X
Xinghai Sun 已提交
259
namespace ops = paddle::operators;
K
Kexin Zhao 已提交
260
namespace plat = paddle::platform;
Q
QI JUN 已提交
261
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
262
    dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
P
phlrain 已提交
263 264 265 266
    ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
    ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
267
    ops::DropoutGradKernel<plat::CUDADeviceContext, plat::float16>,
P
phlrain 已提交
268
    ops::DropoutGradKernel<plat::CUDADeviceContext, double>);