softmax_with_cross_entropy_op.cc 14.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
16
#include <memory>
Z
Zeng Jinle 已提交
17 18 19
#include <string>
#include <unordered_map>
#include <vector>
20
#include "paddle/fluid/framework/op_version_registry.h"
Y
Yu Yang 已提交
21

22 23 24 25 26 27
namespace paddle {
namespace operators {

class SoftmaxWithCrossEntropyOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
28
  void Make() override {
C
caoying03 已提交
29
    AddInput("Logits",
30 31 32 33 34
             "(Tensor, default: Tensor<float>), The input tensor of unscaled "
             "log probabilities, whose dimension :attr:`axis` should be scaled "
             "by softmax.");
    AddInput(
        "Label",
T
tianshuo78520a 已提交
35
        "(Tensor) The input tensor of groud truth label. If :attr:`soft_label` "
36 37 38 39
        "is set to false, Label is a Tensor<int64> in same shape with "
        "Input(Logits) except the shape in dimension :attr:`axis` as 1. If "
        "soft_label is set to true, Label is a Tensor<float/double> in same "
        "shape with Input(Logits).");
C
caoying03 已提交
40 41
    AddOutput(
        "Softmax",
42 43
        "(Tensor, default: Tensor<float>), A tensor in same shape with "
        "Input(Logits). "
C
caoying03 已提交
44 45
        "The outputs value of softmax activation by given the input batch, "
        "which will be used in backward calculation.")
C
caoying03 已提交
46
        .AsIntermediate();
47 48 49 50 51 52 53 54 55 56 57 58 59
#ifdef PADDLE_WITH_ASCEND_CL
    AddOutput(
        "Backprop",
        "(Tensor, default: Tensor<float>), A tensor in same shape with "
        "Input(Logits). "
        "The intermediate value used for backward calculation. The calculation "
        "is :"
        "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, "
        "where labels is ont-hot."
        "Currently, the tensor is generated and used in npu kernel only. ")
        .AsIntermediate()
        .AsDispensable();
#endif
C
caoying03 已提交
60
    AddOutput("Loss",
61 62 63 64
              "(Tensor, default: Tensor<float>), A tensor in same shape with "
              "Input(Logits) "
              "except the shape in dimension :attr:`axis` as 1. The cross "
              "entropy loss.");
C
caoying03 已提交
65
    AddAttr<bool>(
66
        "soft_label",
T
tianshuo78520a 已提交
67
        "(bool, default: false), A flag to indicate whether to interpretant "
C
caoying03 已提交
68 69
        "the given labels as soft labels.")
        .SetDefault(false);
70
    AddAttr<bool>(
71
        "use_softmax",
72 73
        "(bool, default: true), A flag to indicate whether to do softmax ")
        .SetDefault(true);
S
sneaxiy 已提交
74 75
    AddAttr<bool>(
        "numeric_stable_mode",
76
        "(bool, default: true), A flag to indicate whether to use more "
S
sneaxiy 已提交
77 78
        "numerically stable algorithm. This flag is only valid when "
        "soft_label is false and GPU is used.")
79
        .SetDefault(true);
80 81 82 83 84 85
    AddAttr<int>(
        "ignore_index",
        "(int, default -100), Specifies a target value that is ignored and"
        "does not contribute to the input gradient. Only valid if soft_label"
        "is set to False")
        .SetDefault(-100);
86 87 88 89
    AddAttr<int>("axis",
                 "The dimension index of Input(Logits) to perform softmax,"
                 "default -1 for last dimension")
        .SetDefault(-1);
90
    AddComment(R"DOC(
91 92 93
Softmax With Cross Entropy Operator.

Cross entropy loss with softmax is used as the output layer extensively. This
94
operator computes the softmax normalized values for each row of the input
95
tensor, after which cross-entropy loss is computed. This provides a more
96 97
numerically stable gradient.

98 99 100
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
101

C
caoying03 已提交
102
When the attribute soft_label is set false, this operators expects mutually
103 104
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
105

106
The equation is as follows:
107

108
1) Hard label (one-hot label, so every sample has exactly one class)
109

110
$$Loss_j =  -\text{Logit}_{Label_j} +
111
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
112
j = 1,..., K$$
C
caoying03 已提交
113

114
2) Soft label (each sample can have a distribution over all classes)
C
caoying03 已提交
115

116
$$Loss_j =  -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
117
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
118
j = 1,...,K$$
C
caoying03 已提交
119 120

)DOC");
121 122 123 124 125 126 127
  }
};

class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

128
  void InferShape(framework::InferShapeContext* ctx) const override {
129 130 131 132 133 134 135 136 137 138 139 140 141
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Logits"), true,
        platform::errors::InvalidArgument("Input(Logits) should be not null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Label"), true,
        platform::errors::InvalidArgument("Input(Label) should be not null."));

    PADDLE_ENFORCE_EQ(ctx->HasOutput("Softmax"), true,
                      platform::errors::InvalidArgument(
                          "Output(Softmax) should be not null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("Loss"), true,
        platform::errors::InvalidArgument("Output(Loss) should be not null."));
Q
qiaolongfei 已提交
142

143
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
144 145
    auto logits_dims = ctx->GetInputDim("Logits");
    auto labels_dims = ctx->GetInputDim("Label");
146
    auto logits_rank = logits_dims.size();
147 148 149 150 151 152 153 154
    PADDLE_ENFORCE_GE(axis, -logits_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
    PADDLE_ENFORCE_LT(axis, logits_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
155 156 157 158 159

    axis = CanonicalAxis(axis, logits_rank);
    for (int i = 0; i < logits_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
160 161 162 163
          PADDLE_ENFORCE_EQ(logits_dims[i], labels_dims[i],
                            platform::errors::InvalidArgument(
                                "Input(Logits) and Input(Label) should in "
                                "same shape in dimensions except axis."));
164 165 166
        }
      }
    }
167

168 169
    auto numeric_stable_mode = ctx->Attrs().Get<bool>("numeric_stable_mode");
    if (axis != logits_rank - 1) {
170 171 172 173
      PADDLE_ENFORCE_EQ(numeric_stable_mode, true,
                        platform::errors::InvalidArgument(
                            "Attr(axis) can only be -1 "
                            "when not in numeric_stable_mode."));
174
    }
175

176 177 178 179 180
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(logits_dims[axis], labels_dims[axis],
181 182 183 184
                          platform::errors::InvalidArgument(
                              "If Attr(soft_label) == true,  "
                              "the axis dimension of "
                              "Input(X) and Input(Label) should be equal."));
185
      }
186
    } else {
187
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
188 189 190 191 192
        PADDLE_ENFORCE_EQ(
            labels_dims[axis], 1UL,
            platform::errors::InvalidArgument("If Attr(soft_label) == false, "
                                              "the axis dimension of "
                                              "Input(Label) should be 1."));
193
      }
194 195
    }

Q
qiaolongfei 已提交
196
    ctx->SetOutputDim("Softmax", logits_dims);
197 198 199 200
#ifdef PADDLE_WITH_ASCEND_CL
    ctx->SetOutputDim("Backprop", logits_dims);
    ctx->ShareLoD("Logits", /*->*/ "Backprop");
#endif
201 202
    logits_dims[axis] = 1;
    ctx->SetOutputDim("Loss", logits_dims);
203

Q
qiaolongfei 已提交
204 205
    ctx->ShareLoD("Logits", /*->*/ "Softmax");
    ctx->ShareLoD("Logits", /*->*/ "Loss");
C
caoying03 已提交
206
  }
Y
Yu Yang 已提交
207

208
 protected:
209
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
210
      const framework::ExecutionContext& ctx) const override {
211 212 213
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
        ctx.device_context());
Y
Yu Yang 已提交
214
  }
C
caoying03 已提交
215 216 217 218 219 220
};

class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

221
  void InferShape(framework::InferShapeContext* ctx) const override {
222 223 224 225 226 227 228 229 230 231 232 233
    PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Loss")), true,
                      platform::errors::InvalidArgument(
                          "Input(Loss@Grad) should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"), true,
                      platform::errors::InvalidArgument(
                          "Input(Softmax) should be not null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Label"), true,
        platform::errors::InvalidArgument("Input(Label) should be not null."));
    PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Logits")), true,
                      platform::errors::InvalidArgument(
                          "Output(Logits@Grad) should be not null."));
Q
qiaolongfei 已提交
234

235
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
236 237
    auto softmax_dims = ctx->GetInputDim("Softmax");
    auto labels_dims = ctx->GetInputDim("Label");
238
    auto softmax_rank = softmax_dims.size();
239 240 241 242 243 244 245 246
    PADDLE_ENFORCE_GE(axis, -softmax_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
    PADDLE_ENFORCE_LT(axis, softmax_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
247 248 249 250 251 252 253

    axis = CanonicalAxis(axis, softmax_rank);
    for (int i = 0; i < softmax_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
          PADDLE_ENFORCE_EQ(
              softmax_dims[i], labels_dims[i],
254 255 256
              platform::errors::InvalidArgument(
                  "Input(Logits) and Input(Label) should in same shape in "
                  "dimensions except axis."));
257 258
        }
      }
259
    }
260

261 262 263 264 265
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(softmax_dims[axis], labels_dims[axis],
266 267 268 269
                          platform::errors::InvalidArgument(
                              "If Attr(soft_label) == true, "
                              "the axis dimension of "
                              "Input(X) and Input(Label) should be equal."));
270
      }
271
    } else {
272
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
273 274 275 276 277
        PADDLE_ENFORCE_EQ(
            labels_dims[axis], 1UL,
            platform::errors::InvalidArgument("If Attr(soft_label) == false, "
                                              "the axis dimension of "
                                              "Input(Label) should be 1."));
278
      }
279
    }
C
caoying03 已提交
280

Q
qiaolongfei 已提交
281 282
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Softmax"));
283
  }
Y
Yu Yang 已提交
284

285
 protected:
286
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
287
      const framework::ExecutionContext& ctx) const override {
288 289 290
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Loss")),
                                   ctx.device_context());
Y
Yu Yang 已提交
291
  }
292 293
};

H
hong 已提交
294 295
template <typename T>
class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
296
 public:
H
hong 已提交
297
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
298 299

 protected:
300
  void Apply(GradOpPtr<T> grad_op) const override {
Y
Yu Yang 已提交
301
    grad_op->SetType("softmax_with_cross_entropy_grad");
H
hong 已提交
302 303
    grad_op->SetInput("Label", this->Input("Label"));
    grad_op->SetInput("Softmax", this->Output("Softmax"));
304 305 306
#ifdef PADDLE_WITH_ASCEND_CL
    grad_op->SetInput("Backprop", this->Output("Backprop"));
#endif
H
hong 已提交
307 308 309 310
    grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
    grad_op->SetOutput(framework::GradVarName("Logits"),
                       this->InputGrad("Logits"));
    grad_op->SetAttrMap(this->Attrs());
311 312 313
  }
};

314
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInferer,
315
                           {"Logits", "Softmax"});
316

317
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInferer,
318
                           {"Softmax", framework::GradVarName("Logits")});
Z
Zeng Jinle 已提交
319

320 321 322 323 324
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

325
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
H
hong 已提交
326 327 328
                  ops::SoftmaxWithCrossEntropyOpMaker,
                  ops::SoftmaxGradMaker<paddle::framework::OpDesc>,
                  ops::SoftmaxGradMaker<paddle::imperative::OpBase>,
329
                  ops::SoftmaxWithCrossEntropyInplaceInferer);
330
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
Z
Zeng Jinle 已提交
331
                  ops::SoftmaxWithCrossEntropyOpGrad,
332
                  ops::SoftmaxWithCrossEntropyGradInplaceInferer);
333
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
C
caoying03 已提交
334 335
                       ops::SoftmaxWithCrossEntropyKernel<float>,
                       ops::SoftmaxWithCrossEntropyKernel<double>);
336
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
C
caoying03 已提交
337 338
                       ops::SoftmaxWithCrossEntropyGradKernel<float>,
                       ops::SoftmaxWithCrossEntropyGradKernel<double>);
339

340
REGISTER_OP_VERSION(softmax_with_cross_entropy)
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
#ifdef PADDLE_WITH_ASCEND_CL
    .AddCheckpoint(
        R"ROC(
              Add a new attribute [use_softmax] )ROC",
        paddle::framework::compatible::OpVersionDesc().NewAttr(
            "use_softmax", "A flag to indicate whether to do softmax", true))
    .AddCheckpoint(
        R"ROC(
                Add a new dispensable/intermediate output [backprop] )ROC",
        paddle::framework::compatible::OpVersionDesc().NewOutput(
            "Backprop",
            "The intermediate value used for backward calculation. The "
            "calculation is :"
            "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, "
            "where labels is ont-hot."
            "Currently, the tensor is generated and used in npu kernel "
            "only. "));
#else
359 360
    .AddCheckpoint(
        R"ROC(
361
              Add a new attribute [use_softmax] )ROC",
362
        paddle::framework::compatible::OpVersionDesc().NewAttr(
363
            "use_softmax", "A flag to indicate whether to do softmax", true));
364
#endif