softmax_with_cross_entropy_op.cc 8.2 KB
Newer Older
1 2 3 4 5 6
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

C
caoying03 已提交
7
http://www.apache.org/licenses/LICENSE-2.0
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#include "paddle/operators/softmax_with_cross_entropy_op.h"
Y
Yu Yang 已提交
16

17 18 19 20 21 22
namespace paddle {
namespace operators {

class SoftmaxWithCrossEntropyOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
23 24
  SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto,
                                 framework::OpAttrChecker* op_checker)
25
      : OpProtoAndCheckerMaker(proto, op_checker) {
C
caoying03 已提交
26
    AddInput("Logits",
27
             "(Tensor, default: Tensor<float>), The unscaled log probabilities "
C
caoying03 已提交
28
             "which is a 2-D tensor with shape [N x K]. N is the batch_size, "
29 30
             "and K is the class number.");
    AddInput("Label",
C
caoying03 已提交
31 32 33 34
             "(Tensor) The ground truth which is a 2-D tensor. If soft_label "
             "is set to false, Label is a Tensor<int64> with shape [N x 1]. If "
             "soft_label is set to true, Label is a Tensor<float/double> with "
             "shape [N x K].");
C
caoying03 已提交
35 36
    AddOutput(
        "Softmax",
37
        "(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
C
caoying03 已提交
38 39
        "The outputs value of softmax activation by given the input batch, "
        "which will be used in backward calculation.")
C
caoying03 已提交
40
        .AsIntermediate();
C
caoying03 已提交
41
    AddOutput("Loss",
42
              "(Tensor, default: Tensor<float>), A 2-D tensor. The cross "
C
caoying03 已提交
43
              "entropy loss with shape [N x 1].");
C
caoying03 已提交
44
    AddAttr<bool>(
45
        "soft_label",
C
caoying03 已提交
46 47 48
        "(bool, default: false), A flag to indicate whether to interpretate "
        "the given labels as soft labels.")
        .SetDefault(false);
49
    AddComment(R"DOC(
50 51 52
Softmax With Cross Entropy Operator.

Cross entropy loss with softmax is used as the output layer extensively. This
53
operator computes the softmax normalized values for each row of the input
54
tensor, after which cross-entropy loss is computed. This provides a more
55 56
numerically stable gradient.

57 58 59
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
60

C
caoying03 已提交
61
When the attribute soft_label is set false, this operators expects mutually
62 63
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
64

65
The equation is as follows:
66

67
1) Hard label (one-hot label, so every sample has exactly one class)
68

69
$$Loss_j =  -\text{Logit}_{Label_j} +
70
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
71
j = 1,..., K$$
C
caoying03 已提交
72

73
2) Soft label (each sample can have a distribution over all classes)
C
caoying03 已提交
74

75
$$Loss_j =  -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
76
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
77
j = 1,...,K$$
C
caoying03 已提交
78 79

)DOC");
80 81 82 83 84 85 86
  }
};

class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

87
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
88 89 90 91 92 93 94 95 96 97
    PADDLE_ENFORCE(ctx->HasInput("Logits"),
                   "Input(Logits) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");

    PADDLE_ENFORCE(ctx->HasOutput("Softmax"),
                   "Output(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");

    auto logits_dims = ctx->GetInputDim("Logits");
    auto labels_dims = ctx->GetInputDim("Label");
C
caoying03 已提交
98
    PADDLE_ENFORCE_EQ(
Q
qiaolongfei 已提交
99
        logits_dims.size(), 2UL,
100
        "The input of softmax_with_cross_entropy should be a 2-D tensor.");
Q
qiaolongfei 已提交
101
    PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
C
caoying03 已提交
102
                      "The labels should be a 2-D tensor.");
103

104
    if (ctx->Attrs().Get<bool>("soft_label")) {
Q
qiaolongfei 已提交
105
      PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
106
                        "If Attr(soft_label) == true, the 2nd dimension of "
107 108
                        "Input(X) and Input(Label) should be equal.");
    } else {
Q
qiaolongfei 已提交
109
      PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
110
                        "If Attr(soft_label) == false, the 2nd dimension of "
111 112 113
                        "Input(Label) should be 1.");
    }

Q
qiaolongfei 已提交
114 115
    ctx->SetOutputDim("Softmax", logits_dims);
    ctx->SetOutputDim("Loss", {logits_dims[0], 1});
116

Q
qiaolongfei 已提交
117 118
    ctx->ShareLoD("Logits", /*->*/ "Softmax");
    ctx->ShareLoD("Logits", /*->*/ "Loss");
C
caoying03 已提交
119
  }
Y
Yu Yang 已提交
120

121
 protected:
Y
Yu Yang 已提交
122
  framework::OpKernelType GetKernelType(
Y
Yu Yang 已提交
123
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
124 125 126
    return framework::OpKernelType(
        framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
        ctx.device_context());
Y
Yu Yang 已提交
127
  }
C
caoying03 已提交
128 129 130 131 132 133
};

class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

134
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
135 136 137 138 139 140 141 142 143 144 145
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
                   "Input(Loss@Grad) should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Softmax"),
                   "Input(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
                   "Output(Logits@Grad) should be not null.");

    auto softmax_dims = ctx->GetInputDim("Softmax");
    auto labels_dims = ctx->GetInputDim("Label");
    PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
C
caoying03 已提交
146
                      "The labels should be a 2-D tensor.");
147

148
    if (ctx->Attrs().Get<bool>("soft_label")) {
Q
qiaolongfei 已提交
149
      PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
150
                        "When Attr(soft_label) == true, the 2nd dimension of "
151 152
                        "Input(X) and Input(Label) should be equal.");
    } else {
Q
qiaolongfei 已提交
153
      PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
154
                        "When Attr(soft_label) == false, the 2nd dimension of "
155 156
                        "Input(Label) should be 1.");
    }
C
caoying03 已提交
157

Q
qiaolongfei 已提交
158 159
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Softmax"));
160
  }
Y
Yu Yang 已提交
161

162
 protected:
Y
Yu Yang 已提交
163
  framework::OpKernelType GetKernelType(
Y
Yu Yang 已提交
164
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
165 166 167 168
    return framework::OpKernelType(
        framework::ToDataType(
            ctx.Input<Tensor>(framework::GradVarName("Loss"))->type()),
        ctx.device_context());
Y
Yu Yang 已提交
169
  }
170 171
};

172 173 174 175 176
class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
Y
Yu Yang 已提交
177 178 179 180 181 182 183 184 185 186 187
  std::unique_ptr<framework::OpDescBind> Apply() const override {
    auto* grad_op = new framework::OpDescBind();
    grad_op->SetType("softmax_with_cross_entropy_grad");
    grad_op->SetInput("Label", Input("Label"));
    grad_op->SetInput("Softmax", Output("Softmax"));
    grad_op->SetInput("Loss", Output("Loss"));
    grad_op->SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax"));
    grad_op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
    grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
    grad_op->SetAttrMap(Attrs());
    return std::unique_ptr<framework::OpDescBind>(grad_op);
188 189 190
  }
};

191 192 193 194 195
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

196
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
Y
Yu Yang 已提交
197
                  ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker);
198 199
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
                  ops::SoftmaxWithCrossEntropyOpGrad);
200
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
C
caoying03 已提交
201 202
                       ops::SoftmaxWithCrossEntropyKernel<float>,
                       ops::SoftmaxWithCrossEntropyKernel<double>);
203
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
C
caoying03 已提交
204 205
                       ops::SoftmaxWithCrossEntropyGradKernel<float>,
                       ops::SoftmaxWithCrossEntropyGradKernel<double>);