From d2b10cc0b1b6a3267698f0d63d721ca99dc6ecf6 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 27 Oct 2017 15:18:28 +0800 Subject: [PATCH] Refine doc and fix data type of metrics. --- paddle/operators/precision_recall_op.cc | 4 ++-- paddle/operators/precision_recall_op.h | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/operators/precision_recall_op.cc b/paddle/operators/precision_recall_op.cc index 24246907b1c..a3f4c07493d 100644 --- a/paddle/operators/precision_recall_op.cc +++ b/paddle/operators/precision_recall_op.cc @@ -136,9 +136,9 @@ to compute various metrics including: - micro average recall - micro f1 score -To compute the above metrics, we need to statistic counts for true positives, +To compute the above metrics, we need to do statistics for true positives, false positives and false negatives. Here count of true negatives is not -necessary, but statisticing it may provide potential usage and the cost is +necessary, but counting it may provide potential usage and the cost is trivial, so the operator also provides count of true negatives. We define state as a 2-D tensor with shape [class number, 4]. Each row of a diff --git a/paddle/operators/precision_recall_op.h b/paddle/operators/precision_recall_op.h index 3bc638ea448..2e49bc3bb54 100644 --- a/paddle/operators/precision_recall_op.h +++ b/paddle/operators/precision_recall_op.h @@ -42,8 +42,8 @@ class PrecisionRecallKernel : public framework::OpKernel { const int* labels_data = in1->data(); const T* weights_data = in2 ? in2->data() : nullptr; const T* states_data = in3 ? in3->data() : nullptr; - T* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); - T* accum_metrics_data = out1->mutable_data(ctx.GetPlace()); + double* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); + double* accum_metrics_data = out1->mutable_data(ctx.GetPlace()); out2->mutable_data(ctx.GetPlace()); auto accum_states = EigenMatrix::From(*out2); accum_states.setZero(); @@ -121,7 +121,7 @@ class PrecisionRecallKernel : public framework::OpKernel { } protected: - void ComputeMetrics(const T* states_data, T* metrics_data, + void ComputeMetrics(const T* states_data, double* metrics_data, size_t state_var_num, size_t class_dim) const { T total_tp_count = 0; T total_fp_count = 0; -- GitLab