softmax_with_cross_entropy_op.cc 8.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
16
#include <memory>
Y
Yu Yang 已提交
17

18 19 20 21 22 23
namespace paddle {
namespace operators {

class SoftmaxWithCrossEntropyOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
24
  void Make() override {
C
caoying03 已提交
25
    AddInput("Logits",
26
             "(Tensor, default: Tensor<float>), The unscaled log probabilities "
C
caoying03 已提交
27
             "which is a 2-D tensor with shape [N x K]. N is the batch_size, "
28 29
             "and K is the class number.");
    AddInput("Label",
C
caoying03 已提交
30 31 32 33
             "(Tensor) The ground truth which is a 2-D tensor. If soft_label "
             "is set to false, Label is a Tensor<int64> with shape [N x 1]. If "
             "soft_label is set to true, Label is a Tensor<float/double> with "
             "shape [N x K].");
C
caoying03 已提交
34 35
    AddOutput(
        "Softmax",
36
        "(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
C
caoying03 已提交
37 38
        "The outputs value of softmax activation by given the input batch, "
        "which will be used in backward calculation.")
C
caoying03 已提交
39
        .AsIntermediate();
C
caoying03 已提交
40
    AddOutput("Loss",
41
              "(Tensor, default: Tensor<float>), A 2-D tensor. The cross "
C
caoying03 已提交
42
              "entropy loss with shape [N x 1].");
C
caoying03 已提交
43
    AddAttr<bool>(
44
        "soft_label",
C
caoying03 已提交
45 46 47
        "(bool, default: false), A flag to indicate whether to interpretate "
        "the given labels as soft labels.")
        .SetDefault(false);
S
sneaxiy 已提交
48 49
    AddAttr<bool>(
        "numeric_stable_mode",
50
        "(bool, default: true), A flag to indicate whether to use more "
S
sneaxiy 已提交
51 52
        "numerically stable algorithm. This flag is only valid when "
        "soft_label is false and GPU is used.")
53
        .SetDefault(true);
54 55 56 57 58 59
    AddAttr<int>(
        "ignore_index",
        "(int, default -100), Specifies a target value that is ignored and"
        "does not contribute to the input gradient. Only valid if soft_label"
        "is set to False")
        .SetDefault(-100);
60
    AddComment(R"DOC(
61 62 63
Softmax With Cross Entropy Operator.

Cross entropy loss with softmax is used as the output layer extensively. This
64
operator computes the softmax normalized values for each row of the input
65
tensor, after which cross-entropy loss is computed. This provides a more
66 67
numerically stable gradient.

68 69 70
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
71

C
caoying03 已提交
72
When the attribute soft_label is set false, this operators expects mutually
73 74
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
75

76
The equation is as follows:
77

78
1) Hard label (one-hot label, so every sample has exactly one class)
79

80
$$Loss_j =  -\text{Logit}_{Label_j} +
81
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
82
j = 1,..., K$$
C
caoying03 已提交
83

84
2) Soft label (each sample can have a distribution over all classes)
C
caoying03 已提交
85

86
$$Loss_j =  -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
87
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
88
j = 1,...,K$$
C
caoying03 已提交
89 90

)DOC");
91 92 93 94 95 96 97
  }
};

class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

98
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
99 100 101 102 103 104 105 106 107 108
    PADDLE_ENFORCE(ctx->HasInput("Logits"),
                   "Input(Logits) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");

    PADDLE_ENFORCE(ctx->HasOutput("Softmax"),
                   "Output(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");

    auto logits_dims = ctx->GetInputDim("Logits");
    auto labels_dims = ctx->GetInputDim("Label");
C
caoying03 已提交
109
    PADDLE_ENFORCE_EQ(
Q
qiaolongfei 已提交
110
        logits_dims.size(), 2UL,
111
        "The input of softmax_with_cross_entropy should be a 2-D tensor.");
Q
qiaolongfei 已提交
112
    PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
C
caoying03 已提交
113
                      "The labels should be a 2-D tensor.");
114

115
    if (ctx->Attrs().Get<bool>("soft_label")) {
Q
qiaolongfei 已提交
116
      PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
117
                        "If Attr(soft_label) == true, the 2nd dimension of "
118 119
                        "Input(X) and Input(Label) should be equal.");
    } else {
Q
qiaolongfei 已提交
120
      PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
121
                        "If Attr(soft_label) == false, the 2nd dimension of "
122 123 124
                        "Input(Label) should be 1.");
    }

Q
qiaolongfei 已提交
125 126
    ctx->SetOutputDim("Softmax", logits_dims);
    ctx->SetOutputDim("Loss", {logits_dims[0], 1});
127

Q
qiaolongfei 已提交
128 129
    ctx->ShareLoD("Logits", /*->*/ "Softmax");
    ctx->ShareLoD("Logits", /*->*/ "Loss");
C
caoying03 已提交
130
  }
Y
Yu Yang 已提交
131

132
 protected:
133
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
134
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
135 136
    return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(),
                                   ctx.device_context());
Y
Yu Yang 已提交
137
  }
C
caoying03 已提交
138 139 140 141 142 143
};

class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

144
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
145 146 147 148 149 150 151 152 153 154 155
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
                   "Input(Loss@Grad) should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Softmax"),
                   "Input(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
                   "Output(Logits@Grad) should be not null.");

    auto softmax_dims = ctx->GetInputDim("Softmax");
    auto labels_dims = ctx->GetInputDim("Label");
    PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
C
caoying03 已提交
156
                      "The labels should be a 2-D tensor.");
157

158
    if (ctx->Attrs().Get<bool>("soft_label")) {
Q
qiaolongfei 已提交
159
      PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
160
                        "When Attr(soft_label) == true, the 2nd dimension of "
161 162
                        "Input(X) and Input(Label) should be equal.");
    } else {
Q
qiaolongfei 已提交
163
      PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
164
                        "When Attr(soft_label) == false, the 2nd dimension of "
165 166
                        "Input(Label) should be 1.");
    }
C
caoying03 已提交
167

Q
qiaolongfei 已提交
168 169
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Softmax"));
170
  }
Y
Yu Yang 已提交
171

172
 protected:
173
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
174
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
175
    return framework::OpKernelType(
Y
Yu Yang 已提交
176
        ctx.Input<Tensor>(framework::GradVarName("Loss"))->type(),
Y
Yu Yang 已提交
177
        ctx.device_context());
Y
Yu Yang 已提交
178
  }
179 180
};

181 182 183 184 185
class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
Y
Yu Yang 已提交
186 187
  std::unique_ptr<framework::OpDesc> Apply() const override {
    auto* grad_op = new framework::OpDesc();
Y
Yu Yang 已提交
188 189 190 191 192 193 194
    grad_op->SetType("softmax_with_cross_entropy_grad");
    grad_op->SetInput("Label", Input("Label"));
    grad_op->SetInput("Softmax", Output("Softmax"));
    grad_op->SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax"));
    grad_op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
    grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
    grad_op->SetAttrMap(Attrs());
Y
Yu Yang 已提交
195
    return std::unique_ptr<framework::OpDesc>(grad_op);
196 197 198
  }
};

199 200 201 202 203
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

204
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
Y
Yu Yang 已提交
205
                  ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker);
206 207
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
                  ops::SoftmaxWithCrossEntropyOpGrad);
208
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
C
caoying03 已提交
209 210
                       ops::SoftmaxWithCrossEntropyKernel<float>,
                       ops::SoftmaxWithCrossEntropyKernel<double>);
211
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
C
caoying03 已提交
212 213
                       ops::SoftmaxWithCrossEntropyGradKernel<float>,
                       ops::SoftmaxWithCrossEntropyGradKernel<double>);