softmax_with_cross_entropy_op.cc 12.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
16
#include <memory>
Z
Zeng Jinle 已提交
17 18 19
#include <string>
#include <unordered_map>
#include <vector>
Y
Yu Yang 已提交
20

21 22 23 24 25 26
namespace paddle {
namespace operators {

class SoftmaxWithCrossEntropyOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
27
  void Make() override {
C
caoying03 已提交
28
    AddInput("Logits",
29 30 31 32 33
             "(Tensor, default: Tensor<float>), The input tensor of unscaled "
             "log probabilities, whose dimension :attr:`axis` should be scaled "
             "by softmax.");
    AddInput(
        "Label",
T
tianshuo78520a 已提交
34
        "(Tensor) The input tensor of groud truth label. If :attr:`soft_label` "
35 36 37 38
        "is set to false, Label is a Tensor<int64> in same shape with "
        "Input(Logits) except the shape in dimension :attr:`axis` as 1. If "
        "soft_label is set to true, Label is a Tensor<float/double> in same "
        "shape with Input(Logits).");
C
caoying03 已提交
39 40
    AddOutput(
        "Softmax",
41 42
        "(Tensor, default: Tensor<float>), A tensor in same shape with "
        "Input(Logits). "
C
caoying03 已提交
43 44
        "The outputs value of softmax activation by given the input batch, "
        "which will be used in backward calculation.")
C
caoying03 已提交
45
        .AsIntermediate();
C
caoying03 已提交
46
    AddOutput("Loss",
47 48 49 50
              "(Tensor, default: Tensor<float>), A tensor in same shape with "
              "Input(Logits) "
              "except the shape in dimension :attr:`axis` as 1. The cross "
              "entropy loss.");
C
caoying03 已提交
51
    AddAttr<bool>(
52
        "soft_label",
T
tianshuo78520a 已提交
53
        "(bool, default: false), A flag to indicate whether to interpretant "
C
caoying03 已提交
54 55
        "the given labels as soft labels.")
        .SetDefault(false);
S
sneaxiy 已提交
56 57
    AddAttr<bool>(
        "numeric_stable_mode",
58
        "(bool, default: true), A flag to indicate whether to use more "
S
sneaxiy 已提交
59 60
        "numerically stable algorithm. This flag is only valid when "
        "soft_label is false and GPU is used.")
61
        .SetDefault(true);
62 63 64 65 66 67
    AddAttr<int>(
        "ignore_index",
        "(int, default -100), Specifies a target value that is ignored and"
        "does not contribute to the input gradient. Only valid if soft_label"
        "is set to False")
        .SetDefault(-100);
68 69 70 71
    AddAttr<int>("axis",
                 "The dimension index of Input(Logits) to perform softmax,"
                 "default -1 for last dimension")
        .SetDefault(-1);
72
    AddComment(R"DOC(
73 74 75
Softmax With Cross Entropy Operator.

Cross entropy loss with softmax is used as the output layer extensively. This
76
operator computes the softmax normalized values for each row of the input
77
tensor, after which cross-entropy loss is computed. This provides a more
78 79
numerically stable gradient.

80 81 82
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
83

C
caoying03 已提交
84
When the attribute soft_label is set false, this operators expects mutually
85 86
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
87

88
The equation is as follows:
89

90
1) Hard label (one-hot label, so every sample has exactly one class)
91

92
$$Loss_j =  -\text{Logit}_{Label_j} +
93
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
94
j = 1,..., K$$
C
caoying03 已提交
95

96
2) Soft label (each sample can have a distribution over all classes)
C
caoying03 已提交
97

98
$$Loss_j =  -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
99
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
100
j = 1,...,K$$
C
caoying03 已提交
101 102

)DOC");
103 104 105 106 107 108 109
  }
};

class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

110
  void InferShape(framework::InferShapeContext* ctx) const override {
111 112 113 114 115 116 117 118 119 120 121 122 123
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Logits"), true,
        platform::errors::InvalidArgument("Input(Logits) should be not null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Label"), true,
        platform::errors::InvalidArgument("Input(Label) should be not null."));

    PADDLE_ENFORCE_EQ(ctx->HasOutput("Softmax"), true,
                      platform::errors::InvalidArgument(
                          "Output(Softmax) should be not null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("Loss"), true,
        platform::errors::InvalidArgument("Output(Loss) should be not null."));
Q
qiaolongfei 已提交
124

125
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
126 127
    auto logits_dims = ctx->GetInputDim("Logits");
    auto labels_dims = ctx->GetInputDim("Label");
128
    auto logits_rank = logits_dims.size();
129 130 131 132 133 134 135 136
    PADDLE_ENFORCE_GE(axis, -logits_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
    PADDLE_ENFORCE_LT(axis, logits_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
137 138 139 140 141

    axis = CanonicalAxis(axis, logits_rank);
    for (int i = 0; i < logits_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
142 143 144 145
          PADDLE_ENFORCE_EQ(logits_dims[i], labels_dims[i],
                            platform::errors::InvalidArgument(
                                "Input(Logits) and Input(Label) should in "
                                "same shape in dimensions except axis."));
146 147 148
        }
      }
    }
149

150 151
    auto numeric_stable_mode = ctx->Attrs().Get<bool>("numeric_stable_mode");
    if (axis != logits_rank - 1) {
152 153 154 155
      PADDLE_ENFORCE_EQ(numeric_stable_mode, true,
                        platform::errors::InvalidArgument(
                            "Attr(axis) can only be -1 "
                            "when not in numeric_stable_mode."));
156
    }
157

158 159 160 161 162
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(logits_dims[axis], labels_dims[axis],
163 164 165 166
                          platform::errors::InvalidArgument(
                              "If Attr(soft_label) == true,  "
                              "the axis dimension of "
                              "Input(X) and Input(Label) should be equal."));
167
      }
168
    } else {
169
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
170 171 172 173 174
        PADDLE_ENFORCE_EQ(
            labels_dims[axis], 1UL,
            platform::errors::InvalidArgument("If Attr(soft_label) == false, "
                                              "the axis dimension of "
                                              "Input(Label) should be 1."));
175
      }
176 177
    }

Q
qiaolongfei 已提交
178
    ctx->SetOutputDim("Softmax", logits_dims);
179 180 181

    logits_dims[axis] = 1;
    ctx->SetOutputDim("Loss", logits_dims);
182

Q
qiaolongfei 已提交
183 184
    ctx->ShareLoD("Logits", /*->*/ "Softmax");
    ctx->ShareLoD("Logits", /*->*/ "Loss");
C
caoying03 已提交
185
  }
Y
Yu Yang 已提交
186

187
 protected:
188
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
189
      const framework::ExecutionContext& ctx) const override {
190 191 192
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
        ctx.device_context());
Y
Yu Yang 已提交
193
  }
C
caoying03 已提交
194 195 196 197 198 199
};

class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

200
  void InferShape(framework::InferShapeContext* ctx) const override {
201 202 203 204 205 206 207 208 209 210 211 212
    PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Loss")), true,
                      platform::errors::InvalidArgument(
                          "Input(Loss@Grad) should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"), true,
                      platform::errors::InvalidArgument(
                          "Input(Softmax) should be not null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Label"), true,
        platform::errors::InvalidArgument("Input(Label) should be not null."));
    PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Logits")), true,
                      platform::errors::InvalidArgument(
                          "Output(Logits@Grad) should be not null."));
Q
qiaolongfei 已提交
213

214
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
215 216
    auto softmax_dims = ctx->GetInputDim("Softmax");
    auto labels_dims = ctx->GetInputDim("Label");
217
    auto softmax_rank = softmax_dims.size();
218 219 220 221 222 223 224 225
    PADDLE_ENFORCE_GE(axis, -softmax_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
    PADDLE_ENFORCE_LT(axis, softmax_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
226 227 228 229 230 231 232

    axis = CanonicalAxis(axis, softmax_rank);
    for (int i = 0; i < softmax_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
          PADDLE_ENFORCE_EQ(
              softmax_dims[i], labels_dims[i],
233 234 235
              platform::errors::InvalidArgument(
                  "Input(Logits) and Input(Label) should in same shape in "
                  "dimensions except axis."));
236 237
        }
      }
238
    }
239

240 241 242 243 244
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(softmax_dims[axis], labels_dims[axis],
245 246 247 248
                          platform::errors::InvalidArgument(
                              "If Attr(soft_label) == true, "
                              "the axis dimension of "
                              "Input(X) and Input(Label) should be equal."));
249
      }
250
    } else {
251
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
252 253 254 255 256
        PADDLE_ENFORCE_EQ(
            labels_dims[axis], 1UL,
            platform::errors::InvalidArgument("If Attr(soft_label) == false, "
                                              "the axis dimension of "
                                              "Input(Label) should be 1."));
257
      }
258
    }
C
caoying03 已提交
259

Q
qiaolongfei 已提交
260 261
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Softmax"));
262
  }
Y
Yu Yang 已提交
263

264
 protected:
265
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
266
      const framework::ExecutionContext& ctx) const override {
267 268 269
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Loss")),
                                   ctx.device_context());
Y
Yu Yang 已提交
270
  }
271 272
};

H
hong 已提交
273 274
template <typename T>
class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
275
 public:
H
hong 已提交
276
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
277 278

 protected:
279
  void Apply(GradOpPtr<T> grad_op) const override {
Y
Yu Yang 已提交
280
    grad_op->SetType("softmax_with_cross_entropy_grad");
H
hong 已提交
281 282 283 284 285 286
    grad_op->SetInput("Label", this->Input("Label"));
    grad_op->SetInput("Softmax", this->Output("Softmax"));
    grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
    grad_op->SetOutput(framework::GradVarName("Logits"),
                       this->InputGrad("Logits"));
    grad_op->SetAttrMap(this->Attrs());
287 288 289
  }
};

290
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInferer,
291
                           {"Logits", "Softmax"});
292

293
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInferer,
294
                           {"Softmax", framework::GradVarName("Logits")});
Z
Zeng Jinle 已提交
295

296 297 298 299 300
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

301
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
H
hong 已提交
302 303 304
                  ops::SoftmaxWithCrossEntropyOpMaker,
                  ops::SoftmaxGradMaker<paddle::framework::OpDesc>,
                  ops::SoftmaxGradMaker<paddle::imperative::OpBase>,
305
                  ops::SoftmaxWithCrossEntropyInplaceInferer);
306
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
Z
Zeng Jinle 已提交
307
                  ops::SoftmaxWithCrossEntropyOpGrad,
308
                  ops::SoftmaxWithCrossEntropyGradInplaceInferer);
309
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
C
caoying03 已提交
310 311
                       ops::SoftmaxWithCrossEntropyKernel<float>,
                       ops::SoftmaxWithCrossEntropyKernel<double>);
312
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
C
caoying03 已提交
313 314
                       ops::SoftmaxWithCrossEntropyGradKernel<float>,
                       ops::SoftmaxWithCrossEntropyGradKernel<double>);