dropout_op.cu 4.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15 16 17
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
P
phlrain 已提交
18
#include <string>
Y
Yi Wang 已提交
19
#include "paddle/fluid/operators/dropout_op.h"
K
Kexin Zhao 已提交
20
#include "paddle/fluid/platform/float16.h"
X
Xinghai Sun 已提交
21

22 23 24
namespace paddle {
namespace operators {

Z
Zeng Jinle 已提交
25
template <typename T, typename MaskType>
26 27
__global__ void RandomGenerator(const size_t n, const int seed,
                                const float dropout_prob, const T* src,
Z
Zeng Jinle 已提交
28
                                MaskType* mask_data, T* dst,
P
phlrain 已提交
29
                                bool is_upscale_in_train) {
30 31 32 33
  thrust::minstd_rand rng;
  rng.seed(seed);
  thrust::uniform_real_distribution<float> dist(0, 1);

D
dzhwinter 已提交
34
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
35 36
  int step_size = 0;

Z
Zeng Jinle 已提交
37
  MaskType mask;
38
  T dest;
D
dzhwinter 已提交
39
  for (; idx < n; idx += blockDim.x * gridDim.x) {
40 41 42 43 44 45 46 47
    T s = src[idx];
    if (step_size == 0) {
      rng.discard(idx);
      step_size = blockDim.x * gridDim.x;
    } else {
      rng.discard(step_size);
    }
    if (dist(rng) < dropout_prob) {
Z
Zeng Jinle 已提交
48 49
      mask = 0;
      dest = 0;
50
    } else {
Z
Zeng Jinle 已提交
51
      mask = 1;
P
phlrain 已提交
52
      if (is_upscale_in_train) {
Z
Zeng Jinle 已提交
53
        dest = s / static_cast<T>(1.0f - dropout_prob);
P
phlrain 已提交
54
      } else {
Z
Zeng Jinle 已提交
55
        dest = s;
P
phlrain 已提交
56
      }
57 58 59
    }
    mask_data[idx] = mask;
    dst[idx] = dest;
60
  }
D
dzhwinter 已提交
61
}
62 63 64 65

// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
K
Kexin Zhao 已提交
66
template <typename Place, typename T>
Y
Yu Yang 已提交
67
class GPUDropoutKernel : public framework::OpKernel<T> {
68 69 70 71 72
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
    auto* y = context.Output<Tensor>("Out");
    y->mutable_data<T>(context.GetPlace());
K
Kexin Zhao 已提交
73
    float dropout_prob = context.Attr<float>("dropout_prob");
74

Z
Zeng Jinle 已提交
75
    auto& dropout_implementation =
P
phlrain 已提交
76
        context.Attr<std::string>("dropout_implementation");
Z
Zeng Jinle 已提交
77 78
    bool upscale_in_train = (dropout_implementation == "upscale_in_train");

Q
QI JUN 已提交
79
    auto& place = *context.template device_context<Place>().eigen_device();
80
    if (!context.Attr<bool>("is_test")) {
Z
Zeng Jinle 已提交
81 82 83
      int64_t x_numel = x->numel();
      auto stream = context.cuda_device_context().stream();

84
      auto* mask = context.Output<Tensor>("Mask");
Z
Zeng Jinle 已提交
85
      auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
D
dzhwinter 已提交
86 87 88
      size_t size = framework::product(mask->dims());
      auto* x_data = x->data<T>();
      auto* y_data = y->mutable_data<T>(context.GetPlace());
Z
Zeng Jinle 已提交
89 90 91 92 93 94
      if (dropout_prob == 1.0f) {
        PADDLE_ENFORCE(cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
        PADDLE_ENFORCE(cudaMemsetAsync(mask_data, 0,
                                       x_numel * sizeof(*mask_data), stream));
        return;
      }
95 96 97 98 99

      std::random_device rnd;
      int seed =
          context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();

D
dzhwinter 已提交
100
      int threads = 512;
Z
Zeng Jinle 已提交
101 102
      int grid = (x_numel + threads - 1) / threads;
      RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
P
phlrain 已提交
103
          size, seed, dropout_prob, x_data, mask_data, y_data,
Z
Zeng Jinle 已提交
104
          upscale_in_train);
105
    } else {
106 107
      auto X = EigenMatrix<T>::Reshape(*x, 1);
      auto Y = EigenMatrix<T>::Reshape(*y, 1);
Z
Zeng Jinle 已提交
108
      if (upscale_in_train) {
P
phlrain 已提交
109 110 111 112
        Y.device(place) = X;
      } else {
        Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
      }
113
    }
114 115 116 117 118 119
  }
};

}  // namespace operators
}  // namespace paddle

X
Xinghai Sun 已提交
120
namespace ops = paddle::operators;
K
Kexin Zhao 已提交
121
namespace plat = paddle::platform;
Q
QI JUN 已提交
122
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
123
    dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
P
phlrain 已提交
124 125 126 127
    ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
    ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
128
    ops::DropoutGradKernel<plat::CUDADeviceContext, plat::float16>,
P
phlrain 已提交
129
    ops::DropoutGradKernel<plat::CUDADeviceContext, double>);