softmax_with_cross_entropy_op.cc 11.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
16
#include <memory>
Z
Zeng Jinle 已提交
17 18 19
#include <string>
#include <unordered_map>
#include <vector>
Y
Yu Yang 已提交
20

21 22 23 24 25 26
namespace paddle {
namespace operators {

class SoftmaxWithCrossEntropyOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
27
  void Make() override {
C
caoying03 已提交
28
    AddInput("Logits",
29 30 31 32 33
             "(Tensor, default: Tensor<float>), The input tensor of unscaled "
             "log probabilities, whose dimension :attr:`axis` should be scaled "
             "by softmax.");
    AddInput(
        "Label",
T
tianshuo78520a 已提交
34
        "(Tensor) The input tensor of groud truth label. If :attr:`soft_label` "
35 36 37 38
        "is set to false, Label is a Tensor<int64> in same shape with "
        "Input(Logits) except the shape in dimension :attr:`axis` as 1. If "
        "soft_label is set to true, Label is a Tensor<float/double> in same "
        "shape with Input(Logits).");
C
caoying03 已提交
39 40
    AddOutput(
        "Softmax",
41 42
        "(Tensor, default: Tensor<float>), A tensor in same shape with "
        "Input(Logits). "
C
caoying03 已提交
43 44
        "The outputs value of softmax activation by given the input batch, "
        "which will be used in backward calculation.")
C
caoying03 已提交
45
        .AsIntermediate();
C
caoying03 已提交
46
    AddOutput("Loss",
47 48 49 50
              "(Tensor, default: Tensor<float>), A tensor in same shape with "
              "Input(Logits) "
              "except the shape in dimension :attr:`axis` as 1. The cross "
              "entropy loss.");
C
caoying03 已提交
51
    AddAttr<bool>(
52
        "soft_label",
T
tianshuo78520a 已提交
53
        "(bool, default: false), A flag to indicate whether to interpretant "
C
caoying03 已提交
54 55
        "the given labels as soft labels.")
        .SetDefault(false);
S
sneaxiy 已提交
56 57
    AddAttr<bool>(
        "numeric_stable_mode",
58
        "(bool, default: true), A flag to indicate whether to use more "
S
sneaxiy 已提交
59 60
        "numerically stable algorithm. This flag is only valid when "
        "soft_label is false and GPU is used.")
61
        .SetDefault(true);
62 63 64 65 66 67
    AddAttr<int>(
        "ignore_index",
        "(int, default -100), Specifies a target value that is ignored and"
        "does not contribute to the input gradient. Only valid if soft_label"
        "is set to False")
        .SetDefault(-100);
68 69 70 71
    AddAttr<int>("axis",
                 "The dimension index of Input(Logits) to perform softmax,"
                 "default -1 for last dimension")
        .SetDefault(-1);
72
    AddComment(R"DOC(
73 74 75
Softmax With Cross Entropy Operator.

Cross entropy loss with softmax is used as the output layer extensively. This
76
operator computes the softmax normalized values for each row of the input
77
tensor, after which cross-entropy loss is computed. This provides a more
78 79
numerically stable gradient.

80 81 82
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
83

C
caoying03 已提交
84
When the attribute soft_label is set false, this operators expects mutually
85 86
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
87

88
The equation is as follows:
89

90
1) Hard label (one-hot label, so every sample has exactly one class)
91

92
$$Loss_j =  -\text{Logit}_{Label_j} +
93
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
94
j = 1,..., K$$
C
caoying03 已提交
95

96
2) Soft label (each sample can have a distribution over all classes)
C
caoying03 已提交
97

98
$$Loss_j =  -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
99
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
100
j = 1,...,K$$
C
caoying03 已提交
101 102

)DOC");
103 104 105 106 107 108 109
  }
};

class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

110
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
111 112 113 114 115 116 117 118
    PADDLE_ENFORCE(ctx->HasInput("Logits"),
                   "Input(Logits) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");

    PADDLE_ENFORCE(ctx->HasOutput("Softmax"),
                   "Output(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");

119
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
120 121
    auto logits_dims = ctx->GetInputDim("Logits");
    auto labels_dims = ctx->GetInputDim("Label");
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    auto logits_rank = logits_dims.size();
    PADDLE_ENFORCE(axis >= -logits_rank && axis < logits_rank,
                   "Attr(axis) value should be in range [-R, R-1], "
                   "R is the rank of Input(Logits).");

    axis = CanonicalAxis(axis, logits_rank);
    for (int i = 0; i < logits_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
          PADDLE_ENFORCE_EQ(
              logits_dims[i], labels_dims[i],
              "Input(Logits) and Input(Label) should in same shape in "
              "dimensions except axis.");
        }
      }
    }
138

139 140 141 142 143
    auto numeric_stable_mode = ctx->Attrs().Get<bool>("numeric_stable_mode");
    if (axis != logits_rank - 1) {
      PADDLE_ENFORCE(
          numeric_stable_mode,
          "Attr(axis) can only be -1 when not in numeric_stable_mode.");
144
    }
145

146 147 148 149 150 151
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(logits_dims[axis], labels_dims[axis],
                          "If Attr(soft_label) == true, the axis dimension of "
152 153
                          "Input(X) and Input(Label) should be equal.");
      }
154
    } else {
155 156 157 158 159
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
        PADDLE_ENFORCE_EQ(labels_dims[axis], 1UL,
                          "If Attr(soft_label) == false, the axis dimension of "
                          "Input(Label) should be 1.");
      }
160 161
    }

Q
qiaolongfei 已提交
162
    ctx->SetOutputDim("Softmax", logits_dims);
163 164 165

    logits_dims[axis] = 1;
    ctx->SetOutputDim("Loss", logits_dims);
166

Q
qiaolongfei 已提交
167 168
    ctx->ShareLoD("Logits", /*->*/ "Softmax");
    ctx->ShareLoD("Logits", /*->*/ "Loss");
C
caoying03 已提交
169
  }
Y
Yu Yang 已提交
170

171
 protected:
172
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
173
      const framework::ExecutionContext& ctx) const override {
174 175 176
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
        ctx.device_context());
Y
Yu Yang 已提交
177
  }
C
caoying03 已提交
178 179 180 181 182 183
};

class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

184
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
185 186 187 188 189 190 191 192
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
                   "Input(Loss@Grad) should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Softmax"),
                   "Input(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
                   "Output(Logits@Grad) should be not null.");

193
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
194 195
    auto softmax_dims = ctx->GetInputDim("Softmax");
    auto labels_dims = ctx->GetInputDim("Label");
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
    auto softmax_rank = softmax_dims.size();
    PADDLE_ENFORCE(axis >= -softmax_rank && axis < softmax_rank,
                   "Attr(axis) value should be in range [-R, R-1], "
                   "R is the rank of Input(Logits).");

    axis = CanonicalAxis(axis, softmax_rank);
    for (int i = 0; i < softmax_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
          PADDLE_ENFORCE_EQ(
              softmax_dims[i], labels_dims[i],
              "Input(Logits) and Input(Label) should in same shape in "
              "dimensions except axis.");
        }
      }
211
    }
212

213 214 215 216 217 218 219
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(softmax_dims[axis], labels_dims[axis],
                          "If Attr(soft_label) == true, the axis dimension of "
                          "Input(X) and Input(Label) should be equal.");
220
      }
221
    } else {
222 223 224 225 226
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
        PADDLE_ENFORCE_EQ(labels_dims[axis], 1UL,
                          "If Attr(soft_label) == false, the axis dimension of "
                          "Input(Label) should be 1.");
      }
227
    }
C
caoying03 已提交
228

Q
qiaolongfei 已提交
229 230
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Softmax"));
231
  }
Y
Yu Yang 已提交
232

233
 protected:
234
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
235
      const framework::ExecutionContext& ctx) const override {
236 237 238
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Loss")),
                                   ctx.device_context());
Y
Yu Yang 已提交
239
  }
240 241
};

H
hong 已提交
242 243
template <typename T>
class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
244
 public:
H
hong 已提交
245
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
246 247

 protected:
248
  void Apply(GradOpPtr<T> grad_op) const override {
Y
Yu Yang 已提交
249
    grad_op->SetType("softmax_with_cross_entropy_grad");
H
hong 已提交
250 251 252 253 254 255
    grad_op->SetInput("Label", this->Input("Label"));
    grad_op->SetInput("Softmax", this->Output("Softmax"));
    grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
    grad_op->SetOutput(framework::GradVarName("Logits"),
                       this->InputGrad("Logits"));
    grad_op->SetAttrMap(this->Attrs());
256 257 258
  }
};

259 260
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInference,
                           {"Logits", "Softmax"});
261

262 263
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInference,
                           {"Softmax", framework::GradVarName("Logits")});
Z
Zeng Jinle 已提交
264

265 266 267 268 269
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

270
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
H
hong 已提交
271 272 273
                  ops::SoftmaxWithCrossEntropyOpMaker,
                  ops::SoftmaxGradMaker<paddle::framework::OpDesc>,
                  ops::SoftmaxGradMaker<paddle::imperative::OpBase>,
274
                  ops::SoftmaxWithCrossEntropyInplaceInference);
275
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
Z
Zeng Jinle 已提交
276 277
                  ops::SoftmaxWithCrossEntropyOpGrad,
                  ops::SoftmaxWithCrossEntropyGradInplaceInference);
278
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
C
caoying03 已提交
279 280
                       ops::SoftmaxWithCrossEntropyKernel<float>,
                       ops::SoftmaxWithCrossEntropyKernel<double>);
281
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
C
caoying03 已提交
282 283
                       ops::SoftmaxWithCrossEntropyGradKernel<float>,
                       ops::SoftmaxWithCrossEntropyGradKernel<double>);