dropout_op.h 7.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
X
Xinghai Sun 已提交
14
#pragma once
Y
Yi Wang 已提交
15

Z
Zeng Jinle 已提交
16
#include <cstring>
17
#include <random>
P
phlrain 已提交
18
#include <string>
Y
Yi Wang 已提交
19

Z
Zhang Ting 已提交
20
#include <algorithm>
Y
Yi Wang 已提交
21
#include "paddle/fluid/framework/eigen.h"
22
#include "paddle/fluid/framework/generator.h"
Y
Yi Wang 已提交
23
#include "paddle/fluid/framework/op_registry.h"
Z
Zhang Ting 已提交
24
#include "paddle/fluid/platform/gpu_launch_config.h"
X
Xinghai Sun 已提交
25 26 27 28

namespace paddle {
namespace operators {

Z
Zhang Ting 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
// aligned vector generates vectorized load/store on CUDA
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
  T val[Size];
};

template <typename T>
inline int VectorizedSize(const T* pointer) {
  uint64_t address = reinterpret_cast<uint64_t>(pointer);
  constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value;  // NOLINT
  if (address % vec4 == 0) {
    return 4;
  }
  return 1;
}

45
#if defined(__NVCC__) || defined(__HIPCC__)
Z
Zhang Ting 已提交
46 47 48 49 50 51 52 53 54 55 56
template <typename T, typename MaskType, int VecSize>
__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
                                      const T factor, const int64_t size,
                                      T* dx) {
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

  using LoadT = AlignedVector<T, VecSize>;
  using MaskLoadT = AlignedVector<MaskType, VecSize>;

  for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
    T dout_vec[VecSize];
Z
Zhang Ting 已提交
57 58
    LoadT* dout_value = reinterpret_cast<LoadT*>(&dout_vec);
    *dout_value = *reinterpret_cast<const LoadT*>(&dout[i]);
Z
Zhang Ting 已提交
59 60

    MaskType mask_vec[VecSize];
Z
Zhang Ting 已提交
61 62 63 64
    MaskLoadT* mask_value = reinterpret_cast<MaskLoadT*>(&mask_vec);
    *mask_value = *reinterpret_cast<const MaskLoadT*>(&mask[i]);

    T dx_vec[VecSize];
Z
Zhang Ting 已提交
65 66 67 68 69 70 71 72 73 74 75

#pragma unroll
    for (int ii = 0; ii < VecSize; ii++) {
      dx_vec[ii] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
    }

    *(reinterpret_cast<LoadT*>(&dx[i])) = *reinterpret_cast<LoadT*>(&dx_vec[0]);
  }
}
#endif

X
Xinghai Sun 已提交
76 77 78
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
79
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
X
Xinghai Sun 已提交
80

81 82 83 84
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

K
Kexin Zhao 已提交
85
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
86
class CPUDropoutKernel : public framework::OpKernel<T> {
87 88 89
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
M
mapingshuo 已提交
90 91
    auto* seed =
        context.HasInput("Seed") ? context.Input<Tensor>("Seed") : nullptr;
92
    auto* y = context.Output<Tensor>("Out");
93
    const auto* x_data = x->data<T>();
94
    auto* y_data = y->mutable_data<T>(context.GetPlace());
95
    float dropout_prob = context.Attr<float>("dropout_prob");
96

Z
Zeng Jinle 已提交
97
    auto& dropout_implementation =
P
phlrain 已提交
98
        context.Attr<std::string>("dropout_implementation");
Z
Zeng Jinle 已提交
99
    bool upscale_in_train = (dropout_implementation == "upscale_in_train");
100
    if (!context.Attr<bool>("is_test")) {
101
      auto* mask = context.Output<Tensor>("Mask");
Z
Zeng Jinle 已提交
102 103 104 105 106 107 108 109 110
      auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
      size_t size = framework::product(mask->dims());

      // Special case when dropout_prob is 1.0
      if (dropout_prob == 1.0f) {
        std::memset(y_data, 0, size * sizeof(*y_data));        // NOLINT
        std::memset(mask_data, 0, size * sizeof(*mask_data));  // NOLINT
        return;
      }
L
Leo Chen 已提交
111
      // std::minstd_rand engine;
112 113
      // NOTE: fixed seed should only be used in unittest or for debug.
      // Guarantee to use random seed in training.
L
Leo Chen 已提交
114
      int seed_data = 0;
M
mapingshuo 已提交
115 116 117 118
      if (seed) {
        seed_data = *(seed->data<int>());
      } else {
        seed_data =
L
Leo Chen 已提交
119
            context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : 0;
M
mapingshuo 已提交
120
      }
L
Leo Chen 已提交
121
      auto engine = framework::GetCPURandomEngine(seed_data);
122

123
      std::uniform_real_distribution<float> dist(0, 1);
P
phlrain 已提交
124

125
      for (size_t i = 0; i < size; ++i) {
L
Leo Chen 已提交
126
        if (dist(*engine) < dropout_prob) {
127 128 129
          mask_data[i] = 0;
          y_data[i] = 0;
        } else {
Z
Zeng Jinle 已提交
130 131
          mask_data[i] = 1;
          if (upscale_in_train) {
P
phlrain 已提交
132 133 134 135
            y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
          } else {
            y_data[i] = x_data[i];
          }
136
        }
137
      }
138
    } else {
Z
Zeng Jinle 已提交
139
      if (upscale_in_train) {
140 141 142 143 144 145 146 147
        const auto* X_data = x->data<T>();
        auto* Y_data = y->mutable_data<T>(context.GetPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
        for (int i = 0; i < x->numel(); i++) {
          Y_data[i] = X_data[i];
        }
P
phlrain 已提交
148
      } else {
149 150 151 152
        auto X = EigenMatrix<T>::Reshape(*x, 1);
        auto Y = EigenMatrix<T>::Reshape(*y, 1);
        auto& place =
            *context.template device_context<DeviceContext>().eigen_device();
P
phlrain 已提交
153 154
        Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
      }
155 156 157 158
    }
  }
};

Q
QI JUN 已提交
159
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
160
class DropoutGradKernel : public framework::OpKernel<T> {
X
Xinghai Sun 已提交
161 162
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
ceci3 已提交
163 164 165
    PADDLE_ENFORCE_EQ(!context.Attr<bool>("is_test"), true,
                      platform::errors::PreconditionNotMet(
                          "GradOp is only callable when is_test is false"));
166

X
Xinghai Sun 已提交
167 168 169 170
    auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
    auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* mask = context.Input<Tensor>("Mask");
    grad_x->mutable_data<T>(context.GetPlace());
Z
Zhang Ting 已提交
171
    auto size = grad_x->numel();
X
Xinghai Sun 已提交
172

173 174 175
    auto M = EigenVector<uint8_t>::Flatten(*mask);
    auto dX = EigenVector<T>::Flatten(*grad_x);
    auto dY = EigenVector<T>::Flatten(*grad_y);
X
Xinghai Sun 已提交
176

Q
QI JUN 已提交
177 178
    auto& place =
        *context.template device_context<DeviceContext>().eigen_device();
Z
Zeng Jinle 已提交
179 180 181 182 183 184 185
    auto& dropout_implementation =
        context.Attr<std::string>("dropout_implementation");
    if (dropout_implementation == "upscale_in_train") {
      float dropout_prob = context.Attr<float>("dropout_prob");
      if (dropout_prob == 1.0f) {
        dX.device(place) = static_cast<T>(0) * dY;
      } else {
Z
Zhang Ting 已提交
186 187 188
        int vec_size = VectorizedSize<T>(grad_y->data<T>());
        if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 &&
            size % 4 == 0) {
189
#if defined(__NVCC__) || defined(__HIPCC__)
Z
Zhang Ting 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203
          auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
          auto stream = context.cuda_device_context().stream();
          platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(
              context.cuda_device_context(), size);
          DropoutGradCUDAKernel<
              T, uint8_t,
              4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
              grad_y->data<T>(), mask->data<uint8_t>(), factor, size,
              grad_x->data<T>());
#endif
        } else {
          dX.device(place) =
              dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
        }
Z
Zeng Jinle 已提交
204 205 206 207
      }
    } else {
      dX.device(place) = dY * M.cast<T>();
    }
X
Xinghai Sun 已提交
208 209 210 211 212
  }
};

}  // namespace operators
}  // namespace paddle