提交 97bfc0df 编写于 作者: Y yangyaming

Add comments.

上级 65dbbd57
...@@ -22,7 +22,6 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { ...@@ -22,7 +22,6 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
// may contains weights and StatesInfo
PADDLE_ENFORCE(ctx->HasInput("Predictions"), PADDLE_ENFORCE(ctx->HasInput("Predictions"),
"Input(Predictions) should not be null."); "Input(Predictions) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Labels"), PADDLE_ENFORCE(ctx->HasInput("Labels"),
...@@ -108,11 +107,54 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -108,11 +107,54 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
"provided, current state will be accumulated to this state and " "provided, current state will be accumulated to this state and "
"the accumulation state will be as the output state.") "the accumulation state will be as the output state.")
.AsDispensable(); .AsDispensable();
AddOutput("BatchMetrics", ""); AddOutput("BatchMetrics",
AddOutput("AccumMetrics", ""); "(Tensor, default Tensor<float>), a 1-D tensor with shape {6}."
AddOutput("AccumStatesInfo", ""); "This output tensor contains metrics for current batch data."
"The layout is [macro average precision, macro average recall, "
"macro f1 score, micro average precision, micro average recall, "
"micro f1 score]");
AddOutput("AccumMetrics",
"(Tensor, default Tensor<float>), a 1-D tensor with shape {6}."
"This output tensor contains metrics for accumulated data."
"The layout is [macro average precision, macro average recall, "
"macro f1 score, micro average precision, micro average recall, "
"micro f1 score]");
AddOutput("AccumStatesInfo",
"(Tensor, default Tensor<float>), a 2-D tensor with shape D x 4, "
"where D is equal to class number. This output tensor contains "
"accumulated state variables used to compute metrics. The layout "
"for each class is [true positives, false positives, "
"true negatives, false negatives].");
AddComment(R"DOC( AddComment(R"DOC(
When given 'Input(Predictions)' and 'Input(Labels)', this operator can be used
to compute various metrics including:
- macro average precision
- macro average recall
- macro f1 score
- micro average precision
- micro average recall
- micro f1 score
To compute the above metrics, we need to statistic counts for true positives,
false positives and false negatives. Here count of true negatives is not
necessary, but statisticing it may provide potential usage and the cost is
trivial, so the operator also provides count of true negatives.
We define state as a 2-D tensor with shape [class number, 4]. Each row of a
state contains statistic variables for corresponding class. Layout of each row
is: TP(true positives), FP(false positives), TN(true negatives),
FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be
calculated by given weight instead of instance count.
This operator also supports metrics computing for cross-batch situation. To
achieve this, 'Input(StatesInfo)' should be provided. State of current batch
data will be accumulated to 'Input(StatesInfo)' and 'Output(AccumStatesInfo)'
is the accumulation state.
'Output(BatchMetrics)' is metrics of current batch data while
'Output(AccumStatesInfo)' is metrics of accumulation data.
)DOC"); )DOC");
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册