dropout_op.cu 8.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15
#include <cuda.h>
#include <curand_kernel.h>
16 17 18 19
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
Z
Zhang Ting 已提交
20
#include <algorithm>
P
phlrain 已提交
21
#include <string>
22
#include "paddle/fluid/memory/memcpy.h"
Y
Yi Wang 已提交
23
#include "paddle/fluid/operators/dropout_op.h"
24
#include "paddle/fluid/platform/dynload/curand.h"
K
Kexin Zhao 已提交
25
#include "paddle/fluid/platform/float16.h"
26

27 28 29
namespace paddle {
namespace operators {

Z
Zhang Ting 已提交
30 31 32 33 34
// aligned vector generates vectorized load/store on CUDA
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
  T val[Size];
};
35

Z
Zhang Ting 已提交
36 37 38 39 40 41
template <typename T>
inline int VectorizedSize(const T* pointer) {
  uint64_t address = reinterpret_cast<uint64_t>(pointer);
  constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value;  // NOLINT
  if (address % vec4 == 0) {
    return 4;
42
  }
Z
Zhang Ting 已提交
43
  return 1;
D
dzhwinter 已提交
44
}
45

46
template <typename T, typename MaskType>
Z
Zhang Ting 已提交
47 48 49 50
__global__ void RandomGenerator(const size_t n, uint64_t seed,
                                const float dropout_prob, const T* src,
                                MaskType* mask_data, T* dst,
                                bool is_upscale_in_train, uint64_t increment) {
51 52
  curandStatePhilox4_32_10_t state;
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
Z
Zhang Ting 已提交
53
  curand_init(seed, idx, increment, &state);
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

  MaskType mask;
  T dest;
  for (; idx < n; idx += blockDim.x * gridDim.x) {
    T s = src[idx];
    if (curand_uniform(&state) < dropout_prob) {
      mask = 0;
      dest = 0;
    } else {
      mask = 1;
      if (is_upscale_in_train) {
        dest = s / static_cast<T>(1.0f - dropout_prob);
      } else {
        dest = s;
      }
    }
    mask_data[idx] = mask;
    dst[idx] = dest;
  }
}

Z
Zhang Ting 已提交
75 76 77 78 79 80 81
template <typename T, typename MaskType, int VecSize>
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
                                          const float dropout_prob,
                                          const T* src, MaskType* mask_data,
                                          T* dst, bool is_upscale_in_train,
                                          uint64_t increment) {
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
Y
yaoxuefeng 已提交
82
  curandStatePhilox4_32_10_t state;
Z
Zhang Ting 已提交
83
  curand_init(seed, idx, increment, &state);
Y
yaoxuefeng 已提交
84 85 86

  MaskType mask;
  T dest;
Z
Zhang Ting 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
  using LoadT = AlignedVector<T, VecSize>;
  using MaskLoadT = AlignedVector<MaskType, VecSize>;
  T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
  for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
    T src_vec[VecSize];
    LoadT* value = reinterpret_cast<LoadT*>(&src_vec);
    *value = *reinterpret_cast<const LoadT*>(&src[i]);
    float4 rand = curand_uniform4(&state);

    T dest_vec[VecSize];
    MaskType mask_vec[VecSize];

#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
      if ((&rand.x)[ii] < dropout_prob) {
        dest_vec[ii] = 0;
        mask_vec[ii] = 0;
Y
yaoxuefeng 已提交
104
      } else {
Z
Zhang Ting 已提交
105 106 107 108 109 110
        if (is_upscale_in_train) {
          dest_vec[ii] = src_vec[ii] * factor;
        } else {
          dest_vec[ii] = src_vec[ii];
        }
        mask_vec[ii] = 1;
Y
yaoxuefeng 已提交
111 112
      }
    }
Z
Zhang Ting 已提交
113 114 115 116 117

    *(reinterpret_cast<LoadT*>(&dst[i])) =
        *reinterpret_cast<LoadT*>(&dest_vec[0]);
    *(reinterpret_cast<MaskLoadT*>(&mask_data[i])) =
        *reinterpret_cast<MaskLoadT*>(&mask_vec[0]);
Y
yaoxuefeng 已提交
118 119 120
  }
}

121 122 123
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
K
Kexin Zhao 已提交
124
template <typename Place, typename T>
Y
Yu Yang 已提交
125
class GPUDropoutKernel : public framework::OpKernel<T> {
126 127 128
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
M
mapingshuo 已提交
129 130
    auto* seed =
        context.HasInput("Seed") ? context.Input<Tensor>("Seed") : nullptr;
131 132
    auto* y = context.Output<Tensor>("Out");
    y->mutable_data<T>(context.GetPlace());
K
Kexin Zhao 已提交
133
    float dropout_prob = context.Attr<float>("dropout_prob");
134

Z
Zeng Jinle 已提交
135
    auto& dropout_implementation =
P
phlrain 已提交
136
        context.Attr<std::string>("dropout_implementation");
Z
Zeng Jinle 已提交
137 138
    bool upscale_in_train = (dropout_implementation == "upscale_in_train");

Q
QI JUN 已提交
139
    auto& place = *context.template device_context<Place>().eigen_device();
140
    if (!context.Attr<bool>("is_test")) {
Z
Zeng Jinle 已提交
141 142 143
      int64_t x_numel = x->numel();
      auto stream = context.cuda_device_context().stream();

144
      auto* mask = context.Output<Tensor>("Mask");
Z
Zeng Jinle 已提交
145
      auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
D
dzhwinter 已提交
146 147 148
      size_t size = framework::product(mask->dims());
      auto* x_data = x->data<T>();
      auto* y_data = y->mutable_data<T>(context.GetPlace());
Z
Zeng Jinle 已提交
149
      if (dropout_prob == 1.0f) {
150 151 152 153
        PADDLE_ENFORCE_CUDA_SUCCESS(
            cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
        PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(
            mask_data, 0, x_numel * sizeof(*mask_data), stream));
Z
Zeng Jinle 已提交
154 155
        return;
      }
156

D
dzhwinter 已提交
157
      int threads = 512;
Z
Zeng Jinle 已提交
158
      int grid = (x_numel + threads - 1) / threads;
Z
Zhang Ting 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
      const auto& dev_ctx = context.cuda_device_context();
      int blocks_per_sm =
          dev_ctx.GetMaxPhysicalThreadCount() / dev_ctx.GetSMCount() / threads;
      grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid);

      // increment is used to set the args(offset) of curand_init, which defines
      // offset in subsequence.
      // The detail:
      // https://docs.nvidia.com/cuda/curand/device-api-overview.html
      // Increment should be at least the number of curand() random numbers used
      // in each thread to avoid the random number generated this time being the
      // same as the previous calls.
      uint64_t seed_data;
      uint64_t increment;
      int vec_size = VectorizedSize<T>(x_data);
      auto offset =
          ((x_numel - 1) / (threads * grid * vec_size) + 1) * vec_size;
      int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
                          .GetDeviceId();
      auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);

180
      if (seed && platform::is_gpu_place(seed->place())) {
Z
Zhang Ting 已提交
181 182 183 184 185 186 187 188
        framework::Tensor seed_cpu_tensor;
        TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
        seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
        increment = offset;
      } else if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
        auto seed_offset = gen_cuda->IncrementOffset(offset);
        seed_data = seed_offset.first;
        increment = seed_offset.second;
189
      } else {
Z
Zhang Ting 已提交
190 191 192 193 194 195 196 197
        if (seed) {
          seed_data = *(seed->data<int>());
        } else {
          std::random_device rnd;
          seed_data = context.Attr<bool>("fix_seed") ? context.Attr<int>("seed")
                                                     : rnd();
        }
        increment = offset;
198 199
      }

Z
Zhang Ting 已提交
200 201 202 203 204 205 206 207
      if (vec_size == 4) {
        VectorizedRandomGenerator<T, uint8_t, 4><<<grid, threads, 0, stream>>>(
            size, seed_data, dropout_prob, x_data, mask_data, y_data,
            upscale_in_train, increment);
      } else {
        RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
            size, seed_data, dropout_prob, x_data, mask_data, y_data,
            upscale_in_train, increment);
Y
yaoxuefeng 已提交
208 209
      }

210
    } else {
211 212
      auto X = EigenMatrix<T>::Reshape(*x, 1);
      auto Y = EigenMatrix<T>::Reshape(*y, 1);
Z
Zeng Jinle 已提交
213
      if (upscale_in_train) {
P
phlrain 已提交
214 215 216 217
        Y.device(place) = X;
      } else {
        Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
      }
218
    }
219 220 221 222 223 224
  }
};

}  // namespace operators
}  // namespace paddle

X
Xinghai Sun 已提交
225
namespace ops = paddle::operators;
K
Kexin Zhao 已提交
226
namespace plat = paddle::platform;
Q
QI JUN 已提交
227
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
228
    dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
P
phlrain 已提交
229 230 231 232
    ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
    ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
    dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
233
    ops::DropoutGradKernel<plat::CUDADeviceContext, plat::float16>,
P
phlrain 已提交
234
    ops::DropoutGradKernel<plat::CUDADeviceContext, double>);