cross_entropy_op.cc 6.6 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/cross_entropy_op.h"

namespace paddle {
namespace operators {

20
class CrossEntropyOp : public framework::OperatorWithKernel {
Y
Yu Yang 已提交
21 22 23
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

24
 protected:
D
dongzhihong 已提交
25
  void InferShape(const framework::InferShapeContext &ctx) const override {
C
caoying03 已提交
26
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
27
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
C
caoying03 已提交
28 29 30
                            "Input(Label) should be not null.");
    PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
                            "Output(Y) should be not null.");
31 32 33

    auto x = ctx.Input<Tensor>("X");
    auto label = ctx.Input<Tensor>("Label");
C
caoying03 已提交
34
    PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
35
    PADDLE_ENFORCE_EQ(label->dims().size(), 2,
C
caoying03 已提交
36
                      "Input(Label)'s rank should be 2.");
37
    PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
C
caoying03 已提交
38
                      "The 1st dimension of Input(X) and Input(Label) should "
39
                      "be equal.");
40
    if (ctx.Attr<bool>("soft_label")) {
41
      PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
C
caoying03 已提交
42 43
                        "If Attr(soft_label) == true, the 2nd dimension of "
                        "Input(X) and Input(Label) should be equal.");
44
    } else {
45
      PADDLE_ENFORCE_EQ(label->dims()[1], 1,
C
caoying03 已提交
46 47
                        "If Attr(soft_label) == false, the 2nd dimension of "
                        "Input(Label) should be 1.");
48
    }
49

D
dangqingqing 已提交
50
    ctx.Output<Tensor>("Y")->Resize({x->dims()[0], 1});
51
    ctx.ShareLoD("X", /*->*/ "Y");
Q
Qiao Longfei 已提交
52 53 54
  }
};

55
class CrossEntropyGradientOp : public framework::OperatorWithKernel {
Y
Yu Yang 已提交
56 57 58
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

Y
Yan Chunwei 已提交
59
 protected:
D
dongzhihong 已提交
60
  void InferShape(const framework::InferShapeContext &ctx) const override {
C
caoying03 已提交
61
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
62
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
C
caoying03 已提交
63
                            "Input(Label) should be not null.");
64
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
C
caoying03 已提交
65
                            "Input(Y@GRAD) shoudl be not null.");
66

67
    auto x = ctx.Input<Tensor>("X");
68 69
    auto label = ctx.Input<Tensor>("Label");
    auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
C
caoying03 已提交
70 71 72
    PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
    PADDLE_ENFORCE_EQ(dy->dims().size(), 2,
                      "Input(Y@Grad)'s rank should be 2.");
73
    PADDLE_ENFORCE_EQ(label->dims().size(), 2,
C
caoying03 已提交
74
                      "Input(Label)'s rank should be 2.");
75
    PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
C
caoying03 已提交
76
                      "The 1st dimension of Input(X) and Input(Label) should "
77 78
                      "be equal.");
    PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0],
C
caoying03 已提交
79
                      "The 1st dimension of Input(X) and Input(Y@Grad) should "
80 81
                      "be equal.");
    PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
C
caoying03 已提交
82
                      "The 2nd dimension of Input(Y@Grad) should be 1.");
83
    if (ctx.Attr<bool>("soft_label")) {
84
      PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
C
caoying03 已提交
85 86
                        "When Attr(soft_label) == true, the 2nd dimension of "
                        "Input(X) and Input(Label) should be equal.");
87 88
    } else {
      PADDLE_ENFORCE_EQ(label->dims()[1], 1,
C
caoying03 已提交
89 90
                        "When Attr(soft_label) == false, the 2nd dimension of "
                        "Input(Label) should be 1.");
91
    }
Y
Yan Chunwei 已提交
92

D
dangqingqing 已提交
93
    auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
94
    dx->Resize(x->dims());
Y
Yan Chunwei 已提交
95 96 97
  }
};

98
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
99
 public:
100 101
  CrossEntropyOpMaker(framework::OpProto *proto,
                      framework::OpAttrChecker *op_checker)
102
      : OpProtoAndCheckerMaker(proto, op_checker) {
C
caoying03 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
    AddInput("X",
             "(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
             "where N is the batch size and D is the number of classes. "
             "This input is a probability computed by the previous operator, "
             "which is almost always the result of a softmax operator.");
    AddInput("Label",
             "(Tensor, default Tensor<int>), the ground truth which is "
             "a 1-D or 2-D tensor. "
             "When soft_label is set to 0, `Label` is a Tensor<int> with shape "
             "[N x 1]. "
             "When soft_label is set to 1, `Label` is a Tensor<float/double> "
             "with shape [N x K].");
    AddOutput("Y",
              "(Tensor, default Tensor<float>), a 1-D tensor "
              "with shape [N x 1]. The cross entropy loss.");
    AddAttr<bool>(
        "soft_label",
        "(bool, default false), a flag to indicate whether to interpretate "
        "the given labels as soft labels.")
122
        .SetDefault(false);
Q
Qiao Longfei 已提交
123
    AddComment(R"DOC(
124
CrossEntropy Operator.
Q
Qiao Longfei 已提交
125

126 127 128
It supports both standard cross-entropy and soft-label cross-entropy loss
computation.
1) One-hot cross-entropy:
129
    soft_label = False, Label[i, 0] indicates the class index for sample i:
130

131
                Y[i] = -log(X[i, Label[i]])
Q
Qiao Longfei 已提交
132

133
2) Soft-label cross-entropy:
134
    soft_label = True, Label[i, j] indicates the soft label of class j
135
    for sample i:
136

137
                Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
138

139
   Please make sure that in this case the summuation of each row of Label
140 141 142 143 144 145
   equals one.

3) One-hot cross-entropy with vecterized Input(Label):
     As a special case of 2), when each row of Input(Label) has only one
     non-zero element (equals 1), soft-label cross-entropy degenerates to a
     one-hot cross-entropy with one-hot label representation.
D
dangqingqing 已提交
146 147 148

Both the input `X` and `Label` can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD with input `X`.
Q
Qiao Longfei 已提交
149 150 151 152 153 154
)DOC");
  }
};
}  // namespace operators
}  // namespace paddle

D
dongzhihong 已提交
155
namespace ops = paddle::operators;
156 157 158 159 160
REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
            cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>);
REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
                       ops::CrossEntropyGradientOpKernel<float>);