softmax_with_cross_entropy_op.cc 11.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
16
#include <memory>
Z
Zeng Jinle 已提交
17 18 19
#include <string>
#include <unordered_map>
#include <vector>
Y
Yu Yang 已提交
20

21 22 23 24 25 26
namespace paddle {
namespace operators {

class SoftmaxWithCrossEntropyOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
27
  void Make() override {
C
caoying03 已提交
28
    AddInput("Logits",
29 30 31 32 33 34 35 36 37 38
             "(Tensor, default: Tensor<float>), The input tensor of unscaled "
             "log probabilities, whose dimension :attr:`axis` should be scaled "
             "by softmax.");
    AddInput(
        "Label",
        "(Tensor) The input tesnor of groud truth label. If :attr:`soft_label` "
        "is set to false, Label is a Tensor<int64> in same shape with "
        "Input(Logits) except the shape in dimension :attr:`axis` as 1. If "
        "soft_label is set to true, Label is a Tensor<float/double> in same "
        "shape with Input(Logits).");
C
caoying03 已提交
39 40
    AddOutput(
        "Softmax",
41 42
        "(Tensor, default: Tensor<float>), A tensor in same shape with "
        "Input(Logits). "
C
caoying03 已提交
43 44
        "The outputs value of softmax activation by given the input batch, "
        "which will be used in backward calculation.")
C
caoying03 已提交
45
        .AsIntermediate();
C
caoying03 已提交
46
    AddOutput("Loss",
47 48 49 50
              "(Tensor, default: Tensor<float>), A tensor in same shape with "
              "Input(Logits) "
              "except the shape in dimension :attr:`axis` as 1. The cross "
              "entropy loss.");
C
caoying03 已提交
51
    AddAttr<bool>(
52
        "soft_label",
C
caoying03 已提交
53 54 55
        "(bool, default: false), A flag to indicate whether to interpretate "
        "the given labels as soft labels.")
        .SetDefault(false);
S
sneaxiy 已提交
56 57
    AddAttr<bool>(
        "numeric_stable_mode",
58
        "(bool, default: true), A flag to indicate whether to use more "
S
sneaxiy 已提交
59 60
        "numerically stable algorithm. This flag is only valid when "
        "soft_label is false and GPU is used.")
61
        .SetDefault(true);
62 63 64 65 66 67
    AddAttr<int>(
        "ignore_index",
        "(int, default -100), Specifies a target value that is ignored and"
        "does not contribute to the input gradient. Only valid if soft_label"
        "is set to False")
        .SetDefault(-100);
68 69 70 71
    AddAttr<int>("axis",
                 "The dimension index of Input(Logits) to perform softmax,"
                 "default -1 for last dimension")
        .SetDefault(-1);
72
    AddComment(R"DOC(
73 74 75
Softmax With Cross Entropy Operator.

Cross entropy loss with softmax is used as the output layer extensively. This
76
operator computes the softmax normalized values for each row of the input
77
tensor, after which cross-entropy loss is computed. This provides a more
78 79
numerically stable gradient.

80 81 82
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
83

C
caoying03 已提交
84
When the attribute soft_label is set false, this operators expects mutually
85 86
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
87

88
The equation is as follows:
89

90
1) Hard label (one-hot label, so every sample has exactly one class)
91

92
$$Loss_j =  -\text{Logit}_{Label_j} +
93
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
94
j = 1,..., K$$
C
caoying03 已提交
95

96
2) Soft label (each sample can have a distribution over all classes)
C
caoying03 已提交
97

98
$$Loss_j =  -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
99
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
100
j = 1,...,K$$
C
caoying03 已提交
101 102

)DOC");
103 104 105 106 107 108 109
  }
};

class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

110
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
111 112 113 114 115 116 117 118
    PADDLE_ENFORCE(ctx->HasInput("Logits"),
                   "Input(Logits) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");

    PADDLE_ENFORCE(ctx->HasOutput("Softmax"),
                   "Output(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");

119
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
120 121
    auto logits_dims = ctx->GetInputDim("Logits");
    auto labels_dims = ctx->GetInputDim("Label");
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    auto logits_rank = logits_dims.size();
    PADDLE_ENFORCE(axis >= -logits_rank && axis < logits_rank,
                   "Attr(axis) value should be in range [-R, R-1], "
                   "R is the rank of Input(Logits).");

    axis = CanonicalAxis(axis, logits_rank);
    for (int i = 0; i < logits_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
          PADDLE_ENFORCE_EQ(
              logits_dims[i], labels_dims[i],
              "Input(Logits) and Input(Label) should in same shape in "
              "dimensions except axis.");
        }
      }
    }
138

139 140 141 142 143
    auto numeric_stable_mode = ctx->Attrs().Get<bool>("numeric_stable_mode");
    if (axis != logits_rank - 1) {
      PADDLE_ENFORCE(
          numeric_stable_mode,
          "Attr(axis) can only be -1 when not in numeric_stable_mode.");
144
    }
145

146 147 148 149 150 151
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(logits_dims[axis], labels_dims[axis],
                          "If Attr(soft_label) == true, the axis dimension of "
152 153
                          "Input(X) and Input(Label) should be equal.");
      }
154
    } else {
155 156 157 158 159
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
        PADDLE_ENFORCE_EQ(labels_dims[axis], 1UL,
                          "If Attr(soft_label) == false, the axis dimension of "
                          "Input(Label) should be 1.");
      }
160 161
    }

Q
qiaolongfei 已提交
162
    ctx->SetOutputDim("Softmax", logits_dims);
163 164 165

    logits_dims[axis] = 1;
    ctx->SetOutputDim("Loss", logits_dims);
166

Q
qiaolongfei 已提交
167 168
    ctx->ShareLoD("Logits", /*->*/ "Softmax");
    ctx->ShareLoD("Logits", /*->*/ "Loss");
C
caoying03 已提交
169
  }
Y
Yu Yang 已提交
170

171
 protected:
172
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
173
      const framework::ExecutionContext& ctx) const override {
174 175 176
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
        ctx.device_context());
Y
Yu Yang 已提交
177
  }
C
caoying03 已提交
178 179 180 181 182 183
};

class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

184
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
185 186 187 188 189 190 191 192
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
                   "Input(Loss@Grad) should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Softmax"),
                   "Input(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
                   "Output(Logits@Grad) should be not null.");

193
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
194 195
    auto softmax_dims = ctx->GetInputDim("Softmax");
    auto labels_dims = ctx->GetInputDim("Label");
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
    auto softmax_rank = softmax_dims.size();
    PADDLE_ENFORCE(axis >= -softmax_rank && axis < softmax_rank,
                   "Attr(axis) value should be in range [-R, R-1], "
                   "R is the rank of Input(Logits).");

    axis = CanonicalAxis(axis, softmax_rank);
    for (int i = 0; i < softmax_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
          PADDLE_ENFORCE_EQ(
              softmax_dims[i], labels_dims[i],
              "Input(Logits) and Input(Label) should in same shape in "
              "dimensions except axis.");
        }
      }
211
    }
212

213 214 215 216 217 218 219
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(softmax_dims[axis], labels_dims[axis],
                          "If Attr(soft_label) == true, the axis dimension of "
                          "Input(X) and Input(Label) should be equal.");
220
      }
221
    } else {
222 223 224 225 226
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
        PADDLE_ENFORCE_EQ(labels_dims[axis], 1UL,
                          "If Attr(soft_label) == false, the axis dimension of "
                          "Input(Label) should be 1.");
      }
227
    }
C
caoying03 已提交
228

Q
qiaolongfei 已提交
229 230
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Softmax"));
231
  }
Y
Yu Yang 已提交
232

233
 protected:
234
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
235
      const framework::ExecutionContext& ctx) const override {
236 237 238
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Loss")),
                                   ctx.device_context());
Y
Yu Yang 已提交
239
  }
240 241
};

H
hong 已提交
242 243
template <typename T>
class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
244
 public:
H
hong 已提交
245
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
246 247

 protected:
H
hong 已提交
248 249
  std::unique_ptr<T> Apply() const override {
    auto* grad_op = new T();
Y
Yu Yang 已提交
250
    grad_op->SetType("softmax_with_cross_entropy_grad");
H
hong 已提交
251 252 253 254 255 256 257
    grad_op->SetInput("Label", this->Input("Label"));
    grad_op->SetInput("Softmax", this->Output("Softmax"));
    grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
    grad_op->SetOutput(framework::GradVarName("Logits"),
                       this->InputGrad("Logits"));
    grad_op->SetAttrMap(this->Attrs());
    return std::unique_ptr<T>(grad_op);
258 259 260
  }
};

261 262
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInference,
                           {"Logits", "Softmax"});
263

264 265
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInference,
                           {"Softmax", framework::GradVarName("Logits")});
Z
Zeng Jinle 已提交
266

267 268 269 270 271
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

272
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
H
hong 已提交
273 274 275
                  ops::SoftmaxWithCrossEntropyOpMaker,
                  ops::SoftmaxGradMaker<paddle::framework::OpDesc>,
                  ops::SoftmaxGradMaker<paddle::imperative::OpBase>,
276
                  ops::SoftmaxWithCrossEntropyInplaceInference);
277
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
Z
Zeng Jinle 已提交
278 279
                  ops::SoftmaxWithCrossEntropyOpGrad,
                  ops::SoftmaxWithCrossEntropyGradInplaceInference);
280
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
C
caoying03 已提交
281 282
                       ops::SoftmaxWithCrossEntropyKernel<float>,
                       ops::SoftmaxWithCrossEntropyKernel<double>);
283
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
C
caoying03 已提交
284 285
                       ops::SoftmaxWithCrossEntropyGradKernel<float>,
                       ops::SoftmaxWithCrossEntropyGradKernel<double>);