dropout_op.h 5.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
X
Xinghai Sun 已提交
14
#pragma once
Y
Yi Wang 已提交
15

Z
Zeng Jinle 已提交
16
#include <cstring>
17
#include <random>
P
phlrain 已提交
18
#include <string>
Y
Yi Wang 已提交
19

Z
Zhang Ting 已提交
20
#include <algorithm>
Y
Yi Wang 已提交
21
#include "paddle/fluid/framework/eigen.h"
22
#include "paddle/fluid/framework/generator.h"
Y
Yi Wang 已提交
23
#include "paddle/fluid/framework/op_registry.h"
X
Xinghai Sun 已提交
24 25 26 27 28 29 30

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
31
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
X
Xinghai Sun 已提交
32

33 34 35 36
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

K
Kexin Zhao 已提交
37
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
38
class CPUDropoutKernel : public framework::OpKernel<T> {
39 40 41
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
M
mapingshuo 已提交
42 43
    auto* seed =
        context.HasInput("Seed") ? context.Input<Tensor>("Seed") : nullptr;
44
    auto* y = context.Output<Tensor>("Out");
45
    const auto* x_data = x->data<T>();
46
    auto* y_data = y->mutable_data<T>(context.GetPlace());
47
    float dropout_prob = context.Attr<float>("dropout_prob");
48

Z
Zeng Jinle 已提交
49
    auto& dropout_implementation =
P
phlrain 已提交
50
        context.Attr<std::string>("dropout_implementation");
Z
Zeng Jinle 已提交
51
    bool upscale_in_train = (dropout_implementation == "upscale_in_train");
52
    if (!context.Attr<bool>("is_test")) {
53
      auto* mask = context.Output<Tensor>("Mask");
Z
Zeng Jinle 已提交
54
      auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
55
      size_t size = phi::product(mask->dims());
Z
Zeng Jinle 已提交
56 57 58 59 60 61 62

      // Special case when dropout_prob is 1.0
      if (dropout_prob == 1.0f) {
        std::memset(y_data, 0, size * sizeof(*y_data));        // NOLINT
        std::memset(mask_data, 0, size * sizeof(*mask_data));  // NOLINT
        return;
      }
L
Leo Chen 已提交
63
      // std::minstd_rand engine;
64 65
      // NOTE: fixed seed should only be used in unittest or for debug.
      // Guarantee to use random seed in training.
L
Leo Chen 已提交
66
      int seed_data = 0;
M
mapingshuo 已提交
67 68 69 70
      if (seed) {
        seed_data = *(seed->data<int>());
      } else {
        seed_data =
L
Leo Chen 已提交
71
            context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : 0;
M
mapingshuo 已提交
72
      }
L
Leo Chen 已提交
73
      auto engine = framework::GetCPURandomEngine(seed_data);
74

75
      std::uniform_real_distribution<float> dist(0, 1);
P
phlrain 已提交
76

77
      for (size_t i = 0; i < size; ++i) {
L
Leo Chen 已提交
78
        if (dist(*engine) < dropout_prob) {
79 80 81
          mask_data[i] = 0;
          y_data[i] = 0;
        } else {
Z
Zeng Jinle 已提交
82 83
          mask_data[i] = 1;
          if (upscale_in_train) {
P
phlrain 已提交
84 85 86 87
            y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
          } else {
            y_data[i] = x_data[i];
          }
88
        }
89
      }
90
    } else {
Z
Zeng Jinle 已提交
91
      if (upscale_in_train) {
92 93 94 95 96 97 98 99
        const auto* X_data = x->data<T>();
        auto* Y_data = y->mutable_data<T>(context.GetPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
        for (int i = 0; i < x->numel(); i++) {
          Y_data[i] = X_data[i];
        }
P
phlrain 已提交
100
      } else {
101 102 103 104
        auto X = EigenMatrix<T>::Reshape(*x, 1);
        auto Y = EigenMatrix<T>::Reshape(*y, 1);
        auto& place =
            *context.template device_context<DeviceContext>().eigen_device();
P
phlrain 已提交
105 106
        Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
      }
107 108 109
    }
  }
};
Q
QI JUN 已提交
110
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
111
class DropoutGradKernel : public framework::OpKernel<T> {
X
Xinghai Sun 已提交
112 113 114 115 116 117 118
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
    auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* mask = context.Input<Tensor>("Mask");
    grad_x->mutable_data<T>(context.GetPlace());

119 120
    auto dX = EigenVector<T>::Flatten(*grad_x);
    auto dY = EigenVector<T>::Flatten(*grad_y);
X
Xinghai Sun 已提交
121

Q
QI JUN 已提交
122 123
    auto& place =
        *context.template device_context<DeviceContext>().eigen_device();
Z
Zeng Jinle 已提交
124 125
    auto& dropout_implementation =
        context.Attr<std::string>("dropout_implementation");
126 127 128
    if (context.Attr<bool>("is_test") == true) {
      if (dropout_implementation == "upscale_in_train") {
        dX.device(place) = static_cast<T>(1) * dY;
Z
Zeng Jinle 已提交
129
      } else {
130 131 132 133 134 135 136 137 138 139
        float dropout_prob = context.Attr<float>("dropout_prob");
        dX.device(place) = dY * static_cast<T>(1.0f - dropout_prob);
      }
    } else {
      auto M = EigenVector<uint8_t>::Flatten(*mask);
      if (dropout_implementation == "upscale_in_train") {
        float dropout_prob = context.Attr<float>("dropout_prob");
        if (dropout_prob == 1.0f) {
          dX.device(place) = static_cast<T>(0) * dY;
        } else {
140 141
          dX.device(place) =
              dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
Z
Zhang Ting 已提交
142
        }
143 144
      } else {
        dX.device(place) = dY * M.cast<T>();
Z
Zeng Jinle 已提交
145 146
      }
    }
X
Xinghai Sun 已提交
147 148 149 150 151
  }
};

}  // namespace operators
}  // namespace paddle