softmax_with_cross_entropy_op.cc 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/operators/softmax_with_cross_entropy_op.h"
Y
Yu Yang 已提交
16
#include <paddle/function/TensorType.h>
17 18 19 20 21 22 23

namespace paddle {
namespace operators {

class SoftmaxWithCrossEntropyOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
24 25
  SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto,
                                 framework::OpAttrChecker* op_checker)
26
      : OpProtoAndCheckerMaker(proto, op_checker) {
C
caoying03 已提交
27
    AddInput("Logits",
28
             "(Tensor, default: Tensor<float>), The unscaled log probabilities "
C
caoying03 已提交
29 30
             "which is a 2-D tensor with shape [N x K]. N is the batch_size, "
             "and K is the class number.")
31
        .NotInGradient();
C
caoying03 已提交
32 33
    AddInput(
        "Label",
34 35 36 37
        "(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
        "tensor. "
        "If softLable is set to 0, Label is a Tensor<int> with shape [N x 1]. "
        "If softLable is set to 1, Label is a Tensor<float/double> "
C
caoying03 已提交
38 39 40
        "with shape [N x K].");
    AddOutput(
        "Softmax",
41
        "(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
C
caoying03 已提交
42 43
        "The outputs value of softmax activation by given the input batch, "
        "which will be used in backward calculation.")
C
caoying03 已提交
44
        .AsIntermediate();
C
caoying03 已提交
45
    AddOutput("Loss",
46
              "(Tensor, default: Tensor<float>), A 2-D tensor. The cross "
C
caoying03 已提交
47
              "entropy loss with shape [N x 1].");
C
caoying03 已提交
48 49 50 51 52
    AddAttr<bool>(
        "softLabel",
        "(bool, default: false), A flag to indicate whether to interpretate "
        "the given labels as soft labels.")
        .SetDefault(false);
53 54 55 56 57 58 59 60 61 62 63 64 65 66
    AddComment(R"DOC(
Cross entropy loss with softmax are used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input
tensor, after which cross-entropy loss is then computed. This provides a more
numerically stable gradient.

Because this operators performs a softmax on logits internally, it expects
unscaled logits. Please do not call this op with the output of softmax operator,
which will produce incorrect results.

This operators expects mutually exclusive hard labels, each sample in a batch
is in exactly one class with probabilities 1. Each sample in the batch with one
and only one label.

C
caoying03 已提交
67
Equation:
68

C
caoying03 已提交
69
1) hard label (one-hot label)
70

C
caoying03 已提交
71 72 73 74 75 76 77
Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K

2) soft label (a distribution over all classes)

Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K

)DOC");
78 79 80 81 82 83 84 85
  }
};

class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
Q
qiaolongfei 已提交
86 87 88 89 90 91 92 93 94 95 96
  void InferShape(framework::InferShapeContextBase* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("Logits"),
                   "Input(Logits) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");

    PADDLE_ENFORCE(ctx->HasOutput("Softmax"),
                   "Output(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");

    auto logits_dims = ctx->GetInputDim("Logits");
    auto labels_dims = ctx->GetInputDim("Label");
C
caoying03 已提交
97
    PADDLE_ENFORCE_EQ(
Q
qiaolongfei 已提交
98
        logits_dims.size(), 2UL,
99
        "The input of softmax_with_cross_entropy should be a 2-D tensor.");
Q
qiaolongfei 已提交
100
    PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
C
caoying03 已提交
101
                      "The labels should be a 2-D tensor.");
102

Q
qiaolongfei 已提交
103 104
    if (ctx->Attrs().Get<bool>("softLabel")) {
      PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
105 106 107
                        "If Attr(softLabel) == true, the 2nd dimension of "
                        "Input(X) and Input(Label) should be equal.");
    } else {
Q
qiaolongfei 已提交
108
      PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
109 110 111 112
                        "If Attr(softLabel) == false, the 2nd dimension of "
                        "Input(Label) should be 1.");
    }

Q
qiaolongfei 已提交
113 114
    ctx->SetOutputDim("Softmax", logits_dims);
    ctx->SetOutputDim("Loss", {logits_dims[0], 1});
115

Q
qiaolongfei 已提交
116 117
    ctx->ShareLoD("Logits", /*->*/ "Softmax");
    ctx->ShareLoD("Logits", /*->*/ "Loss");
C
caoying03 已提交
118
  }
Y
Yu Yang 已提交
119 120 121 122 123

  framework::DataType IndicateDataType(
      const framework::ExecutionContext& ctx) const override {
    return framework::ToDataType(ctx.Input<Tensor>("Logits")->type());
  }
C
caoying03 已提交
124 125 126 127 128 129 130
};

class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
Q
qiaolongfei 已提交
131 132 133 134 135 136 137 138 139 140 141 142
  void InferShape(framework::InferShapeContextBase* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
                   "Input(Loss@Grad) should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Softmax"),
                   "Input(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
                   "Output(Logits@Grad) should be not null.");

    auto softmax_dims = ctx->GetInputDim("Softmax");
    auto labels_dims = ctx->GetInputDim("Label");
    PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
C
caoying03 已提交
143
                      "The labels should be a 2-D tensor.");
144

Q
qiaolongfei 已提交
145 146
    if (ctx->Attrs().Get<bool>("softLabel")) {
      PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
147 148 149
                        "When Attr(softLabel) == true, the 2nd dimension of "
                        "Input(X) and Input(Label) should be equal.");
    } else {
Q
qiaolongfei 已提交
150
      PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
151 152 153
                        "When Attr(softLabel) == false, the 2nd dimension of "
                        "Input(Label) should be 1.");
    }
C
caoying03 已提交
154

Q
qiaolongfei 已提交
155 156
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Softmax"));
157
  }
Y
Yu Yang 已提交
158 159 160

  framework::DataType IndicateDataType(
      const framework::ExecutionContext& ctx) const override {
Y
Fix CI  
Yu Yang 已提交
161 162
    return framework::ToDataType(
        ctx.Input<Tensor>(framework::GradVarName("Loss"))->type());
Y
Yu Yang 已提交
163
  }
164 165 166 167 168 169 170 171 172 173 174
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
            ops::SoftmaxWithCrossEntropyOpMaker,
            softmax_with_cross_entropy_grad,
            ops::SoftmaxWithCrossEntropyOpGrad);
175 176 177 178
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
                       ops::SoftmaxWithCrossEntropyKernel<float>);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
                       ops::SoftmaxWithCrossEntropyGradKernel<float>);