diff --git a/paddle/operators/precision_recall_op.cc b/paddle/operators/precision_recall_op.cc index 24246907b1c44a38c962a96b2e0d7d3662e4a9eb..a3f4c07493d0f3bd91daa59ebc14c42e16d807e8 100644 --- a/paddle/operators/precision_recall_op.cc +++ b/paddle/operators/precision_recall_op.cc @@ -136,9 +136,9 @@ to compute various metrics including: - micro average recall - micro f1 score -To compute the above metrics, we need to statistic counts for true positives, +To compute the above metrics, we need to do statistics for true positives, false positives and false negatives. Here count of true negatives is not -necessary, but statisticing it may provide potential usage and the cost is +necessary, but counting it may provide potential usage and the cost is trivial, so the operator also provides count of true negatives. We define state as a 2-D tensor with shape [class number, 4]. Each row of a diff --git a/paddle/operators/precision_recall_op.h b/paddle/operators/precision_recall_op.h index 3bc638ea448c07f0f183257ab933b17e4021e219..2e49bc3bb5478c4dfc9cbcfa797f3b509ac4e5cd 100644 --- a/paddle/operators/precision_recall_op.h +++ b/paddle/operators/precision_recall_op.h @@ -42,8 +42,8 @@ class PrecisionRecallKernel : public framework::OpKernel { const int* labels_data = in1->data(); const T* weights_data = in2 ? in2->data() : nullptr; const T* states_data = in3 ? in3->data() : nullptr; - T* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); - T* accum_metrics_data = out1->mutable_data(ctx.GetPlace()); + double* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); + double* accum_metrics_data = out1->mutable_data(ctx.GetPlace()); out2->mutable_data(ctx.GetPlace()); auto accum_states = EigenMatrix::From(*out2); accum_states.setZero(); @@ -121,7 +121,7 @@ class PrecisionRecallKernel : public framework::OpKernel { } protected: - void ComputeMetrics(const T* states_data, T* metrics_data, + void ComputeMetrics(const T* states_data, double* metrics_data, size_t state_var_num, size_t class_dim) const { T total_tp_count = 0; T total_fp_count = 0;