precision_recall_op.cc 9.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
yangyaming 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/metrics/precision_recall_op.h"
Y
yangyaming 已提交
16

Y
yangyaming 已提交
17 18 19 20 21 22 23 24
namespace paddle {
namespace operators {

class PrecisionRecallOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
    PADDLE_ENFORCE_EQ(ctx->HasInput("MaxProbs"), true,
                      platform::errors::InvalidArgument(
                          "Input(MaxProbs) should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Indices"), true,
                      platform::errors::InvalidArgument(
                          "Input(Indices) should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Labels"), true,
        platform::errors::InvalidArgument("Input(Labels) should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasOutput("BatchMetrics"), true,
                      platform::errors::InvalidArgument(
                          "Output(BatchMetrics) should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasOutput("AccumMetrics"), true,
                      platform::errors::InvalidArgument(
                          "Output(AccumMetrics) should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasOutput("AccumStatesInfo"), true,
                      platform::errors::InvalidArgument(
                          "Output(AccumStatesInfo) should not be null."));
Y
yangyaming 已提交
43

Y
yangyaming 已提交
44 45 46
    int64_t cls_num =
        static_cast<int64_t>(ctx->Attrs().Get<int>("class_number"));
    auto max_probs_dims = ctx->GetInputDim("MaxProbs");
Y
yangyaming 已提交
47 48
    auto labels_dims = ctx->GetInputDim("Labels");

P
phlrain 已提交
49 50
    if (ctx->IsRuntime()) {
      PADDLE_ENFORCE_EQ(
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
          max_probs_dims[1], 1,
          platform::errors::InvalidArgument(
              "Each instance contains one max probability, so the shape of "
              "Input(MaxProbs) should be [batch_size, 1]. But received (%d)",
              max_probs_dims[1]));
      PADDLE_ENFORCE_EQ(ctx->GetInputDim("Indices"), max_probs_dims,
                        platform::errors::InvalidArgument(
                            "The shape of Input(Indices) should bes same with "
                            "max_probs_dims, But received (%d) != (%d)",
                            ctx->GetInputDim("Indices"), max_probs_dims));
      PADDLE_ENFORCE_EQ(max_probs_dims[0], labels_dims[0],
                        platform::errors::InvalidArgument(
                            "The 1st dimension of Input(MaxProbs) and "
                            "Input(Labels) both are batch_size and the shape "
                            "should be the same. But received (%d) != (%d)",
                            max_probs_dims[0], labels_dims[0]));
P
phlrain 已提交
67
      PADDLE_ENFORCE_EQ(
68 69 70 71 72
          labels_dims[1], 1,
          platform::errors::InvalidArgument(
              "The 2nd dimension of Input(Labels) contains instance label and "
              "the shape should be equal to 1. But received (%d)",
              labels_dims[1]));
P
phlrain 已提交
73
    }
Y
yangyaming 已提交
74 75
    if (ctx->HasInput("Weights")) {
      auto weights_dims = ctx->GetInputDim("Weights");
P
phlrain 已提交
76 77

      if (ctx->IsRuntime()) {
78 79 80 81
        PADDLE_ENFORCE_EQ(
            weights_dims, framework::make_ddim({max_probs_dims[0], 1}),
            platform::errors::InvalidArgument(
                "The shape of Input(Weights) should be [batch_size, 1]."));
P
phlrain 已提交
82
      }
Y
yangyaming 已提交
83 84 85
    }
    if (ctx->HasInput("StatesInfo")) {
      auto states_dims = ctx->GetInputDim("StatesInfo");
P
phlrain 已提交
86 87

      if (ctx->IsRuntime()) {
88 89 90 91
        PADDLE_ENFORCE_EQ(
            states_dims, framework::make_ddim({cls_num, 4}),
            platform::errors::InvalidArgument(
                "The shape of Input(StatesInfo) should be [class_number, 4]."));
P
phlrain 已提交
92
      }
Y
yangyaming 已提交
93 94 95 96 97 98 99 100 101 102 103 104
    }

    // Layouts of BatchMetrics and AccumMetrics both are:
    // [
    //  macro average precision, macro average recall, macro average F1 score,
    //  micro average precision, micro average recall, micro average F1 score
    // ]
    ctx->SetOutputDim("BatchMetrics", {6});
    ctx->SetOutputDim("AccumMetrics", {6});
    // Shape of AccumStatesInfo is [class_number, 4]
    // The layout of each row is:
    // [ TP, FP, TN, FN ]
Y
yangyaming 已提交
105
    ctx->SetOutputDim("AccumStatesInfo", {cls_num, 4});
Y
yangyaming 已提交
106
  }
Y
yangyaming 已提交
107 108

 protected:
109
  framework::OpKernelType GetExpectedKernelType(
Y
yangyaming 已提交
110
      const framework::ExecutionContext &ctx) const override {
111 112 113
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "MaxProbs"),
        ctx.device_context());
Y
yangyaming 已提交
114
  }
Y
yangyaming 已提交
115 116 117 118
};

class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
119
  void Make() override {
Y
yangyaming 已提交
120
    AddInput("MaxProbs",
K
kexinzhao 已提交
121
             "(Tensor, default Tensor<float>) A 2-D tensor with shape N x 1, "
Y
yangyaming 已提交
122 123 124 125
             "where N is the batch size. Each row contains the max probability "
             "of an instance which computed by the previous top_k (k=1) "
             "operator.");
    AddInput("Indices",
K
kexinzhao 已提交
126
             "(Tensor, default Tensor<int>) A 2-D tensor with shape N x 1, "
Y
yangyaming 已提交
127 128
             "where N is the batch size. Each row contains the corresponding "
             "index which computed by the previous top_k (k=1) operator.");
Y
yangyaming 已提交
129
    AddInput("Labels",
K
kexinzhao 已提交
130
             "(Tensor, default Tensor<int>) A 2-D tensor with shape N x 1, "
Y
yangyaming 已提交
131 132 133
             "where N is the batch size. Each element is a label and the "
             "value should be in [0, class_number - 1].");
    AddInput("Weights",
K
kexinzhao 已提交
134
             "(Tensor, default Tensor<float>) A 2-D tensor with shape N x 1, "
Y
yangyaming 已提交
135 136 137 138
             "where N is the batch size. This input is optional. If provided, "
             "weight of instance would be considered when computing metrics.")
        .AsDispensable();
    AddInput("StatesInfo",
K
kexinzhao 已提交
139
             "(Tensor, default Tensor<int>) A 2-D tensor with shape D x 4, "
Y
yangyaming 已提交
140 141
             "where D is the number of classes. This input is optional. If "
             "provided, current state will be accumulated to this state and "
K
kexinzhao 已提交
142
             "the accumulation state will be the output state.")
Y
yangyaming 已提交
143
        .AsDispensable();
Y
yangyaming 已提交
144
    AddOutput("BatchMetrics",
K
kexinzhao 已提交
145 146
              "(Tensor, default Tensor<float>) A 1-D tensor with shape {6}. "
              "This output tensor contains metrics for current batch data. "
Y
yangyaming 已提交
147 148
              "The layout is [macro average precision, macro average recall, "
              "macro f1 score, micro average precision, micro average recall, "
K
kexinzhao 已提交
149
              "micro f1 score].");
Y
yangyaming 已提交
150
    AddOutput("AccumMetrics",
K
kexinzhao 已提交
151 152
              "(Tensor, default Tensor<float>) A 1-D tensor with shape {6}. "
              "This output tensor contains metrics for accumulated data. "
Y
yangyaming 已提交
153 154
              "The layout is [macro average precision, macro average recall, "
              "macro f1 score, micro average precision, micro average recall, "
K
kexinzhao 已提交
155
              "micro f1 score].");
Y
yangyaming 已提交
156
    AddOutput("AccumStatesInfo",
K
kexinzhao 已提交
157
              "(Tensor, default Tensor<float>) A 2-D tensor with shape D x 4, "
Y
yangyaming 已提交
158 159 160 161
              "where D is equal to class number. This output tensor contains "
              "accumulated state variables used to compute metrics. The layout "
              "for each class is [true positives, false positives, "
              "true negatives, false negatives].");
K
kexinzhao 已提交
162
    AddAttr<int>("class_number", "(int) Number of classes to be evaluated.");
Y
yangyaming 已提交
163
    AddComment(R"DOC(
K
kexinzhao 已提交
164 165 166
Precision Recall Operator.

When given Input(Indices) and Input(Labels), this operator can be used
Y
yangyaming 已提交
167
to compute various metrics including:
K
kexinzhao 已提交
168 169 170 171 172 173
1. macro average precision
2. macro average recall
3. macro f1 score
4. micro average precision
5. micro average recall
6. micro f1 score
Y
yangyaming 已提交
174

175
To compute the above metrics, we need to do statistics for true positives,
K
kexinzhao 已提交
176
false positives and false negatives. Here the count of true negatives is not
177
necessary, but counting it may provide potential usage and the cost is
K
kexinzhao 已提交
178
trivial, so the operator also provides the count of true negatives.
Y
yangyaming 已提交
179

Y
yangyaming 已提交
180
We define state as a 2-D tensor with shape [class_number, 4]. Each row of a
Y
yangyaming 已提交
181 182
state contains statistic variables for corresponding class. Layout of each row
is: TP(true positives), FP(false positives), TN(true negatives),
K
kexinzhao 已提交
183 184
FN(false negatives). If Input(Weights) is provided, TP, FP, TN, FN will be
calculated by given weight instead of the instance count.
Y
yangyaming 已提交
185 186

This operator also supports metrics computing for cross-batch situation. To
K
kexinzhao 已提交
187 188
achieve this, Input(StatesInfo) should be provided. State of current batch
data will be accumulated to Input(StatesInfo) and Output(AccumStatesInfo)
Y
yangyaming 已提交
189 190
is the accumulation state.

K
kexinzhao 已提交
191 192
Output(BatchMetrics) is metrics of current batch data while
Output(AccumStatesInfo) is metrics of accumulation data.
Y
yangyaming 已提交
193

Y
yangyaming 已提交
194 195 196 197 198 199 200 201
)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
202 203 204 205
REGISTER_OPERATOR(
    precision_recall, ops::PrecisionRecallOp, ops::PrecisionRecallOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
Y
yangyaming 已提交
206 207 208
REGISTER_OP_CPU_KERNEL(
    precision_recall,
    ops::PrecisionRecallKernel<paddle::platform::CPUPlace, float>,
Y
yangyaming 已提交
209
    ops::PrecisionRecallKernel<paddle::platform::CPUPlace, double>);