dropout_op.cu 5.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15
#include <cuda.h>
#include <curand_kernel.h>
16 17 18 19
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
P
phlrain 已提交
20
#include <string>
Y
Yi Wang 已提交
21
#include "paddle/fluid/operators/dropout_op.h"
22
#include "paddle/fluid/platform/dynload/curand.h"
K
Kexin Zhao 已提交
23
#include "paddle/fluid/platform/float16.h"
24 25 26
namespace paddle {
namespace operators {

Z
Zeng Jinle 已提交
27
template <typename T, typename MaskType>
28 29
__global__ void RandomGenerator(const size_t n, const int seed,
                                const float dropout_prob, const T* src,
Z
Zeng Jinle 已提交
30
                                MaskType* mask_data, T* dst,
P
phlrain 已提交
31
                                bool is_upscale_in_train) {
32
  curandStatePhilox4_32_10_t state;
D
dzhwinter 已提交
33
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
34 35
  int step_size = 0;

Z
Zeng Jinle 已提交
36
  MaskType mask;
37
  T dest;
D
dzhwinter 已提交
38
  for (; idx < n; idx += blockDim.x * gridDim.x) {
39 40
    T s = src[idx];
    if (step_size == 0) {
41
      curand_init(seed, idx, idx, &state);
42 43
      step_size = blockDim.x * gridDim.x;
    } else {
44
      curand_init(seed, idx, step_size, &state);
45
    }
46
    if (curand_uniform(&state) < dropout_prob) {
Z
Zeng Jinle 已提交
47 48
      mask = 0;
      dest = 0;
49
    } else {
Z
Zeng Jinle 已提交
50
      mask = 1;
P
phlrain 已提交
51
      if (is_upscale_in_train) {
Z
Zeng Jinle 已提交
52
        dest = s / static_cast<T>(1.0f - dropout_prob);
P
phlrain 已提交
53
      } else {
Z
Zeng Jinle 已提交
54
        dest = s;
P
phlrain 已提交
55
      }
56 57 58
    }
    mask_data[idx] = mask;
    dst[idx] = dest;
59
  }
D
dzhwinter 已提交
60
}
61 62 63 64

// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
K
Kexin Zhao 已提交
65
template <typename Place, typename T>
Y
Yu Yang 已提交
66
class GPUDropoutKernel : public framework::OpKernel<T> {
67 68 69
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
M
mapingshuo 已提交
70 71
    auto* seed =
        context.HasInput("Seed") ? context.Input<Tensor>("Seed") : nullptr;
72 73
    auto* y = context.Output<Tensor>("Out");
    y->mutable_data<T>(context.GetPlace());
K
Kexin Zhao 已提交
74
    float dropout_prob = context.Attr<float>("dropout_prob");
75

Z
Zeng Jinle 已提交
76
    auto& dropout_implementation =
P
phlrain 已提交
77
        context.Attr<std::string>("dropout_implementation");
Z
Zeng Jinle 已提交
78 79
    bool upscale_in_train = (dropout_implementation == "upscale_in_train");

Q
QI JUN 已提交
80
    auto& place = *context.template device_context<Place>().eigen_device();
81
    if (!context.Attr<bool>("is_test")) {
Z
Zeng Jinle 已提交
82 83 84
      int64_t x_numel = x->numel();
      auto stream = context.cuda_device_context().stream();

85
      auto* mask = context.Output<Tensor>("Mask");
Z
Zeng Jinle 已提交
86
      auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
D
dzhwinter 已提交
87 88
      size_t size = framework::product(mask->dims());
      auto* x_data = x->data<T>();
M
mapingshuo 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102
      int seed_data;
      std::random_device rnd;
      if (seed) {
        if (platform::is_gpu_place(seed->place())) {
          framework::Tensor temp;
          TensorCopySync(*seed, platform::CPUPlace(), &temp);
          seed_data = *(temp.data<int>());
        } else {
          seed_data = *(seed->data<int>());
        }
      } else {
        seed_data =
            context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
      }
D
dzhwinter 已提交
103
      auto* y_data = y->mutable_data<T>(context.GetPlace());
Z
Zeng Jinle 已提交
104
      if (dropout_prob == 1.0f) {
105 106 107 108
        PADDLE_ENFORCE_CUDA_SUCCESS(
            cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
        PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(
            mask_data, 0, x_numel * sizeof(*mask_data), stream));
Z
Zeng Jinle 已提交
109 110
        return;
      }
111

D
dzhwinter 已提交
112
      int threads = 512;
Z
Zeng Jinle 已提交
113 114
      int grid = (x_numel + threads - 1) / threads;
      RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
M
mapingshuo 已提交
115
          size, seed_data, dropout_prob, x_data, mask_data, y_data,
Z
Zeng Jinle 已提交
116
          upscale_in_train);
117
    } else {
118 119
      auto X = EigenMatrix<T>::Reshape(*x, 1);
      auto Y = EigenMatrix<T>::Reshape(*y, 1);
Z
Zeng Jinle 已提交
120
      if (upscale_in_train) {
P
phlrain 已提交
121 122 123 124
        Y.device(place) = X;
      } else {
        Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
      }
125
    }
126 127 128 129 130 131
  }
};

}  // namespace operators
}  // namespace paddle

X
Xinghai Sun 已提交
132
namespace ops = paddle::operators;
K
Kexin Zhao 已提交
133
namespace plat = paddle::platform;
Q
QI JUN 已提交
134
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
135
    dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
P
phlrain 已提交
136 137 138 139
    ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
    ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
140
    ops::DropoutGradKernel<plat::CUDADeviceContext, plat::float16>,
P
phlrain 已提交
141
    ops::DropoutGradKernel<plat::CUDADeviceContext, double>);