cross_entropy_op.cc 6.7 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/cross_entropy_op.h"

namespace paddle {
namespace operators {

20
class CrossEntropyOp : public framework::OperatorWithKernel {
Y
Yu Yang 已提交
21 22 23
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

24
 protected:
D
dongzhihong 已提交
25
  void InferShape(const framework::InferShapeContext &ctx) const override {
C
caoying03 已提交
26
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
27
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
C
caoying03 已提交
28 29 30
                            "Input(Label) should be not null.");
    PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
                            "Output(Y) should be not null.");
31 32 33

    auto x = ctx.Input<Tensor>("X");
    auto label = ctx.Input<Tensor>("Label");
C
caoying03 已提交
34
    PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
35
    PADDLE_ENFORCE_EQ(label->dims().size(), 2,
C
caoying03 已提交
36
                      "Input(Label)'s rank should be 2.");
37
    PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
C
caoying03 已提交
38
                      "The 1st dimension of Input(X) and Input(Label) should "
39
                      "be equal.");
C
caoying03 已提交
40
    if (ctx.Attr<bool>("softLabel")) {
41
      PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
C
caoying03 已提交
42
                        "If Attr(softLabel) == true, the 2nd dimension of "
C
caoying03 已提交
43
                        "Input(X) and Input(Label) should be equal.");
44
    } else {
45
      PADDLE_ENFORCE_EQ(label->dims()[1], 1,
C
caoying03 已提交
46
                        "If Attr(softLabel) == false, the 2nd dimension of "
C
caoying03 已提交
47
                        "Input(Label) should be 1.");
48
    }
49

D
dangqingqing 已提交
50
    ctx.Output<Tensor>("Y")->Resize({x->dims()[0], 1});
51
    ctx.ShareLoD("X", /*->*/ "Y");
Q
Qiao Longfei 已提交
52 53 54
  }
};

55
class CrossEntropyGradientOp : public framework::OperatorWithKernel {
Y
Yu Yang 已提交
56 57 58
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

Y
Yan Chunwei 已提交
59
 protected:
D
dongzhihong 已提交
60
  void InferShape(const framework::InferShapeContext &ctx) const override {
C
caoying03 已提交
61
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
62
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
C
caoying03 已提交
63
                            "Input(Label) should be not null.");
64
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
C
caoying03 已提交
65
                            "Input(Y@GRAD) shoudl be not null.");
C
caoying03 已提交
66 67
    PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("X")),
                            "Output(X@GRAD) should be not null.");
68

69
    auto x = ctx.Input<Tensor>("X");
70 71
    auto label = ctx.Input<Tensor>("Label");
    auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
C
caoying03 已提交
72 73 74
    PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
    PADDLE_ENFORCE_EQ(dy->dims().size(), 2,
                      "Input(Y@Grad)'s rank should be 2.");
75
    PADDLE_ENFORCE_EQ(label->dims().size(), 2,
C
caoying03 已提交
76
                      "Input(Label)'s rank should be 2.");
77
    PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
C
caoying03 已提交
78
                      "The 1st dimension of Input(X) and Input(Label) should "
79 80
                      "be equal.");
    PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0],
C
caoying03 已提交
81
                      "The 1st dimension of Input(X) and Input(Y@Grad) should "
82 83
                      "be equal.");
    PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
C
caoying03 已提交
84
                      "The 2nd dimension of Input(Y@Grad) should be 1.");
C
caoying03 已提交
85
    if (ctx.Attr<bool>("softLabel")) {
86
      PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
C
caoying03 已提交
87
                        "When Attr(softLabel) == true, the 2nd dimension of "
C
caoying03 已提交
88
                        "Input(X) and Input(Label) should be equal.");
89 90
    } else {
      PADDLE_ENFORCE_EQ(label->dims()[1], 1,
C
caoying03 已提交
91
                        "When Attr(softLabel) == false, the 2nd dimension of "
C
caoying03 已提交
92
                        "Input(Label) should be 1.");
93
    }
Y
Yan Chunwei 已提交
94

D
dangqingqing 已提交
95
    auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
96
    dx->Resize(x->dims());
Y
Yan Chunwei 已提交
97 98 99
  }
};

100
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
101
 public:
102 103
  CrossEntropyOpMaker(framework::OpProto *proto,
                      framework::OpAttrChecker *op_checker)
104
      : OpProtoAndCheckerMaker(proto, op_checker) {
C
caoying03 已提交
105 106 107 108 109
    AddInput("X",
             "(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
             "where N is the batch size and D is the number of classes. "
             "This input is a probability computed by the previous operator, "
             "which is almost always the result of a softmax operator.");
C
caoying03 已提交
110 111 112 113 114 115 116 117
    AddInput(
        "Label",
        "(Tensor, default Tensor<int>), the ground truth which is "
        "a 2-D tensor. "
        "When softLabel is set to false, `Label` is a Tensor<int> with shape "
        "[N x 1]. "
        "When softLabel is set to true, `Label` is a Tensor<float/double> "
        "with shape [N x K].");
C
caoying03 已提交
118
    AddOutput("Y",
C
caoying03 已提交
119
              "(Tensor, default Tensor<float>), a 2-D tensor "
C
caoying03 已提交
120 121
              "with shape [N x 1]. The cross entropy loss.");
    AddAttr<bool>(
C
caoying03 已提交
122
        "softLabel",
C
caoying03 已提交
123 124
        "(bool, default false), a flag to indicate whether to interpretate "
        "the given labels as soft labels.")
125
        .SetDefault(false);
Q
Qiao Longfei 已提交
126
    AddComment(R"DOC(
127
CrossEntropy Operator.
Q
Qiao Longfei 已提交
128

129 130 131
It supports both standard cross-entropy and soft-label cross-entropy loss
computation.
1) One-hot cross-entropy:
C
caoying03 已提交
132
    softLabel = false, Label[i, 0] indicates the class index for sample i:
133

134
                Y[i] = -log(X[i, Label[i]])
Q
Qiao Longfei 已提交
135

136
2) Soft-label cross-entropy:
C
caoying03 已提交
137
    softLabel = true, Label[i, j] indicates the soft label of class j
138
    for sample i:
139

140
                Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
141

142
   Please make sure that in this case the summuation of each row of Label
143 144 145 146 147 148
   equals one.

3) One-hot cross-entropy with vecterized Input(Label):
     As a special case of 2), when each row of Input(Label) has only one
     non-zero element (equals 1), soft-label cross-entropy degenerates to a
     one-hot cross-entropy with one-hot label representation.
D
dangqingqing 已提交
149 150 151

Both the input `X` and `Label` can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD with input `X`.
Q
Qiao Longfei 已提交
152 153 154 155 156 157
)DOC");
  }
};
}  // namespace operators
}  // namespace paddle

D
dongzhihong 已提交
158
namespace ops = paddle::operators;
159 160 161 162 163
REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
            cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>);
REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
                       ops::CrossEntropyGradientOpKernel<float>);