dropout_op.cc 7.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
Xinghai Sun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
X
Xinghai Sun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/dropout_op.h"
S
sneaxiy 已提交
16
#include <memory>
P
phlrain 已提交
17
#include <string>
X
Xinghai Sun 已提交
18 19 20 21 22 23 24 25 26 27

namespace paddle {
namespace operators {

using framework::Tensor;

class DropoutOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

28
  void InferShape(framework::InferShapeContext* ctx) const override {
29
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Dropout");
Q
Qiao Longfei 已提交
30 31 32

    auto x_dims = ctx->GetInputDim("X");
    ctx->SetOutputDim("Out", x_dims);
33
    if (ctx->Attrs().Get<bool>("is_test") == false) {
Q
Qiao Longfei 已提交
34
      ctx->SetOutputDim("Mask", x_dims);
35
    }
Q
Qiao Longfei 已提交
36
    ctx->ShareLoD("X", /*->*/ "Out");
X
Xinghai Sun 已提交
37
  }
M
mapingshuo 已提交
38 39 40 41 42 43 44

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
  }
45 46 47 48 49 50 51 52 53 54 55 56 57

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "Seed") {
      VLOG(10) << "var_name:" << var_name
               << " does not need to transform in dropout op";
      return expected_kernel_type;
    }

    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
X
Xinghai Sun 已提交
58 59 60 61
};

class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
62
  void Make() override {
X
Xinghai Sun 已提交
63
    AddInput("X", "The input of dropout op.");
M
mapingshuo 已提交
64 65 66
    AddInput("Seed",
             "The seed of dropout op, it has higher priority than the attr "
             "fix_seed and seed")
67 68
        .AsDispensable()
        .AsExtra();
X
Xinghai Sun 已提交
69
    AddOutput("Out", "The output of dropout op.");
70 71 72
    AddOutput("Mask", "The random sampled dropout mask.")
        .AsIntermediate()
        .AsExtra();
X
Xinghai Sun 已提交
73

K
Kexin Zhao 已提交
74
    AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
C
chengduoZH 已提交
75 76
        .SetDefault(.5f)
        .AddCustomChecker([](const float& drop_p) {
77 78 79
          PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true,
                            platform::errors::InvalidArgument(
                                "'dropout_prob' must be between 0.0 and 1.0."));
C
chengduoZH 已提交
80
        });
81 82 83 84
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
85 86 87 88 89 90
    AddAttr<bool>("fix_seed",
                  "A flag indicating whether to use a fixed seed to generate "
                  "random mask. NOTE: DO NOT set this flag to true in "
                  "training. Setting this flag to true is only useful in "
                  "unittest or for debug that always the same output units "
                  "will be dropped.")
91 92
        .SetDefault(false)
        .AsExtra();
93
    AddAttr<int>("seed", "Dropout random seed.").SetDefault(0).AsExtra();
P
phlrain 已提交
94 95 96 97 98 99 100 101 102
    AddAttr<std::string>(
        "dropout_implementation",
        "[\"downgrade_in_infer\"|\"upscale_in_train\"]"
        "There are two kinds of ways to implement dropout"
        "(the mask below is a tensor have the same shape with input"
        "the value of mask is 0 or 1, the ratio of 0 is dropout_prob)"
        "1. downgrade_in_infer(default), downgrade the outcome at inference "
        "time"
        "   train: out = input * mask"
C
ceci3 已提交
103
        "   inference: out = input * (1.0 - dropout_prob)"
P
phlrain 已提交
104 105 106 107 108 109 110 111
        "2. upscale_in_train, upscale the outcome at training time, do nothing "
        "in inference"
        "   train: out = input * mask / ( 1.0 - dropout_prob )"
        "   inference: out = input"
        "   dropout op can be removed from the program. the program will be "
        "efficient")
        .SetDefault("downgrade_in_infer")
        .AddCustomChecker([](const std::string& type) {
112 113 114 115 116
          PADDLE_ENFORCE_EQ(
              type == "downgrade_in_infer" || type == "upscale_in_train", true,
              platform::errors::InvalidArgument(
                  "dropout_implementation can only be downgrade_in_infer or "
                  "upscale_in_train"));
P
phlrain 已提交
117
        });
K
Kexin Zhao 已提交
118

119 120 121
    AddComment(R"DOC(
Dropout Operator.

K
Kexin Zhao 已提交
122
Dropout refers to randomly dropping out units in a nerual network. It is a
123 124
regularization technique for reducing overfitting by preventing neuron
co-adaption during training. The dropout operator randomly set (according to
125
the given dropout probability) the outputs of some units to zero, while others
K
Kexin Zhao 已提交
126 127
are set equal to their corresponding inputs.

128
)DOC");
X
Xinghai Sun 已提交
129 130 131 132 133 134 135
  }
};

class DropoutOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

136
  void InferShape(framework::InferShapeContext* ctx) const override {
137 138 139
    OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "DropoutGrad");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
                   framework::GradVarName("Out"), "DropoutGrad");
Q
Qiao Longfei 已提交
140 141

    auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
S
sneaxiy 已提交
142 143 144 145 146

    ctx->SetOutputDim(framework::GradVarName("X"), out_dims);
    ctx->ShareLoD(framework::GradVarName("Out"),
                  /*->*/ framework::GradVarName("X"));
  }
Z
Zeng Jinle 已提交
147 148 149 150

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
151 152 153
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.GetPlace());
Z
Zeng Jinle 已提交
154
  }
S
sneaxiy 已提交
155 156
};

H
hong 已提交
157 158
template <typename T>
class DropoutGradOpMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
159
 public:
H
hong 已提交
160
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
161 162

 protected:
163
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
164
    op->SetType("dropout_grad");
H
hong 已提交
165 166 167 168
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetInput("Mask", this->Output("Mask"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
X
Xinghai Sun 已提交
169 170 171 172 173 174 175
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
176
REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
H
hong 已提交
177 178
                  ops::DropoutGradOpMaker<paddle::framework::OpDesc>,
                  ops::DropoutGradOpMaker<paddle::imperative::OpBase>);
179
REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad);
180
REGISTER_OP_CPU_KERNEL(
P
phlrain 已提交
181
    dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>,
182 183 184
    ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, double>,
    ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext,
                          paddle::platform::bfloat16>);
X
Xinghai Sun 已提交
185
REGISTER_OP_CPU_KERNEL(
Q
QI JUN 已提交
186
    dropout_grad,
P
phlrain 已提交
187
    ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>,
188 189 190
    ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, double>,
    ops::DropoutGradKernel<paddle::platform::CPUDeviceContext,
                           paddle::platform::bfloat16>);