dropout_op.cu 4.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15 16 17
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
P
phlrain 已提交
18
#include <string>
Y
Yi Wang 已提交
19
#include "paddle/fluid/operators/dropout_op.h"
K
Kexin Zhao 已提交
20
#include "paddle/fluid/platform/float16.h"
X
Xinghai Sun 已提交
21

22 23 24
namespace paddle {
namespace operators {

K
Kexin Zhao 已提交
25
template <typename T>
26 27
__global__ void RandomGenerator(const size_t n, const int seed,
                                const float dropout_prob, const T* src,
P
phlrain 已提交
28
                                T* mask_data, T* dst,
P
phlrain 已提交
29
                                bool is_upscale_in_train) {
30 31 32 33
  thrust::minstd_rand rng;
  rng.seed(seed);
  thrust::uniform_real_distribution<float> dist(0, 1);

D
dzhwinter 已提交
34
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
35 36 37 38
  int step_size = 0;

  T mask;
  T dest;
D
dzhwinter 已提交
39
  for (; idx < n; idx += blockDim.x * gridDim.x) {
40 41 42 43 44 45 46 47 48 49
    T s = src[idx];
    if (step_size == 0) {
      rng.discard(idx);
      step_size = blockDim.x * gridDim.x;
    } else {
      rng.discard(step_size);
    }
    if (dist(rng) < dropout_prob) {
      mask = static_cast<T>(0);
    } else {
P
phlrain 已提交
50
      if (is_upscale_in_train) {
P
phlrain 已提交
51 52 53 54
        mask = static_cast<T>(1.0f / (1.0f - dropout_prob));
      } else {
        mask = static_cast<T>(1);
      }
55 56 57 58
    }
    dest = s * mask;
    mask_data[idx] = mask;
    dst[idx] = dest;
59
  }
D
dzhwinter 已提交
60
}
61 62 63 64

// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
K
Kexin Zhao 已提交
65
template <typename Place, typename T>
Y
Yu Yang 已提交
66
class GPUDropoutKernel : public framework::OpKernel<T> {
67 68 69 70 71
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
    auto* y = context.Output<Tensor>("Out");
    y->mutable_data<T>(context.GetPlace());
K
Kexin Zhao 已提交
72
    float dropout_prob = context.Attr<float>("dropout_prob");
73

P
phlrain 已提交
74 75
    auto dropout_implementation =
        context.Attr<std::string>("dropout_implementation");
Q
QI JUN 已提交
76
    auto& place = *context.template device_context<Place>().eigen_device();
77
    if (!context.Attr<bool>("is_test")) {
78 79
      auto* mask = context.Output<Tensor>("Mask");
      auto* mask_data = mask->mutable_data<T>(context.GetPlace());
D
dzhwinter 已提交
80 81 82
      size_t size = framework::product(mask->dims());
      auto* x_data = x->data<T>();
      auto* y_data = y->mutable_data<T>(context.GetPlace());
83 84 85 86 87

      std::random_device rnd;
      int seed =
          context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();

D
dzhwinter 已提交
88 89
      int threads = 512;
      int grid = (x->numel() + threads - 1) / threads;
K
Kexin Zhao 已提交
90 91
      RandomGenerator<
          T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
P
phlrain 已提交
92
          size, seed, dropout_prob, x_data, mask_data, y_data,
P
phlrain 已提交
93
          (dropout_implementation == "upscale_in_train"));
94
    } else {
95 96
      auto X = EigenMatrix<T>::Reshape(*x, 1);
      auto Y = EigenMatrix<T>::Reshape(*y, 1);
P
phlrain 已提交
97
      if (dropout_implementation == "upscale_in_train") {
P
phlrain 已提交
98 99 100 101
        Y.device(place) = X;
      } else {
        Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
      }
102
    }
103 104 105 106 107 108
  }
};

}  // namespace operators
}  // namespace paddle

X
Xinghai Sun 已提交
109
namespace ops = paddle::operators;
K
Kexin Zhao 已提交
110
namespace plat = paddle::platform;
Q
QI JUN 已提交
111
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
112
    dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
P
phlrain 已提交
113 114 115 116 117
    ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
    ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
    ops::DropoutGradKernel<plat::CUDADeviceContext, double>);