softmax_with_cross_entropy_op.cc 15.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

15
#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/framework/op_version_registry.h"
17
#include "paddle/phi/kernels/funcs/axis_utils.h"
Y
Yu Yang 已提交
18

19 20 21 22 23 24
namespace paddle {
namespace operators {

class SoftmaxWithCrossEntropyOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
25
  void Make() override {
C
caoying03 已提交
26
    AddInput("Logits",
27 28 29 30 31
             "(Tensor, default: Tensor<float>), The input tensor of unscaled "
             "log probabilities, whose dimension :attr:`axis` should be scaled "
             "by softmax.");
    AddInput(
        "Label",
T
tianshuo78520a 已提交
32
        "(Tensor) The input tensor of groud truth label. If :attr:`soft_label` "
33 34 35 36
        "is set to false, Label is a Tensor<int64> in same shape with "
        "Input(Logits) except the shape in dimension :attr:`axis` as 1. If "
        "soft_label is set to true, Label is a Tensor<float/double> in same "
        "shape with Input(Logits).");
C
caoying03 已提交
37 38
    AddOutput(
        "Softmax",
39 40
        "(Tensor, default: Tensor<float>), A tensor in same shape with "
        "Input(Logits). "
C
caoying03 已提交
41 42
        "The outputs value of softmax activation by given the input batch, "
        "which will be used in backward calculation.")
C
caoying03 已提交
43
        .AsIntermediate();
44
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
45 46 47 48 49 50 51 52
    AddOutput(
        "Backprop",
        "(Tensor, default: Tensor<float>), A tensor in same shape with "
        "Input(Logits). "
        "The intermediate value used for backward calculation. The calculation "
        "is :"
        "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, "
        "where labels is ont-hot."
53
        "Currently, the tensor is generated and used in npu/mlu kernel. ")
54
        .AsIntermediate();
55
#endif
C
caoying03 已提交
56
    AddOutput("Loss",
57 58 59 60
              "(Tensor, default: Tensor<float>), A tensor in same shape with "
              "Input(Logits) "
              "except the shape in dimension :attr:`axis` as 1. The cross "
              "entropy loss.");
C
caoying03 已提交
61
    AddAttr<bool>(
62
        "soft_label",
T
tianshuo78520a 已提交
63
        "(bool, default: false), A flag to indicate whether to interpretant "
C
caoying03 已提交
64 65
        "the given labels as soft labels.")
        .SetDefault(false);
66
    AddAttr<bool>(
67
        "use_softmax",
68 69
        "(bool, default: true), A flag to indicate whether to do softmax ")
        .SetDefault(true);
S
sneaxiy 已提交
70 71
    AddAttr<bool>(
        "numeric_stable_mode",
72
        "(bool, default: true), A flag to indicate whether to use more "
S
sneaxiy 已提交
73 74
        "numerically stable algorithm. This flag is only valid when "
        "soft_label is false and GPU is used.")
75
        .SetDefault(true);
76 77 78 79 80 81
    AddAttr<int>(
        "ignore_index",
        "(int, default -100), Specifies a target value that is ignored and"
        "does not contribute to the input gradient. Only valid if soft_label"
        "is set to False")
        .SetDefault(-100);
82 83 84 85
    AddAttr<int>("axis",
                 "The dimension index of Input(Logits) to perform softmax,"
                 "default -1 for last dimension")
        .SetDefault(-1);
86
    AddComment(R"DOC(
87 88 89
Softmax With Cross Entropy Operator.

Cross entropy loss with softmax is used as the output layer extensively. This
90
operator computes the softmax normalized values for each row of the input
91
tensor, after which cross-entropy loss is computed. This provides a more
92 93
numerically stable gradient.

94 95 96
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
97

C
caoying03 已提交
98
When the attribute soft_label is set false, this operators expects mutually
99 100
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
101

102
The equation is as follows:
103

104
1) Hard label (one-hot label, so every sample has exactly one class)
105

106
$$Loss_j =  -\text{Logit}_{Label_j} +
107
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
108
j = 1,..., K$$
C
caoying03 已提交
109

110
2) Soft label (each sample can have a distribution over all classes)
C
caoying03 已提交
111

112
$$Loss_j =  -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
113
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
114
j = 1,...,K$$
C
caoying03 已提交
115 116

)DOC");
117 118 119 120 121 122 123
  }
};

class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

124
  void InferShape(framework::InferShapeContext* ctx) const override {
125
    PADDLE_ENFORCE_EQ(
126 127
        ctx->HasInput("Logits"),
        true,
128 129
        platform::errors::InvalidArgument("Input(Logits) should be not null."));
    PADDLE_ENFORCE_EQ(
130 131
        ctx->HasInput("Label"),
        true,
132 133
        platform::errors::InvalidArgument("Input(Label) should be not null."));

134 135
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Softmax"),
                      true,
136 137
                      platform::errors::InvalidArgument(
                          "Output(Softmax) should be not null."));
138
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
139 140
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Backprop"),
                      true,
141 142 143
                      platform::errors::InvalidArgument(
                          "Output(Backprop) should be not null."));
#endif
144
    PADDLE_ENFORCE_EQ(
145 146
        ctx->HasOutput("Loss"),
        true,
147
        platform::errors::InvalidArgument("Output(Loss) should be not null."));
Q
qiaolongfei 已提交
148

149
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
150 151
    auto logits_dims = ctx->GetInputDim("Logits");
    auto labels_dims = ctx->GetInputDim("Label");
152
    auto logits_rank = logits_dims.size();
153 154
    PADDLE_ENFORCE_GE(axis,
                      -logits_rank,
155 156 157
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
158 159
    PADDLE_ENFORCE_LT(axis,
                      logits_rank,
160 161 162
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
163

164
    axis = phi::funcs::CanonicalAxis(axis, logits_rank);
165 166 167 168 169 170 171

    PADDLE_ENFORCE_EQ(logits_dims.size(),
                      labels_dims.size(),
                      platform::errors::InvalidArgument(
                          "Input(Logits) and Input(Label) should in "
                          "same dimensions size."));

172 173 174
    for (int i = 0; i < logits_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
175 176
          PADDLE_ENFORCE_EQ(logits_dims[i],
                            labels_dims[i],
177 178 179
                            platform::errors::InvalidArgument(
                                "Input(Logits) and Input(Label) should in "
                                "same shape in dimensions except axis."));
180 181 182
        }
      }
    }
183

184 185
    auto numeric_stable_mode = ctx->Attrs().Get<bool>("numeric_stable_mode");
    if (axis != logits_rank - 1) {
186 187
      PADDLE_ENFORCE_EQ(numeric_stable_mode,
                        true,
188 189 190
                        platform::errors::InvalidArgument(
                            "Attr(axis) can only be -1 "
                            "when not in numeric_stable_mode."));
191
    }
192

193 194 195 196
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
197 198
        PADDLE_ENFORCE_EQ(logits_dims[axis],
                          labels_dims[axis],
199 200 201 202
                          platform::errors::InvalidArgument(
                              "If Attr(soft_label) == true,  "
                              "the axis dimension of "
                              "Input(X) and Input(Label) should be equal."));
203
      }
204
    } else {
205
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
206
        PADDLE_ENFORCE_EQ(
207 208
            labels_dims[axis],
            1UL,
209 210 211
            platform::errors::InvalidArgument("If Attr(soft_label) == false, "
                                              "the axis dimension of "
                                              "Input(Label) should be 1."));
212
      }
213 214
    }

Q
qiaolongfei 已提交
215
    ctx->SetOutputDim("Softmax", logits_dims);
216
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
217 218 219
    ctx->SetOutputDim("Backprop", logits_dims);
    ctx->ShareLoD("Logits", /*->*/ "Backprop");
#endif
220 221
    logits_dims[axis] = 1;
    ctx->SetOutputDim("Loss", logits_dims);
222

Q
qiaolongfei 已提交
223 224
    ctx->ShareLoD("Logits", /*->*/ "Softmax");
    ctx->ShareLoD("Logits", /*->*/ "Loss");
C
caoying03 已提交
225
  }
Y
Yu Yang 已提交
226

227
 protected:
228
  phi::KernelKey GetExpectedKernelType(
Y
Yu Yang 已提交
229
      const framework::ExecutionContext& ctx) const override {
230 231
    return phi::KernelKey(
        OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace());
Y
Yu Yang 已提交
232
  }
C
caoying03 已提交
233 234 235 236 237 238
};

class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

239
  void InferShape(framework::InferShapeContext* ctx) const override {
240 241
    PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Loss")),
                      true,
242 243
                      platform::errors::InvalidArgument(
                          "Input(Loss@Grad) should not be null."));
244 245
    PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"),
                      true,
246 247
                      platform::errors::InvalidArgument(
                          "Input(Softmax) should be not null."));
248
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
249 250
    PADDLE_ENFORCE_EQ(ctx->HasInput("Backprop"),
                      true,
251 252 253
                      platform::errors::InvalidArgument(
                          "Input(Backprop) should be not null."));
#endif
254
    PADDLE_ENFORCE_EQ(
255 256
        ctx->HasInput("Label"),
        true,
257
        platform::errors::InvalidArgument("Input(Label) should be not null."));
258 259
    PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Logits")),
                      true,
260 261
                      platform::errors::InvalidArgument(
                          "Output(Logits@Grad) should be not null."));
Q
qiaolongfei 已提交
262

263
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
264 265
    auto softmax_dims = ctx->GetInputDim("Softmax");
    auto labels_dims = ctx->GetInputDim("Label");
266
    auto softmax_rank = softmax_dims.size();
267 268
    PADDLE_ENFORCE_GE(axis,
                      -softmax_rank,
269 270 271
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
272 273
    PADDLE_ENFORCE_LT(axis,
                      softmax_rank,
274 275 276
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
277

278
    axis = phi::funcs::CanonicalAxis(axis, softmax_rank);
279 280 281 282
    for (int i = 0; i < softmax_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
          PADDLE_ENFORCE_EQ(
283 284
              softmax_dims[i],
              labels_dims[i],
285 286 287
              platform::errors::InvalidArgument(
                  "Input(Logits) and Input(Label) should in same shape in "
                  "dimensions except axis."));
288 289
        }
      }
290
    }
291

292 293 294 295
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
296 297
        PADDLE_ENFORCE_EQ(softmax_dims[axis],
                          labels_dims[axis],
298 299 300 301
                          platform::errors::InvalidArgument(
                              "If Attr(soft_label) == true, "
                              "the axis dimension of "
                              "Input(X) and Input(Label) should be equal."));
302
      }
303
    } else {
304
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
305
        PADDLE_ENFORCE_EQ(
306 307
            labels_dims[axis],
            1UL,
308 309 310
            platform::errors::InvalidArgument("If Attr(soft_label) == false, "
                                              "the axis dimension of "
                                              "Input(Label) should be 1."));
311
      }
312
    }
C
caoying03 已提交
313

Q
qiaolongfei 已提交
314 315
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Softmax"));
316
  }
Y
Yu Yang 已提交
317

318
 protected:
319
  phi::KernelKey GetExpectedKernelType(
Y
Yu Yang 已提交
320
      const framework::ExecutionContext& ctx) const override {
321 322 323
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
                              ctx, framework::GradVarName("Loss")),
                          ctx.GetPlace());
Y
Yu Yang 已提交
324
  }
325 326
};

H
hong 已提交
327 328
template <typename T>
class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
329
 public:
H
hong 已提交
330
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
331 332

 protected:
333
  void Apply(GradOpPtr<T> grad_op) const override {
Y
Yu Yang 已提交
334
    grad_op->SetType("softmax_with_cross_entropy_grad");
H
hong 已提交
335 336
    grad_op->SetInput("Label", this->Input("Label"));
    grad_op->SetInput("Softmax", this->Output("Softmax"));
337
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
338 339
    grad_op->SetInput("Backprop", this->Output("Backprop"));
#endif
H
hong 已提交
340 341 342 343
    grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
    grad_op->SetOutput(framework::GradVarName("Logits"),
                       this->InputGrad("Logits"));
    grad_op->SetAttrMap(this->Attrs());
344 345 346
  }
};

347
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInferer,
348
                           {"Logits", "Softmax"});
349

350
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInferer,
351
                           {"Softmax", framework::GradVarName("Logits")});
Z
Zeng Jinle 已提交
352

353 354 355 356 357
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

358 359
REGISTER_OPERATOR(softmax_with_cross_entropy,
                  ops::SoftmaxWithCrossEntropyOp,
H
hong 已提交
360 361 362
                  ops::SoftmaxWithCrossEntropyOpMaker,
                  ops::SoftmaxGradMaker<paddle::framework::OpDesc>,
                  ops::SoftmaxGradMaker<paddle::imperative::OpBase>,
363
                  ops::SoftmaxWithCrossEntropyInplaceInferer);
364
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
Z
Zeng Jinle 已提交
365
                  ops::SoftmaxWithCrossEntropyOpGrad,
366
                  ops::SoftmaxWithCrossEntropyGradInplaceInferer);
367

368
REGISTER_OP_VERSION(softmax_with_cross_entropy)
369
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
370 371 372 373 374 375 376 377 378 379 380 381 382 383
    .AddCheckpoint(
        R"ROC(
              Add a new attribute [use_softmax] )ROC",
        paddle::framework::compatible::OpVersionDesc().NewAttr(
            "use_softmax", "A flag to indicate whether to do softmax", true))
    .AddCheckpoint(
        R"ROC(
                Add a new dispensable/intermediate output [backprop] )ROC",
        paddle::framework::compatible::OpVersionDesc().NewOutput(
            "Backprop",
            "The intermediate value used for backward calculation. The "
            "calculation is :"
            "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, "
            "where labels is ont-hot."
384
            "Currently, the tensor is generated and used in npu/mlu kernel. "));
385
#else
386 387
    .AddCheckpoint(
        R"ROC(
388
              Add a new attribute [use_softmax] )ROC",
389
        paddle::framework::compatible::OpVersionDesc().NewAttr(
390
            "use_softmax", "A flag to indicate whether to do softmax", true));
391
#endif