dropout_op.cu 3.0 KB
Newer Older
X
Xinghai Sun 已提交
1 2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
X
Xinghai Sun 已提交
14 15

#define EIGEN_USE_GPU
16 17 18 19
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
X
Xinghai Sun 已提交
20 21
#include "paddle/operators/dropout_op.h"

22 23 24
namespace paddle {
namespace operators {

25
template <typename T, typename AttrType>
26
struct MaskGenerator {
27
  AttrType dropout_prob;
28 29
  int seed;

30
  __host__ __device__ MaskGenerator(AttrType dropout_prob, int seed)
31 32 33 34 35
      : dropout_prob(dropout_prob), seed(seed) {}

  __host__ __device__ T operator()(const unsigned int n) const {
    thrust::minstd_rand rng;
    rng.seed(seed);
36
    thrust::uniform_real_distribution<AttrType> dist(0, 1);
37 38 39 40 41 42 43 44 45 46 47 48
    rng.discard(n);
    if (dist(rng) < dropout_prob) {
      return static_cast<T>(0);
    } else {
      return static_cast<T>(1);
    }
  }
};

// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
49
template <typename Place, typename T, typename AttrType>
Y
Yu Yang 已提交
50
class GPUDropoutKernel : public framework::OpKernel<T> {
51 52 53 54 55
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
    auto* y = context.Output<Tensor>("Out");
    y->mutable_data<T>(context.GetPlace());
56
    AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
57

58 59
    auto X = EigenMatrix<T>::Reshape(*x, 1);
    auto Y = EigenMatrix<T>::Reshape(*y, 1);
60

Q
QI JUN 已提交
61
    auto& place = *context.template device_context<Place>().eigen_device();
62
    if (!context.Attr<bool>("is_test")) {
63 64 65
      auto* mask = context.Output<Tensor>("Mask");
      auto* mask_data = mask->mutable_data<T>(context.GetPlace());
      int size = framework::product(mask->dims());
66 67 68 69 70
      int seed = context.Attr<int>("seed");
      thrust::counting_iterator<unsigned int> index_sequence_begin(0);
      thrust::transform(index_sequence_begin, index_sequence_begin + size,
                        thrust::device_ptr<T>(mask_data),
                        MaskGenerator<T, AttrType>(dropout_prob, seed));
71
      auto M = EigenMatrix<T>::Reshape(*mask, 1);
72 73
      Y.device(place) = X * M;
    } else {
74
      Y.device(place) = X * (1.0f - dropout_prob);
75
    }
76 77 78 79 80 81
  }
};

}  // namespace operators
}  // namespace paddle

X
Xinghai Sun 已提交
82
namespace ops = paddle::operators;
Q
QI JUN 已提交
83 84 85 86 87 88
REGISTER_OP_CUDA_KERNEL(
    dropout,
    ops::GPUDropoutKernel<paddle::platform::CUDADeviceContext, float, float>);
REGISTER_OP_CUDA_KERNEL(
    dropout_grad,
    ops::DropoutGradKernel<paddle::platform::CUDADeviceContext, float>);