dropout_op.h 5.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
X
Xinghai Sun 已提交
14
#pragma once
Y
Yi Wang 已提交
15

Z
Zeng Jinle 已提交
16
#include <cstring>
17
#include <random>
P
phlrain 已提交
18
#include <string>
Y
Yi Wang 已提交
19

Y
Yi Wang 已提交
20
#include "paddle/fluid/framework/eigen.h"
21
#include "paddle/fluid/framework/generator.h"
Y
Yi Wang 已提交
22
#include "paddle/fluid/framework/op_registry.h"
X
Xinghai Sun 已提交
23 24 25 26 27 28 29

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
30
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
X
Xinghai Sun 已提交
31

K
Kexin Zhao 已提交
32
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
33
class CPUDropoutKernel : public framework::OpKernel<T> {
34 35 36
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
M
mapingshuo 已提交
37 38
    auto* seed =
        context.HasInput("Seed") ? context.Input<Tensor>("Seed") : nullptr;
39
    auto* y = context.Output<Tensor>("Out");
40
    const auto* x_data = x->data<T>();
41
    auto* y_data = y->mutable_data<T>(context.GetPlace());
42
    float dropout_prob = context.Attr<float>("dropout_prob");
43

Z
Zeng Jinle 已提交
44
    auto& dropout_implementation =
P
phlrain 已提交
45
        context.Attr<std::string>("dropout_implementation");
Z
Zeng Jinle 已提交
46
    bool upscale_in_train = (dropout_implementation == "upscale_in_train");
47
    if (!context.Attr<bool>("is_test")) {
48
      auto* mask = context.Output<Tensor>("Mask");
Z
Zeng Jinle 已提交
49 50 51 52 53 54 55 56 57
      auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
      size_t size = framework::product(mask->dims());

      // Special case when dropout_prob is 1.0
      if (dropout_prob == 1.0f) {
        std::memset(y_data, 0, size * sizeof(*y_data));        // NOLINT
        std::memset(mask_data, 0, size * sizeof(*mask_data));  // NOLINT
        return;
      }
58

59 60
      bool init_generator_py = framework::Generator::GetInstance()->is_init_py;

61 62 63
      // NOTE: fixed seed should only be used in unittest or for debug.
      // Guarantee to use random seed in training.
      std::random_device rnd;
64
      std::minstd_rand engine;
M
mapingshuo 已提交
65 66 67 68 69 70 71 72
      int seed_data;
      if (seed) {
        seed_data = *(seed->data<int>());
      } else {
        seed_data =
            context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
      }
      engine.seed(seed_data);
73

74
      std::uniform_real_distribution<float> dist(0, 1);
P
phlrain 已提交
75

76
      for (size_t i = 0; i < size; ++i) {
77 78 79 80 81
        float cur_random =
            init_generator_py
                ? dist(framework::Generator::GetInstance()->GetCPUEngine())
                : dist(engine);
        if (cur_random < dropout_prob) {
82 83 84
          mask_data[i] = 0;
          y_data[i] = 0;
        } else {
Z
Zeng Jinle 已提交
85 86
          mask_data[i] = 1;
          if (upscale_in_train) {
P
phlrain 已提交
87 88 89 90
            y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
          } else {
            y_data[i] = x_data[i];
          }
91
        }
92
      }
93
    } else {
Z
Zeng Jinle 已提交
94
      if (upscale_in_train) {
95 96 97 98 99 100 101 102
        const auto* X_data = x->data<T>();
        auto* Y_data = y->mutable_data<T>(context.GetPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
        for (int i = 0; i < x->numel(); i++) {
          Y_data[i] = X_data[i];
        }
P
phlrain 已提交
103
      } else {
104 105 106 107
        auto X = EigenMatrix<T>::Reshape(*x, 1);
        auto Y = EigenMatrix<T>::Reshape(*y, 1);
        auto& place =
            *context.template device_context<DeviceContext>().eigen_device();
P
phlrain 已提交
108 109
        Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
      }
110 111 112 113
    }
  }
};

Q
QI JUN 已提交
114
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
115
class DropoutGradKernel : public framework::OpKernel<T> {
X
Xinghai Sun 已提交
116 117
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
ceci3 已提交
118 119 120
    PADDLE_ENFORCE_EQ(!context.Attr<bool>("is_test"), true,
                      platform::errors::PreconditionNotMet(
                          "GradOp is only callable when is_test is false"));
121

X
Xinghai Sun 已提交
122 123 124 125 126
    auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
    auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* mask = context.Input<Tensor>("Mask");
    grad_x->mutable_data<T>(context.GetPlace());

Z
Zeng Jinle 已提交
127
    auto M = EigenMatrix<uint8_t>::Reshape(*mask, 1);
128 129
    auto dX = EigenMatrix<T>::Reshape(*grad_x, 1);
    auto dY = EigenMatrix<T>::Reshape(*grad_y, 1);
X
Xinghai Sun 已提交
130

Q
QI JUN 已提交
131 132
    auto& place =
        *context.template device_context<DeviceContext>().eigen_device();
Z
Zeng Jinle 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146

    auto& dropout_implementation =
        context.Attr<std::string>("dropout_implementation");
    if (dropout_implementation == "upscale_in_train") {
      float dropout_prob = context.Attr<float>("dropout_prob");
      if (dropout_prob == 1.0f) {
        dX.device(place) = static_cast<T>(0) * dY;
      } else {
        dX.device(place) =
            dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
      }
    } else {
      dX.device(place) = dY * M.cast<T>();
    }
X
Xinghai Sun 已提交
147 148 149 150 151
  }
};

}  // namespace operators
}  // namespace paddle