From 97bfc0dfae147f5514251b077eb26a4ed831b890 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 27 Oct 2017 11:05:57 +0800 Subject: [PATCH] Add comments. --- paddle/operators/precision_recall_op.cc | 50 +++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/paddle/operators/precision_recall_op.cc b/paddle/operators/precision_recall_op.cc index 47a16b946..24246907b 100644 --- a/paddle/operators/precision_recall_op.cc +++ b/paddle/operators/precision_recall_op.cc @@ -22,7 +22,6 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - // may contains weights and StatesInfo PADDLE_ENFORCE(ctx->HasInput("Predictions"), "Input(Predictions) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Labels"), @@ -108,11 +107,54 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { "provided, current state will be accumulated to this state and " "the accumulation state will be as the output state.") .AsDispensable(); - AddOutput("BatchMetrics", ""); - AddOutput("AccumMetrics", ""); - AddOutput("AccumStatesInfo", ""); + AddOutput("BatchMetrics", + "(Tensor, default Tensor), a 1-D tensor with shape {6}." + "This output tensor contains metrics for current batch data." + "The layout is [macro average precision, macro average recall, " + "macro f1 score, micro average precision, micro average recall, " + "micro f1 score]"); + AddOutput("AccumMetrics", + "(Tensor, default Tensor), a 1-D tensor with shape {6}." + "This output tensor contains metrics for accumulated data." + "The layout is [macro average precision, macro average recall, " + "macro f1 score, micro average precision, micro average recall, " + "micro f1 score]"); + AddOutput("AccumStatesInfo", + "(Tensor, default Tensor), a 2-D tensor with shape D x 4, " + "where D is equal to class number. This output tensor contains " + "accumulated state variables used to compute metrics. The layout " + "for each class is [true positives, false positives, " + "true negatives, false negatives]."); AddComment(R"DOC( +When given 'Input(Predictions)' and 'Input(Labels)', this operator can be used +to compute various metrics including: + - macro average precision + - macro average recall + - macro f1 score + - micro average precision + - micro average recall + - micro f1 score + +To compute the above metrics, we need to statistic counts for true positives, +false positives and false negatives. Here count of true negatives is not +necessary, but statisticing it may provide potential usage and the cost is +trivial, so the operator also provides count of true negatives. + +We define state as a 2-D tensor with shape [class number, 4]. Each row of a +state contains statistic variables for corresponding class. Layout of each row +is: TP(true positives), FP(false positives), TN(true negatives), +FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be +calculated by given weight instead of instance count. + +This operator also supports metrics computing for cross-batch situation. To +achieve this, 'Input(StatesInfo)' should be provided. State of current batch +data will be accumulated to 'Input(StatesInfo)' and 'Output(AccumStatesInfo)' +is the accumulation state. + +'Output(BatchMetrics)' is metrics of current batch data while +'Output(AccumStatesInfo)' is metrics of accumulation data. + )DOC"); } }; -- GitLab