dropout_op.cu 7.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15
#include <cuda.h>
#include <curand_kernel.h>
16 17 18 19
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
P
phlrain 已提交
20
#include <string>
21
#include "paddle/fluid/memory/memcpy.h"
Y
Yi Wang 已提交
22
#include "paddle/fluid/operators/dropout_op.h"
23
#include "paddle/fluid/platform/dynload/curand.h"
K
Kexin Zhao 已提交
24
#include "paddle/fluid/platform/float16.h"
25

26 27 28
namespace paddle {
namespace operators {

Z
Zeng Jinle 已提交
29
template <typename T, typename MaskType>
30 31
__global__ void RandomGenerator(const size_t n, const int seed,
                                const float dropout_prob, const T* src,
Z
Zeng Jinle 已提交
32
                                MaskType* mask_data, T* dst,
P
phlrain 已提交
33
                                bool is_upscale_in_train) {
34
  curandStatePhilox4_32_10_t state;
D
dzhwinter 已提交
35
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
36 37
  int step_size = 0;

Z
Zeng Jinle 已提交
38
  MaskType mask;
39
  T dest;
D
dzhwinter 已提交
40
  for (; idx < n; idx += blockDim.x * gridDim.x) {
41 42
    T s = src[idx];
    if (step_size == 0) {
43
      curand_init(seed, idx, idx, &state);
44 45
      step_size = blockDim.x * gridDim.x;
    } else {
46
      curand_init(seed, idx, step_size, &state);
47
    }
48
    if (curand_uniform(&state) < dropout_prob) {
Z
Zeng Jinle 已提交
49 50
      mask = 0;
      dest = 0;
51
    } else {
Z
Zeng Jinle 已提交
52
      mask = 1;
P
phlrain 已提交
53
      if (is_upscale_in_train) {
Z
Zeng Jinle 已提交
54
        dest = s / static_cast<T>(1.0f - dropout_prob);
P
phlrain 已提交
55
      } else {
Z
Zeng Jinle 已提交
56
        dest = s;
P
phlrain 已提交
57
      }
58 59 60
    }
    mask_data[idx] = mask;
    dst[idx] = dest;
61
  }
D
dzhwinter 已提交
62
}
63

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
template <typename T, typename MaskType>
__global__ void RandomGeneratorWithSeed(const size_t n, const int* seed,
                                        const float dropout_prob, const T* src,
                                        MaskType* mask_data, T* dst,
                                        bool is_upscale_in_train) {
  curandStatePhilox4_32_10_t state;
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  int step_size = 0;

  MaskType mask;
  T dest;
  for (; idx < n; idx += blockDim.x * gridDim.x) {
    T s = src[idx];
    if (step_size == 0) {
      curand_init(seed[0], idx, idx, &state);
      step_size = blockDim.x * gridDim.x;
    } else {
      curand_init(seed[0], idx, step_size, &state);
    }
    if (curand_uniform(&state) < dropout_prob) {
      mask = 0;
      dest = 0;
    } else {
      mask = 1;
      if (is_upscale_in_train) {
        dest = s / static_cast<T>(1.0f - dropout_prob);
      } else {
        dest = s;
      }
    }
    mask_data[idx] = mask;
    dst[idx] = dest;
  }
}

Y
yaoxuefeng 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
template <typename T, typename MaskType>
__global__ void RandomGeneratorWithGenerator(const size_t n, uint64_t seed,
                                             const float dropout_prob,
                                             const T* src, MaskType* mask_data,
                                             T* dst, bool is_upscale_in_train,
                                             uint64_t increment) {
  curandStatePhilox4_32_10_t state;
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  int step_size = 0;

  MaskType mask;
  T dest;
  for (; idx < n; idx += blockDim.x * gridDim.x) {
    T s = src[idx];
    if (step_size == 0) {
      curand_init(seed, idx, increment, &state);
      step_size = blockDim.x * gridDim.x;
    } else {
      curand_init(seed, idx, increment, &state);
    }
    if (curand_uniform(&state) < dropout_prob) {
      mask = 0;
      dest = 0;
    } else {
      mask = 1;
      if (is_upscale_in_train) {
        dest = s / static_cast<T>(1.0f - dropout_prob);
      } else {
        dest = s;
      }
    }
    mask_data[idx] = mask;
    dst[idx] = dest;
  }
}

135 136 137
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
K
Kexin Zhao 已提交
138
template <typename Place, typename T>
Y
Yu Yang 已提交
139
class GPUDropoutKernel : public framework::OpKernel<T> {
140 141 142
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
M
mapingshuo 已提交
143 144
    auto* seed =
        context.HasInput("Seed") ? context.Input<Tensor>("Seed") : nullptr;
145 146
    auto* y = context.Output<Tensor>("Out");
    y->mutable_data<T>(context.GetPlace());
K
Kexin Zhao 已提交
147
    float dropout_prob = context.Attr<float>("dropout_prob");
148

Z
Zeng Jinle 已提交
149
    auto& dropout_implementation =
P
phlrain 已提交
150
        context.Attr<std::string>("dropout_implementation");
Z
Zeng Jinle 已提交
151 152
    bool upscale_in_train = (dropout_implementation == "upscale_in_train");

Q
QI JUN 已提交
153
    auto& place = *context.template device_context<Place>().eigen_device();
154
    if (!context.Attr<bool>("is_test")) {
Z
Zeng Jinle 已提交
155 156 157
      int64_t x_numel = x->numel();
      auto stream = context.cuda_device_context().stream();

158
      auto* mask = context.Output<Tensor>("Mask");
Z
Zeng Jinle 已提交
159
      auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
D
dzhwinter 已提交
160 161 162
      size_t size = framework::product(mask->dims());
      auto* x_data = x->data<T>();
      auto* y_data = y->mutable_data<T>(context.GetPlace());
Z
Zeng Jinle 已提交
163
      if (dropout_prob == 1.0f) {
164 165 166 167
        PADDLE_ENFORCE_CUDA_SUCCESS(
            cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
        PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(
            mask_data, 0, x_numel * sizeof(*mask_data), stream));
Z
Zeng Jinle 已提交
168 169
        return;
      }
170

D
dzhwinter 已提交
171
      int threads = 512;
Z
Zeng Jinle 已提交
172
      int grid = (x_numel + threads - 1) / threads;
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
      if (seed && platform::is_gpu_place(seed->place())) {
        auto seed_gpu_data = seed->data<int>();
        RandomGeneratorWithSeed<T, uint8_t><<<grid, threads, 0, stream>>>(
            size, seed_gpu_data, dropout_prob, x_data, mask_data, y_data,
            upscale_in_train);
        return;
      }
      int seed_data;
      std::random_device rnd;
      if (seed) {
        seed_data = *(seed->data<int>());
      } else {
        seed_data =
            context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
      }

Y
yaoxuefeng 已提交
189 190 191 192 193 194 195 196 197 198 199
      int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
                          .GetDeviceId();
      auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
      if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
        auto seed_offset = gen_cuda->IncrementOffset(1);
        RandomGeneratorWithGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
            size, seed_offset.first, dropout_prob, x_data, mask_data, y_data,
            upscale_in_train, seed_offset.second);
        return;
      }

Z
Zeng Jinle 已提交
200
      RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
M
mapingshuo 已提交
201
          size, seed_data, dropout_prob, x_data, mask_data, y_data,
Z
Zeng Jinle 已提交
202
          upscale_in_train);
203
    } else {
204 205
      auto X = EigenMatrix<T>::Reshape(*x, 1);
      auto Y = EigenMatrix<T>::Reshape(*y, 1);
Z
Zeng Jinle 已提交
206
      if (upscale_in_train) {
P
phlrain 已提交
207 208 209 210
        Y.device(place) = X;
      } else {
        Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
      }
211
    }
212 213 214 215 216 217
  }
};

}  // namespace operators
}  // namespace paddle

X
Xinghai Sun 已提交
218
namespace ops = paddle::operators;
K
Kexin Zhao 已提交
219
namespace plat = paddle::platform;
Q
QI JUN 已提交
220
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
221
    dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
P
phlrain 已提交
222 223 224 225
    ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
    ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
226
    ops::DropoutGradKernel<plat::CUDADeviceContext, plat::float16>,
P
phlrain 已提交
227
    ops::DropoutGradKernel<plat::CUDADeviceContext, double>);