dropout_op.cu 3.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
X
Xinghai Sun 已提交
14 15

#define EIGEN_USE_GPU
16 17 18 19
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
Y
Yi Wang 已提交
20
#include "paddle/fluid/operators/dropout_op.h"
K
Kexin Zhao 已提交
21
#include "paddle/fluid/platform/float16.h"
X
Xinghai Sun 已提交
22

23 24 25
namespace paddle {
namespace operators {

K
Kexin Zhao 已提交
26
template <typename T>
D
dzhwinter 已提交
27
__global__ void RandomGenerator(const size_t n, const int seed,
K
Kexin Zhao 已提交
28
                                const float dropout_prob, const T* src,
D
dzhwinter 已提交
29 30 31
                                T* mask_data, T* dst) {
  thrust::minstd_rand rng;
  rng.seed(seed);
K
Kexin Zhao 已提交
32
  thrust::uniform_real_distribution<float> dist(0, 1);
33

D
dzhwinter 已提交
34 35
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  for (; idx < n; idx += blockDim.x * gridDim.x) {
D
dzhwinter 已提交
36
    rng.discard(idx);
37
    if (dist(rng) < dropout_prob) {
D
dzhwinter 已提交
38 39 40
      mask_data[idx] = static_cast<T>(0);
    } else {
      mask_data[idx] = static_cast<T>(1);
41
    }
D
dzhwinter 已提交
42
    dst[idx] = mask_data[idx] * src[idx];
43
  }
D
dzhwinter 已提交
44
}
45 46 47 48

// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
K
Kexin Zhao 已提交
49
template <typename Place, typename T>
Y
Yu Yang 已提交
50
class GPUDropoutKernel : public framework::OpKernel<T> {
51 52 53 54 55
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
    auto* y = context.Output<Tensor>("Out");
    y->mutable_data<T>(context.GetPlace());
K
Kexin Zhao 已提交
56
    float dropout_prob = context.Attr<float>("dropout_prob");
57

Q
QI JUN 已提交
58
    auto& place = *context.template device_context<Place>().eigen_device();
59
    if (!context.Attr<bool>("is_test")) {
60 61
      auto* mask = context.Output<Tensor>("Mask");
      auto* mask_data = mask->mutable_data<T>(context.GetPlace());
D
dzhwinter 已提交
62 63 64
      size_t size = framework::product(mask->dims());
      auto* x_data = x->data<T>();
      auto* y_data = y->mutable_data<T>(context.GetPlace());
65 66 67 68 69

      std::random_device rnd;
      int seed =
          context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();

D
dzhwinter 已提交
70 71
      int threads = 512;
      int grid = (x->numel() + threads - 1) / threads;
K
Kexin Zhao 已提交
72 73
      RandomGenerator<
          T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
D
dzhwinter 已提交
74
          size, seed, dropout_prob, x_data, mask_data, y_data);
75
    } else {
G
gongweibao 已提交
76 77
      auto X = EigenMatrix<T>::Reshape(*x, 1);
      auto Y = EigenMatrix<T>::Reshape(*y, 1);
K
Kexin Zhao 已提交
78
      Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
79
    }
80 81 82 83 84 85
  }
};

}  // namespace operators
}  // namespace paddle

X
Xinghai Sun 已提交
86
namespace ops = paddle::operators;
K
Kexin Zhao 已提交
87
namespace plat = paddle::platform;
Q
QI JUN 已提交
88
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
89 90
    dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
    ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
K
Kexin Zhao 已提交
91 92
REGISTER_OP_CUDA_KERNEL(dropout_grad,
                        ops::DropoutGradKernel<plat::CUDADeviceContext, float>);