softmax_with_cross_entropy_op.cc 13.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
16
#include <memory>
Z
Zeng Jinle 已提交
17 18 19
#include <string>
#include <unordered_map>
#include <vector>
20
#include "paddle/fluid/framework/op_version_registry.h"
Y
Yu Yang 已提交
21

22 23 24 25 26 27
namespace paddle {
namespace operators {

class SoftmaxWithCrossEntropyOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
28
  void Make() override {
C
caoying03 已提交
29
    AddInput("Logits",
30 31 32 33 34
             "(Tensor, default: Tensor<float>), The input tensor of unscaled "
             "log probabilities, whose dimension :attr:`axis` should be scaled "
             "by softmax.");
    AddInput(
        "Label",
T
tianshuo78520a 已提交
35
        "(Tensor) The input tensor of groud truth label. If :attr:`soft_label` "
36 37 38 39
        "is set to false, Label is a Tensor<int64> in same shape with "
        "Input(Logits) except the shape in dimension :attr:`axis` as 1. If "
        "soft_label is set to true, Label is a Tensor<float/double> in same "
        "shape with Input(Logits).");
C
caoying03 已提交
40 41
    AddOutput(
        "Softmax",
42 43
        "(Tensor, default: Tensor<float>), A tensor in same shape with "
        "Input(Logits). "
C
caoying03 已提交
44 45
        "The outputs value of softmax activation by given the input batch, "
        "which will be used in backward calculation.")
C
caoying03 已提交
46
        .AsIntermediate();
C
caoying03 已提交
47
    AddOutput("Loss",
48 49 50 51
              "(Tensor, default: Tensor<float>), A tensor in same shape with "
              "Input(Logits) "
              "except the shape in dimension :attr:`axis` as 1. The cross "
              "entropy loss.");
C
caoying03 已提交
52
    AddAttr<bool>(
53
        "soft_label",
T
tianshuo78520a 已提交
54
        "(bool, default: false), A flag to indicate whether to interpretant "
C
caoying03 已提交
55 56
        "the given labels as soft labels.")
        .SetDefault(false);
57 58 59 60
    AddAttr<bool>(
        "softmax_switch",
        "(bool, default: true), A flag to indicate whether to do softmax ")
        .SetDefault(true);
S
sneaxiy 已提交
61 62
    AddAttr<bool>(
        "numeric_stable_mode",
63
        "(bool, default: true), A flag to indicate whether to use more "
S
sneaxiy 已提交
64 65
        "numerically stable algorithm. This flag is only valid when "
        "soft_label is false and GPU is used.")
66
        .SetDefault(true);
67 68 69 70 71 72
    AddAttr<int>(
        "ignore_index",
        "(int, default -100), Specifies a target value that is ignored and"
        "does not contribute to the input gradient. Only valid if soft_label"
        "is set to False")
        .SetDefault(-100);
73 74 75 76
    AddAttr<int>("axis",
                 "The dimension index of Input(Logits) to perform softmax,"
                 "default -1 for last dimension")
        .SetDefault(-1);
77
    AddComment(R"DOC(
78 79 80
Softmax With Cross Entropy Operator.

Cross entropy loss with softmax is used as the output layer extensively. This
81
operator computes the softmax normalized values for each row of the input
82
tensor, after which cross-entropy loss is computed. This provides a more
83 84
numerically stable gradient.

85 86 87
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
88

C
caoying03 已提交
89
When the attribute soft_label is set false, this operators expects mutually
90 91
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
92

93
The equation is as follows:
94

95
1) Hard label (one-hot label, so every sample has exactly one class)
96

97
$$Loss_j =  -\text{Logit}_{Label_j} +
98
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
99
j = 1,..., K$$
C
caoying03 已提交
100

101
2) Soft label (each sample can have a distribution over all classes)
C
caoying03 已提交
102

103
$$Loss_j =  -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
104
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
105
j = 1,...,K$$
C
caoying03 已提交
106 107

)DOC");
108 109 110 111 112 113 114
  }
};

class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

115
  void InferShape(framework::InferShapeContext* ctx) const override {
116 117 118 119 120 121 122 123 124 125 126 127 128
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Logits"), true,
        platform::errors::InvalidArgument("Input(Logits) should be not null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Label"), true,
        platform::errors::InvalidArgument("Input(Label) should be not null."));

    PADDLE_ENFORCE_EQ(ctx->HasOutput("Softmax"), true,
                      platform::errors::InvalidArgument(
                          "Output(Softmax) should be not null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("Loss"), true,
        platform::errors::InvalidArgument("Output(Loss) should be not null."));
Q
qiaolongfei 已提交
129

130
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
131 132
    auto logits_dims = ctx->GetInputDim("Logits");
    auto labels_dims = ctx->GetInputDim("Label");
133
    auto logits_rank = logits_dims.size();
134 135 136 137 138 139 140 141
    PADDLE_ENFORCE_GE(axis, -logits_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
    PADDLE_ENFORCE_LT(axis, logits_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
142 143 144 145 146

    axis = CanonicalAxis(axis, logits_rank);
    for (int i = 0; i < logits_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
147 148 149 150
          PADDLE_ENFORCE_EQ(logits_dims[i], labels_dims[i],
                            platform::errors::InvalidArgument(
                                "Input(Logits) and Input(Label) should in "
                                "same shape in dimensions except axis."));
151 152 153
        }
      }
    }
154

155 156
    auto numeric_stable_mode = ctx->Attrs().Get<bool>("numeric_stable_mode");
    if (axis != logits_rank - 1) {
157 158 159 160
      PADDLE_ENFORCE_EQ(numeric_stable_mode, true,
                        platform::errors::InvalidArgument(
                            "Attr(axis) can only be -1 "
                            "when not in numeric_stable_mode."));
161
    }
162

163 164 165 166 167
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(logits_dims[axis], labels_dims[axis],
168 169 170 171
                          platform::errors::InvalidArgument(
                              "If Attr(soft_label) == true,  "
                              "the axis dimension of "
                              "Input(X) and Input(Label) should be equal."));
172
      }
173
    } else {
174
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
175 176 177 178 179
        PADDLE_ENFORCE_EQ(
            labels_dims[axis], 1UL,
            platform::errors::InvalidArgument("If Attr(soft_label) == false, "
                                              "the axis dimension of "
                                              "Input(Label) should be 1."));
180
      }
181 182
    }

Q
qiaolongfei 已提交
183
    ctx->SetOutputDim("Softmax", logits_dims);
184 185 186

    logits_dims[axis] = 1;
    ctx->SetOutputDim("Loss", logits_dims);
187

Q
qiaolongfei 已提交
188 189
    ctx->ShareLoD("Logits", /*->*/ "Softmax");
    ctx->ShareLoD("Logits", /*->*/ "Loss");
C
caoying03 已提交
190
  }
Y
Yu Yang 已提交
191

192
 protected:
193
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
194
      const framework::ExecutionContext& ctx) const override {
195 196 197
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
        ctx.device_context());
Y
Yu Yang 已提交
198
  }
C
caoying03 已提交
199 200 201 202 203 204
};

class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

205
  void InferShape(framework::InferShapeContext* ctx) const override {
206 207 208 209 210 211 212 213 214 215 216 217
    PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Loss")), true,
                      platform::errors::InvalidArgument(
                          "Input(Loss@Grad) should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"), true,
                      platform::errors::InvalidArgument(
                          "Input(Softmax) should be not null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Label"), true,
        platform::errors::InvalidArgument("Input(Label) should be not null."));
    PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Logits")), true,
                      platform::errors::InvalidArgument(
                          "Output(Logits@Grad) should be not null."));
Q
qiaolongfei 已提交
218

219
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
220 221
    auto softmax_dims = ctx->GetInputDim("Softmax");
    auto labels_dims = ctx->GetInputDim("Label");
222
    auto softmax_rank = softmax_dims.size();
223 224 225 226 227 228 229 230
    PADDLE_ENFORCE_GE(axis, -softmax_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
    PADDLE_ENFORCE_LT(axis, softmax_rank,
                      platform::errors::InvalidArgument(
                          "Attr(axis) value should be in range [-R, R-1], "
                          "R is the rank of Input(Logits)."));
231 232 233 234 235 236 237

    axis = CanonicalAxis(axis, softmax_rank);
    for (int i = 0; i < softmax_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
          PADDLE_ENFORCE_EQ(
              softmax_dims[i], labels_dims[i],
238 239 240
              platform::errors::InvalidArgument(
                  "Input(Logits) and Input(Label) should in same shape in "
                  "dimensions except axis."));
241 242
        }
      }
243
    }
244

245 246 247 248 249
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(softmax_dims[axis], labels_dims[axis],
250 251 252 253
                          platform::errors::InvalidArgument(
                              "If Attr(soft_label) == true, "
                              "the axis dimension of "
                              "Input(X) and Input(Label) should be equal."));
254
      }
255
    } else {
256
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
257 258 259 260 261
        PADDLE_ENFORCE_EQ(
            labels_dims[axis], 1UL,
            platform::errors::InvalidArgument("If Attr(soft_label) == false, "
                                              "the axis dimension of "
                                              "Input(Label) should be 1."));
262
      }
263
    }
C
caoying03 已提交
264

Q
qiaolongfei 已提交
265 266
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Softmax"));
267
  }
Y
Yu Yang 已提交
268

269
 protected:
270
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
271
      const framework::ExecutionContext& ctx) const override {
272 273 274
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Loss")),
                                   ctx.device_context());
Y
Yu Yang 已提交
275
  }
276 277
};

H
hong 已提交
278 279
template <typename T>
class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
280
 public:
H
hong 已提交
281
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
282 283

 protected:
284
  void Apply(GradOpPtr<T> grad_op) const override {
Y
Yu Yang 已提交
285
    grad_op->SetType("softmax_with_cross_entropy_grad");
H
hong 已提交
286 287 288 289 290 291
    grad_op->SetInput("Label", this->Input("Label"));
    grad_op->SetInput("Softmax", this->Output("Softmax"));
    grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
    grad_op->SetOutput(framework::GradVarName("Logits"),
                       this->InputGrad("Logits"));
    grad_op->SetAttrMap(this->Attrs());
292 293 294
  }
};

295
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInferer,
296
                           {"Logits", "Softmax"});
297

298
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInferer,
299
                           {"Softmax", framework::GradVarName("Logits")});
Z
Zeng Jinle 已提交
300

301 302 303 304 305
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

306
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
H
hong 已提交
307 308 309
                  ops::SoftmaxWithCrossEntropyOpMaker,
                  ops::SoftmaxGradMaker<paddle::framework::OpDesc>,
                  ops::SoftmaxGradMaker<paddle::imperative::OpBase>,
310
                  ops::SoftmaxWithCrossEntropyInplaceInferer);
311
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
Z
Zeng Jinle 已提交
312
                  ops::SoftmaxWithCrossEntropyOpGrad,
313
                  ops::SoftmaxWithCrossEntropyGradInplaceInferer);
314
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
C
caoying03 已提交
315 316
                       ops::SoftmaxWithCrossEntropyKernel<float>,
                       ops::SoftmaxWithCrossEntropyKernel<double>);
317
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
C
caoying03 已提交
318 319
                       ops::SoftmaxWithCrossEntropyGradKernel<float>,
                       ops::SoftmaxWithCrossEntropyGradKernel<double>);
320 321 322 323 324 325 326
REGISTER_OP_VERSION(softmax_with_cross_entropy)
    .AddCheckpoint(
        R"ROC(
              Add a new attribute [softmax_switch] )ROC",
        paddle::framework::compatible::OpVersionDesc().NewAttr(
            "softmax_switch", "A flag to indicate whether to do softmax",
            true));