dropout_op.cu 4.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
X
Xinghai Sun 已提交
14 15

#define EIGEN_USE_GPU
16 17 18 19
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
Y
Yi Wang 已提交
20
#include "paddle/fluid/operators/dropout_op.h"
K
Kexin Zhao 已提交
21
#include "paddle/fluid/platform/float16.h"
X
Xinghai Sun 已提交
22

23 24 25
namespace paddle {
namespace operators {

K
Kexin Zhao 已提交
26
template <typename T>
27 28
__global__ void RandomGenerator(const size_t n, const int seed,
                                const float dropout_prob, const T* src,
P
phlrain 已提交
29 30
                                T* mask_data, T* dst,
                                bool dropout_implementation) {
31 32 33 34
  thrust::minstd_rand rng;
  rng.seed(seed);
  thrust::uniform_real_distribution<float> dist(0, 1);

D
dzhwinter 已提交
35
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
36 37 38 39
  int step_size = 0;

  T mask;
  T dest;
D
dzhwinter 已提交
40
  for (; idx < n; idx += blockDim.x * gridDim.x) {
41 42 43 44 45 46 47 48 49 50
    T s = src[idx];
    if (step_size == 0) {
      rng.discard(idx);
      step_size = blockDim.x * gridDim.x;
    } else {
      rng.discard(step_size);
    }
    if (dist(rng) < dropout_prob) {
      mask = static_cast<T>(0);
    } else {
P
phlrain 已提交
51 52 53 54 55
      if (dropout_implementation) {
        mask = static_cast<T>(1.0f / (1.0f - dropout_prob));
      } else {
        mask = static_cast<T>(1);
      }
56 57 58 59
    }
    dest = s * mask;
    mask_data[idx] = mask;
    dst[idx] = dest;
60
  }
D
dzhwinter 已提交
61
}
62 63 64 65

// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
K
Kexin Zhao 已提交
66
template <typename Place, typename T>
Y
Yu Yang 已提交
67
class GPUDropoutKernel : public framework::OpKernel<T> {
68 69 70 71 72
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
    auto* y = context.Output<Tensor>("Out");
    y->mutable_data<T>(context.GetPlace());
K
Kexin Zhao 已提交
73
    float dropout_prob = context.Attr<float>("dropout_prob");
74

P
phlrain 已提交
75
    auto dropout_implementation = context.Attr<bool>("dropout_implementation");
Q
QI JUN 已提交
76
    auto& place = *context.template device_context<Place>().eigen_device();
77
    if (!context.Attr<bool>("is_test")) {
78 79
      auto* mask = context.Output<Tensor>("Mask");
      auto* mask_data = mask->mutable_data<T>(context.GetPlace());
D
dzhwinter 已提交
80 81 82
      size_t size = framework::product(mask->dims());
      auto* x_data = x->data<T>();
      auto* y_data = y->mutable_data<T>(context.GetPlace());
83 84 85 86 87

      std::random_device rnd;
      int seed =
          context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();

D
dzhwinter 已提交
88 89
      int threads = 512;
      int grid = (x->numel() + threads - 1) / threads;
K
Kexin Zhao 已提交
90 91
      RandomGenerator<
          T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
P
phlrain 已提交
92 93
          size, seed, dropout_prob, x_data, mask_data, y_data,
          dropout_implementation);
94
    } else {
95 96
      auto X = EigenMatrix<T>::Reshape(*x, 1);
      auto Y = EigenMatrix<T>::Reshape(*y, 1);
P
phlrain 已提交
97 98 99 100 101
      if (dropout_implementation) {
        Y.device(place) = X;
      } else {
        Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
      }
102
    }
103 104 105 106 107 108
  }
};

}  // namespace operators
}  // namespace paddle

X
Xinghai Sun 已提交
109
namespace ops = paddle::operators;
K
Kexin Zhao 已提交
110
namespace plat = paddle::platform;
Q
QI JUN 已提交
111
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
112
    dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
P
phlrain 已提交
113 114 115 116 117
    ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
    ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
    ops::DropoutGradKernel<plat::CUDADeviceContext, double>);