softmax_with_cross_entropy_op.cc 8.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
Y
Yu Yang 已提交
16

17 18 19 20 21 22
namespace paddle {
namespace operators {

class SoftmaxWithCrossEntropyOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
23
  void Make() override {
C
caoying03 已提交
24
    AddInput("Logits",
25
             "(Tensor, default: Tensor<float>), The unscaled log probabilities "
C
caoying03 已提交
26
             "which is a 2-D tensor with shape [N x K]. N is the batch_size, "
27 28
             "and K is the class number.");
    AddInput("Label",
C
caoying03 已提交
29 30 31 32
             "(Tensor) The ground truth which is a 2-D tensor. If soft_label "
             "is set to false, Label is a Tensor<int64> with shape [N x 1]. If "
             "soft_label is set to true, Label is a Tensor<float/double> with "
             "shape [N x K].");
C
caoying03 已提交
33 34
    AddOutput(
        "Softmax",
35
        "(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
C
caoying03 已提交
36 37
        "The outputs value of softmax activation by given the input batch, "
        "which will be used in backward calculation.")
C
caoying03 已提交
38
        .AsIntermediate();
C
caoying03 已提交
39
    AddOutput("Loss",
40
              "(Tensor, default: Tensor<float>), A 2-D tensor. The cross "
C
caoying03 已提交
41
              "entropy loss with shape [N x 1].");
C
caoying03 已提交
42
    AddAttr<bool>(
43
        "soft_label",
C
caoying03 已提交
44 45 46
        "(bool, default: false), A flag to indicate whether to interpretate "
        "the given labels as soft labels.")
        .SetDefault(false);
S
sneaxiy 已提交
47 48 49 50 51 52
    AddAttr<bool>(
        "numeric_stable_mode",
        "(bool, default: false), A flag to indicate whether to use more "
        "numerically stable algorithm. This flag is only valid when "
        "soft_label is false and GPU is used.")
        .SetDefault(false);
53 54 55 56 57 58
    AddAttr<int>(
        "ignore_index",
        "(int, default -100), Specifies a target value that is ignored and"
        "does not contribute to the input gradient. Only valid if soft_label"
        "is set to False")
        .SetDefault(-100);
59
    AddComment(R"DOC(
60 61 62
Softmax With Cross Entropy Operator.

Cross entropy loss with softmax is used as the output layer extensively. This
63
operator computes the softmax normalized values for each row of the input
64
tensor, after which cross-entropy loss is computed. This provides a more
65 66
numerically stable gradient.

67 68 69
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
70

C
caoying03 已提交
71
When the attribute soft_label is set false, this operators expects mutually
72 73
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
74

75
The equation is as follows:
76

77
1) Hard label (one-hot label, so every sample has exactly one class)
78

79
$$Loss_j =  -\text{Logit}_{Label_j} +
80
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
81
j = 1,..., K$$
C
caoying03 已提交
82

83
2) Soft label (each sample can have a distribution over all classes)
C
caoying03 已提交
84

85
$$Loss_j =  -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
86
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
87
j = 1,...,K$$
C
caoying03 已提交
88 89

)DOC");
90 91 92 93 94 95 96
  }
};

class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

97
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
98 99 100 101 102 103 104 105 106 107
    PADDLE_ENFORCE(ctx->HasInput("Logits"),
                   "Input(Logits) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");

    PADDLE_ENFORCE(ctx->HasOutput("Softmax"),
                   "Output(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");

    auto logits_dims = ctx->GetInputDim("Logits");
    auto labels_dims = ctx->GetInputDim("Label");
C
caoying03 已提交
108
    PADDLE_ENFORCE_EQ(
Q
qiaolongfei 已提交
109
        logits_dims.size(), 2UL,
110
        "The input of softmax_with_cross_entropy should be a 2-D tensor.");
Q
qiaolongfei 已提交
111
    PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
C
caoying03 已提交
112
                      "The labels should be a 2-D tensor.");
113

114
    if (ctx->Attrs().Get<bool>("soft_label")) {
Q
qiaolongfei 已提交
115
      PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
116
                        "If Attr(soft_label) == true, the 2nd dimension of "
117 118
                        "Input(X) and Input(Label) should be equal.");
    } else {
Q
qiaolongfei 已提交
119
      PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
120
                        "If Attr(soft_label) == false, the 2nd dimension of "
121 122 123
                        "Input(Label) should be 1.");
    }

Q
qiaolongfei 已提交
124 125
    ctx->SetOutputDim("Softmax", logits_dims);
    ctx->SetOutputDim("Loss", {logits_dims[0], 1});
126

Q
qiaolongfei 已提交
127 128
    ctx->ShareLoD("Logits", /*->*/ "Softmax");
    ctx->ShareLoD("Logits", /*->*/ "Loss");
C
caoying03 已提交
129
  }
Y
Yu Yang 已提交
130

131
 protected:
132
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
133
      const framework::ExecutionContext& ctx) const override {
M
minqiyang 已提交
134 135
    return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(),
                                   ctx.device_context());
Y
Yu Yang 已提交
136
  }
C
caoying03 已提交
137 138 139 140 141 142
};

class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

143
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
144 145 146 147 148 149 150 151 152 153 154
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
                   "Input(Loss@Grad) should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Softmax"),
                   "Input(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
                   "Output(Logits@Grad) should be not null.");

    auto softmax_dims = ctx->GetInputDim("Softmax");
    auto labels_dims = ctx->GetInputDim("Label");
    PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
C
caoying03 已提交
155
                      "The labels should be a 2-D tensor.");
156

157
    if (ctx->Attrs().Get<bool>("soft_label")) {
Q
qiaolongfei 已提交
158
      PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
159
                        "When Attr(soft_label) == true, the 2nd dimension of "
160 161
                        "Input(X) and Input(Label) should be equal.");
    } else {
Q
qiaolongfei 已提交
162
      PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
163
                        "When Attr(soft_label) == false, the 2nd dimension of "
164 165
                        "Input(Label) should be 1.");
    }
C
caoying03 已提交
166

Q
qiaolongfei 已提交
167 168
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Softmax"));
169
  }
Y
Yu Yang 已提交
170

171
 protected:
172
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
173
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
174
    return framework::OpKernelType(
M
minqiyang 已提交
175
        ctx.Input<Tensor>(framework::GradVarName("Loss"))->type(),
Y
Yu Yang 已提交
176
        ctx.device_context());
Y
Yu Yang 已提交
177
  }
178 179
};

180 181 182 183 184
class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
Y
Yu Yang 已提交
185 186
  std::unique_ptr<framework::OpDesc> Apply() const override {
    auto* grad_op = new framework::OpDesc();
Y
Yu Yang 已提交
187 188 189 190 191 192 193 194
    grad_op->SetType("softmax_with_cross_entropy_grad");
    grad_op->SetInput("Label", Input("Label"));
    grad_op->SetInput("Softmax", Output("Softmax"));
    grad_op->SetInput("Loss", Output("Loss"));
    grad_op->SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax"));
    grad_op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
    grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
    grad_op->SetAttrMap(Attrs());
Y
Yu Yang 已提交
195
    return std::unique_ptr<framework::OpDesc>(grad_op);
196 197 198
  }
};

199 200 201 202 203
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

204
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
Y
Yu Yang 已提交
205
                  ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker);
206 207
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
                  ops::SoftmaxWithCrossEntropyOpGrad);
208
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
C
caoying03 已提交
209 210
                       ops::SoftmaxWithCrossEntropyKernel<float>,
                       ops::SoftmaxWithCrossEntropyKernel<double>);
211
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
C
caoying03 已提交
212 213
                       ops::SoftmaxWithCrossEntropyGradKernel<float>,
                       ops::SoftmaxWithCrossEntropyGradKernel<double>);