softmax_with_cross_entropy_op.cc 11.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

C
caoying03 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
S
sneaxiy 已提交
16
#include <memory>
Z
Zeng Jinle 已提交
17 18 19
#include <string>
#include <unordered_map>
#include <vector>
Y
Yu Yang 已提交
20

21 22 23 24 25 26
namespace paddle {
namespace operators {

class SoftmaxWithCrossEntropyOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
27
  void Make() override {
C
caoying03 已提交
28
    AddInput("Logits",
29 30 31 32 33 34 35 36 37 38
             "(Tensor, default: Tensor<float>), The input tensor of unscaled "
             "log probabilities, whose dimension :attr:`axis` should be scaled "
             "by softmax.");
    AddInput(
        "Label",
        "(Tensor) The input tesnor of groud truth label. If :attr:`soft_label` "
        "is set to false, Label is a Tensor<int64> in same shape with "
        "Input(Logits) except the shape in dimension :attr:`axis` as 1. If "
        "soft_label is set to true, Label is a Tensor<float/double> in same "
        "shape with Input(Logits).");
C
caoying03 已提交
39 40
    AddOutput(
        "Softmax",
41 42
        "(Tensor, default: Tensor<float>), A tensor in same shape with "
        "Input(Logits). "
C
caoying03 已提交
43 44
        "The outputs value of softmax activation by given the input batch, "
        "which will be used in backward calculation.")
C
caoying03 已提交
45
        .AsIntermediate();
C
caoying03 已提交
46
    AddOutput("Loss",
47 48 49 50
              "(Tensor, default: Tensor<float>), A tensor in same shape with "
              "Input(Logits) "
              "except the shape in dimension :attr:`axis` as 1. The cross "
              "entropy loss.");
C
caoying03 已提交
51
    AddAttr<bool>(
52
        "soft_label",
C
caoying03 已提交
53 54 55
        "(bool, default: false), A flag to indicate whether to interpretate "
        "the given labels as soft labels.")
        .SetDefault(false);
S
sneaxiy 已提交
56 57
    AddAttr<bool>(
        "numeric_stable_mode",
58
        "(bool, default: true), A flag to indicate whether to use more "
S
sneaxiy 已提交
59 60
        "numerically stable algorithm. This flag is only valid when "
        "soft_label is false and GPU is used.")
61
        .SetDefault(true);
62 63 64 65 66 67
    AddAttr<int>(
        "ignore_index",
        "(int, default -100), Specifies a target value that is ignored and"
        "does not contribute to the input gradient. Only valid if soft_label"
        "is set to False")
        .SetDefault(-100);
68 69 70 71
    AddAttr<int>("axis",
                 "The dimension index of Input(Logits) to perform softmax,"
                 "default -1 for last dimension")
        .SetDefault(-1);
72
    AddComment(R"DOC(
73 74 75
Softmax With Cross Entropy Operator.

Cross entropy loss with softmax is used as the output layer extensively. This
76
operator computes the softmax normalized values for each row of the input
77
tensor, after which cross-entropy loss is computed. This provides a more
78 79
numerically stable gradient.

80 81 82
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
83

C
caoying03 已提交
84
When the attribute soft_label is set false, this operators expects mutually
85 86
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
87

88
The equation is as follows:
89

90
1) Hard label (one-hot label, so every sample has exactly one class)
91

92
$$Loss_j =  -\text{Logit}_{Label_j} +
93
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
94
j = 1,..., K$$
C
caoying03 已提交
95

96
2) Soft label (each sample can have a distribution over all classes)
C
caoying03 已提交
97

98
$$Loss_j =  -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
99
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
100
j = 1,...,K$$
C
caoying03 已提交
101 102

)DOC");
103 104 105 106 107 108 109
  }
};

class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

110
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
111 112 113 114 115 116 117 118
    PADDLE_ENFORCE(ctx->HasInput("Logits"),
                   "Input(Logits) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");

    PADDLE_ENFORCE(ctx->HasOutput("Softmax"),
                   "Output(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");

119
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
120 121
    auto logits_dims = ctx->GetInputDim("Logits");
    auto labels_dims = ctx->GetInputDim("Label");
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    auto logits_rank = logits_dims.size();
    PADDLE_ENFORCE(axis >= -logits_rank && axis < logits_rank,
                   "Attr(axis) value should be in range [-R, R-1], "
                   "R is the rank of Input(Logits).");

    axis = CanonicalAxis(axis, logits_rank);
    for (int i = 0; i < logits_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
          PADDLE_ENFORCE_EQ(
              logits_dims[i], labels_dims[i],
              "Input(Logits) and Input(Label) should in same shape in "
              "dimensions except axis.");
        }
      }
    }
138

139 140 141 142 143
    auto numeric_stable_mode = ctx->Attrs().Get<bool>("numeric_stable_mode");
    if (axis != logits_rank - 1) {
      PADDLE_ENFORCE(
          numeric_stable_mode,
          "Attr(axis) can only be -1 when not in numeric_stable_mode.");
144
    }
145

146 147 148 149 150 151
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(logits_dims[axis], labels_dims[axis],
                          "If Attr(soft_label) == true, the axis dimension of "
152 153
                          "Input(X) and Input(Label) should be equal.");
      }
154
    } else {
155 156 157 158 159
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
        PADDLE_ENFORCE_EQ(labels_dims[axis], 1UL,
                          "If Attr(soft_label) == false, the axis dimension of "
                          "Input(Label) should be 1.");
      }
160 161
    }

Q
qiaolongfei 已提交
162
    ctx->SetOutputDim("Softmax", logits_dims);
163 164 165

    logits_dims[axis] = 1;
    ctx->SetOutputDim("Loss", logits_dims);
166

Q
qiaolongfei 已提交
167 168
    ctx->ShareLoD("Logits", /*->*/ "Softmax");
    ctx->ShareLoD("Logits", /*->*/ "Loss");
C
caoying03 已提交
169
  }
Y
Yu Yang 已提交
170

171
 protected:
172
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
173
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
174 175
    return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(),
                                   ctx.device_context());
Y
Yu Yang 已提交
176
  }
C
caoying03 已提交
177 178 179 180 181 182
};

class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

183
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
qiaolongfei 已提交
184 185 186 187 188 189 190 191
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
                   "Input(Loss@Grad) should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Softmax"),
                   "Input(Softmax) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
                   "Output(Logits@Grad) should be not null.");

192
    auto axis = ctx->Attrs().Get<int>("axis");
Q
qiaolongfei 已提交
193 194
    auto softmax_dims = ctx->GetInputDim("Softmax");
    auto labels_dims = ctx->GetInputDim("Label");
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
    auto softmax_rank = softmax_dims.size();
    PADDLE_ENFORCE(axis >= -softmax_rank && axis < softmax_rank,
                   "Attr(axis) value should be in range [-R, R-1], "
                   "R is the rank of Input(Logits).");

    axis = CanonicalAxis(axis, softmax_rank);
    for (int i = 0; i < softmax_rank; i++) {
      if (i != axis) {
        if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
          PADDLE_ENFORCE_EQ(
              softmax_dims[i], labels_dims[i],
              "Input(Logits) and Input(Label) should in same shape in "
              "dimensions except axis.");
        }
      }
210
    }
211

212 213 214 215 216 217 218
    bool soft_label = ctx->Attrs().Get<bool>("soft_label");
    if (soft_label) {
      if (ctx->IsRuntime() ||
          (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
        PADDLE_ENFORCE_EQ(softmax_dims[axis], labels_dims[axis],
                          "If Attr(soft_label) == true, the axis dimension of "
                          "Input(X) and Input(Label) should be equal.");
219
      }
220
    } else {
221 222 223 224 225
      if (ctx->IsRuntime() || labels_dims[axis] > 0) {
        PADDLE_ENFORCE_EQ(labels_dims[axis], 1UL,
                          "If Attr(soft_label) == false, the axis dimension of "
                          "Input(Label) should be 1.");
      }
226
    }
C
caoying03 已提交
227

Q
qiaolongfei 已提交
228 229
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Softmax"));
230
  }
Y
Yu Yang 已提交
231

232
 protected:
233
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
234
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
235
    return framework::OpKernelType(
Y
Yu Yang 已提交
236
        ctx.Input<Tensor>(framework::GradVarName("Loss"))->type(),
Y
Yu Yang 已提交
237
        ctx.device_context());
Y
Yu Yang 已提交
238
  }
239 240
};

241 242 243 244 245
class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
Y
Yu Yang 已提交
246 247
  std::unique_ptr<framework::OpDesc> Apply() const override {
    auto* grad_op = new framework::OpDesc();
Y
Yu Yang 已提交
248 249 250 251 252 253
    grad_op->SetType("softmax_with_cross_entropy_grad");
    grad_op->SetInput("Label", Input("Label"));
    grad_op->SetInput("Softmax", Output("Softmax"));
    grad_op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
    grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
    grad_op->SetAttrMap(Attrs());
Y
Yu Yang 已提交
254
    return std::unique_ptr<framework::OpDesc>(grad_op);
255 256 257
  }
};

258 259 260 261 262
class SoftmaxWithCrossEntropyInplaceInference
    : public framework::InplaceOpInference {
 public:
  std::unordered_map<std::string, std::string> operator()(
      const framework::OpDesc& op_desc, bool use_cuda) const {
263
    return {{"Logits", "Softmax"}};
264 265 266
  }
};

Z
Zeng Jinle 已提交
267 268 269 270
class SoftmaxWithCrossEntropyGradInplaceInference
    : public framework::InplaceOpInference {
 public:
  std::unordered_map<std::string, std::string> operator()(
271
      const framework::OpDesc& op_desc, bool use_cuda) const {
Z
Zeng Jinle 已提交
272 273 274 275
    return {{"Softmax", framework::GradVarName("Logits")}};
  }
};

276 277 278 279 280
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

281
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
282 283
                  ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker,
                  ops::SoftmaxWithCrossEntropyInplaceInference);
284
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
Z
Zeng Jinle 已提交
285 286
                  ops::SoftmaxWithCrossEntropyOpGrad,
                  ops::SoftmaxWithCrossEntropyGradInplaceInference);
287
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
C
caoying03 已提交
288 289
                       ops::SoftmaxWithCrossEntropyKernel<float>,
                       ops::SoftmaxWithCrossEntropyKernel<double>);
290
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
C
caoying03 已提交
291 292
                       ops::SoftmaxWithCrossEntropyGradKernel<float>,
                       ops::SoftmaxWithCrossEntropyGradKernel<double>);