softmax_with_cross_entropy_op.cc 14.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

15
#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/framework/op_version_registry.h"
17
#include "paddle/phi/kernels/funcs/axis_utils.h"
Y
Yu Yang 已提交
18

19 20 21 22 23 24
namespace paddle {
namespace operators {

class SoftmaxWithCrossEntropyOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
25
  void Make() override {
C
caoying03 已提交
26
    AddInput("Logits",
27 28 29 30 31
             "(Tensor, default: Tensor<float>), The input tensor of unscaled "
             "log probabilities, whose dimension :attr:`axis` should be scaled "
             "by softmax.");
    AddInput(
        "Label",
T
tianshuo78520a 已提交
32
        "(Tensor) The input tensor of groud truth label. If :attr:`soft_label` "
33 34 35 36
        "is set to false, Label is a Tensor<int64> in same shape with "
        "Input(Logits) except the shape in dimension :attr:`axis` as 1. If "
        "soft_label is set to true, Label is a Tensor<float/double> in same "
        "shape with Input(Logits).");
C
caoying03 已提交
37 38
    AddOutput(
        "Softmax",
39 40
        "(Tensor, default: Tensor<float>), A tensor in same shape with "
        "Input(Logits). "
C
caoying03 已提交
41 42
        "The outputs value of softmax activation by given the input batch, "
        "which will be used in backward calculation.")
C
caoying03 已提交
43
        .AsIntermediate();
44
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
45 46 47 48 49 50 51 52
    AddOutput(
        "Backprop",
        "(Tensor, default: Tensor<float>), A tensor in same shape with "
        "Input(Logits). "
        "The intermediate value used for backward calculation. The calculation "
        "is :"
        "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, "
        "where labels is ont-hot."
53
        "Currently, the tensor is generated and used in npu/mlu kernel. ")
54
        .AsIntermediate();
55
#endif
C
caoying03 已提交
56
    AddOutput("Loss",
57 58 59 60
              "(Tensor, default: Tensor<float>), A tensor in same shape with "
              "Input(Logits) "
              "except the shape in dimension :attr:`axis` as 1. The cross "
              "entropy loss.");
C
caoying03 已提交
61
    AddAttr<bool>(
62
        "soft_label",
T
tianshuo78520a 已提交
63
        "(bool, default: false), A flag to indicate whether to interpretant "
C
caoying03 已提交
64 65
        "the given labels as soft labels.")
        .SetDefault(false);
66
    AddAttr<bool>(
67
        "use_softmax",
68 69
        "(bool, default: true), A flag to indicate whether to do softmax ")
        .SetDefault(true);
S
sneaxiy 已提交
70 71
    AddAttr<bool>(
        "numeric_stable_mode",
72
        "(bool, default: true), A flag to indicate whether to use more "
S
sneaxiy 已提交
73 74
        "numerically stable algorithm. This flag is only valid when "
        "soft_label is false and GPU is used.")
75
        .SetDefault(true);
76 77 78 79 80 81
    AddAttr<int>(
        "ignore_index",
        "(int, default -100), Specifies a target value that is ignored and"
        "does not contribute to the input gradient. Only valid if soft_label"
        "is set to False")
        .SetDefault(-100);
82 83 84 85
    AddAttr<int>("axis",
                 "The dimension index of Input(Logits) to perform softmax,"
                 "default -1 for last dimension")
        .SetDefault(-1);
86
    AddComment(R"DOC(
87 88 89
Softmax With Cross Entropy Operator.

Cross entropy loss with softmax is used as the output layer extensively. This
90
operator computes the softmax normalized values for each row of the input
91
tensor, after which cross-entropy loss is computed. This provides a more
92 93
numerically stable gradient.

94 95 96
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
97

C
caoying03 已提交
98
When the attribute soft_label is set false, this operators expects mutually
99 100
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
101

102
The equation is as follows:
103

104
1) Hard label (one-hot label, so every sample has exactly one class)
105

106
$$Loss_j =  -\text{Logit}_{Label_j} +
107
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
108
j = 1,..., K$$
C
caoying03 已提交
109

110
2) Soft label (each sample can have a distribution over all classes)
C
caoying03 已提交
111

112
$$Loss_j =  -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
113
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
114
j = 1,...,K$$
C
caoying03 已提交
115 116

)DOC");
117 118 119 120 121 122 123
  }
};

class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

124
  void InferShape(framework::InferShapeContext* ctx) const override {
125 126 127 128 129 130 131 132 133 134
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Logits"), true,
        platform::errors::InvalidArgument("Input(Logits) should be not null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Label"), true,
        platform::errors::InvalidArgument("Input(Label) should be not null."));

    PADDLE_ENFORCE_EQ(ctx->HasOutput("Softmax"), true,
                      platform::errors::InvalidArgument(
                          "Output(Softmax) should be not null."));
135
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
136 137 138 139
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Backprop"), true,
                      platform::errors::InvalidArgument(
                          "Output(Backprop) should be not null."));
#endif
140 141 142
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("Loss"), true,
        platform::errors::InvalidArgument("Output(Loss) should be not null."));
Q
qiaolongfei 已提交
143

144
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
145 146
    auto logits_dims = ctx->GetInputDim("Logits");
    auto labels_dims = ctx->GetInputDim("Label");
147
    auto logits_rank = logits_dims.size();
148 149 150 151 152 153 154 155
    PADDLE_ENFORCE_GE(axis, -logits_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
    PADDLE_ENFORCE_LT(axis, logits_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
156

157
    axis = phi::funcs::CanonicalAxis(axis, logits_rank);
158 159 160
    for (int i = 0; i < logits_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
161 162 163 164
          PADDLE_ENFORCE_EQ(logits_dims[i], labels_dims[i],
                            platform::errors::InvalidArgument(
                                "Input(Logits) and Input(Label) should in "
                                "same shape in dimensions except axis."));
165 166 167
        }
      }
    }
168

169 170
    auto numeric_stable_mode = ctx->Attrs().Get<bool>("numeric_stable_mode");
    if (axis != logits_rank - 1) {
171 172 173 174
      PADDLE_ENFORCE_EQ(numeric_stable_mode, true,
                        platform::errors::InvalidArgument(
                            "Attr(axis) can only be -1 "
                            "when not in numeric_stable_mode."));
175
    }
176

177 178 179 180 181
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(logits_dims[axis], labels_dims[axis],
182 183 184 185
                          platform::errors::InvalidArgument(
                              "If Attr(soft_label) == true,  "
                              "the axis dimension of "
                              "Input(X) and Input(Label) should be equal."));
186
      }
187
    } else {
188
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
189 190 191 192 193
        PADDLE_ENFORCE_EQ(
            labels_dims[axis], 1UL,
            platform::errors::InvalidArgument("If Attr(soft_label) == false, "
                                              "the axis dimension of "
                                              "Input(Label) should be 1."));
194
      }
195 196
    }

Q
qiaolongfei 已提交
197
    ctx->SetOutputDim("Softmax", logits_dims);
198
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
199 200 201
    ctx->SetOutputDim("Backprop", logits_dims);
    ctx->ShareLoD("Logits", /*->*/ "Backprop");
#endif
202 203
    logits_dims[axis] = 1;
    ctx->SetOutputDim("Loss", logits_dims);
204

Q
qiaolongfei 已提交
205 206
    ctx->ShareLoD("Logits", /*->*/ "Softmax");
    ctx->ShareLoD("Logits", /*->*/ "Loss");
C
caoying03 已提交
207
  }
Y
Yu Yang 已提交
208

209
 protected:
210
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
211
      const framework::ExecutionContext& ctx) const override {
212 213 214
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
        ctx.device_context());
Y
Yu Yang 已提交
215
  }
C
caoying03 已提交
216 217 218 219 220 221
};

class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

222
  void InferShape(framework::InferShapeContext* ctx) const override {
223 224 225 226 227 228
    PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Loss")), true,
                      platform::errors::InvalidArgument(
                          "Input(Loss@Grad) should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"), true,
                      platform::errors::InvalidArgument(
                          "Input(Softmax) should be not null."));
229
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
230 231 232 233
    PADDLE_ENFORCE_EQ(ctx->HasInput("Backprop"), true,
                      platform::errors::InvalidArgument(
                          "Input(Backprop) should be not null."));
#endif
234 235 236 237 238 239
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Label"), true,
        platform::errors::InvalidArgument("Input(Label) should be not null."));
    PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Logits")), true,
                      platform::errors::InvalidArgument(
                          "Output(Logits@Grad) should be not null."));
Q
qiaolongfei 已提交
240

241
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
242 243
    auto softmax_dims = ctx->GetInputDim("Softmax");
    auto labels_dims = ctx->GetInputDim("Label");
244
    auto softmax_rank = softmax_dims.size();
245 246 247 248 249 250 251 252
    PADDLE_ENFORCE_GE(axis, -softmax_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
    PADDLE_ENFORCE_LT(axis, softmax_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
253

254
    axis = phi::funcs::CanonicalAxis(axis, softmax_rank);
255 256 257 258 259
    for (int i = 0; i < softmax_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
          PADDLE_ENFORCE_EQ(
              softmax_dims[i], labels_dims[i],
260 261 262
              platform::errors::InvalidArgument(
                  "Input(Logits) and Input(Label) should in same shape in "
                  "dimensions except axis."));
263 264
        }
      }
265
    }
266

267 268 269 270 271
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(softmax_dims[axis], labels_dims[axis],
272 273 274 275
                          platform::errors::InvalidArgument(
                              "If Attr(soft_label) == true, "
                              "the axis dimension of "
                              "Input(X) and Input(Label) should be equal."));
276
      }
277
    } else {
278
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
279 280 281 282 283
        PADDLE_ENFORCE_EQ(
            labels_dims[axis], 1UL,
            platform::errors::InvalidArgument("If Attr(soft_label) == false, "
                                              "the axis dimension of "
                                              "Input(Label) should be 1."));
284
      }
285
    }
C
caoying03 已提交
286

Q
qiaolongfei 已提交
287 288
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Softmax"));
289
  }
Y
Yu Yang 已提交
290

291
 protected:
292
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
293
      const framework::ExecutionContext& ctx) const override {
294 295 296
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Loss")),
                                   ctx.device_context());
Y
Yu Yang 已提交
297
  }
298 299
};

H
hong 已提交
300 301
template <typename T>
class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
302
 public:
H
hong 已提交
303
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
304 305

 protected:
306
  void Apply(GradOpPtr<T> grad_op) const override {
Y
Yu Yang 已提交
307
    grad_op->SetType("softmax_with_cross_entropy_grad");
H
hong 已提交
308 309
    grad_op->SetInput("Label", this->Input("Label"));
    grad_op->SetInput("Softmax", this->Output("Softmax"));
310
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
311 312
    grad_op->SetInput("Backprop", this->Output("Backprop"));
#endif
H
hong 已提交
313 314 315 316
    grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
    grad_op->SetOutput(framework::GradVarName("Logits"),
                       this->InputGrad("Logits"));
    grad_op->SetAttrMap(this->Attrs());
317 318 319
  }
};

320
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInferer,
321
                           {"Logits", "Softmax"});
322

323
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInferer,
324
                           {"Softmax", framework::GradVarName("Logits")});
Z
Zeng Jinle 已提交
325

326 327 328 329 330
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

331
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
H
hong 已提交
332 333 334
                  ops::SoftmaxWithCrossEntropyOpMaker,
                  ops::SoftmaxGradMaker<paddle::framework::OpDesc>,
                  ops::SoftmaxGradMaker<paddle::imperative::OpBase>,
335
                  ops::SoftmaxWithCrossEntropyInplaceInferer);
336
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
Z
Zeng Jinle 已提交
337
                  ops::SoftmaxWithCrossEntropyOpGrad,
338
                  ops::SoftmaxWithCrossEntropyGradInplaceInferer);
339

340
REGISTER_OP_VERSION(softmax_with_cross_entropy)
341
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
342 343 344 345 346 347 348 349 350 351 352 353 354 355
    .AddCheckpoint(
        R"ROC(
              Add a new attribute [use_softmax] )ROC",
        paddle::framework::compatible::OpVersionDesc().NewAttr(
            "use_softmax", "A flag to indicate whether to do softmax", true))
    .AddCheckpoint(
        R"ROC(
                Add a new dispensable/intermediate output [backprop] )ROC",
        paddle::framework::compatible::OpVersionDesc().NewOutput(
            "Backprop",
            "The intermediate value used for backward calculation. The "
            "calculation is :"
            "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, "
            "where labels is ont-hot."
356
            "Currently, the tensor is generated and used in npu/mlu kernel. "));
357
#else
358 359
    .AddCheckpoint(
        R"ROC(
360
              Add a new attribute [use_softmax] )ROC",
361
        paddle::framework::compatible::OpVersionDesc().NewAttr(
362
            "use_softmax", "A flag to indicate whether to do softmax", true));
363
#endif