From 9087e3a9bb9aac10236b513c5c9a52af1590b3a7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 23 Feb 2017 16:18:19 +0800 Subject: [PATCH] Follow comments to use struct get return value. --- paddle/gserver/evaluators/Evaluator.cpp | 108 ++++++++++-------------- paddle/gserver/evaluators/Evaluator.h | 22 +++-- 2 files changed, 56 insertions(+), 74 deletions(-) diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index 5911a9ec59a..9db6d252d97 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -647,33 +647,24 @@ real PrecisionRecallEvaluator::evalImp(std::vector& arguments) { } void PrecisionRecallEvaluator::printStats(std::ostream& os) const { - double precision, recall, f1, macroAvgPrecision, macroAvgRecall, - macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score; - bool containMacroMicroInfo = getStatsInfo(&precision, - &recall, - &f1, - ¯oAvgPrecision, - ¯oAvgRecall, - ¯oAvgF1Score, - µAvgPrecision, - µAvgRecall, - µAvgF1Score); + PrintStatsInfo info; + bool containMacroMicroInfo = getStatsInfo(&info); os << "positive_label=" << config_.positive_label() - << " precision=" << precision << " recall=" << recall - << " F1-score=" << f1; + << " precision=" << info.precision << " recall=" << info.recall + << " F1-score=" << info.f1; if (containMacroMicroInfo) { - os << "macro-average-precision=" << macroAvgPrecision - << " macro-average-recall=" << macroAvgRecall - << " macro-average-F1-score=" << macroAvgF1Score; + os << "macro-average-precision=" << info.macroAvgPrecision + << " macro-average-recall=" << info.macroAvgRecall + << " macro-average-F1-score=" << info.macroAvgF1Score; if (!isMultiBinaryLabel_) { // precision and recall are equal in this case - os << " micro-average-precision=" << microAvgPrecision; + os << " micro-average-precision=" << info.microAvgPrecision; } else { - os << " micro-average-precision=" << microAvgPrecision - << " micro-average-recall=" << microAvgRecall - << " micro-average-F1-score=" << microAvgF1Score; + os << " micro-average-precision=" << info.microAvgPrecision + << " micro-average-recall=" << info.microAvgRecall + << " micro-average-F1-score=" << info.microAvgF1Score; } - }; + } } void PrecisionRecallEvaluator::calcStatsInfo(const MatrixPtr& output, @@ -756,31 +747,22 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output, void PrecisionRecallEvaluator::storeLocalValues() const { if (this->values_.size() == 0) { - double precision, recall, f1, macroAvgPrecision, macroAvgRecall, - macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score; - bool containMacroMicroInfo = getStatsInfo(&precision, - &recall, - &f1, - ¯oAvgPrecision, - ¯oAvgRecall, - ¯oAvgF1Score, - µAvgPrecision, - µAvgRecall, - µAvgF1Score); - values_["precision"] = precision; - values_["recal"] = recall; - values_["F1-score"] = f1; + PrintStatsInfo info; + bool containMacroMicroInfo = getStatsInfo(&info); + values_["precision"] = info.precision; + values_["recal"] = info.recall; + values_["F1-score"] = info.f1; if (containMacroMicroInfo) { - values_["macro-average-precision"] = macroAvgPrecision; - values_["macro-average-recall"] = macroAvgRecall; - values_["macro-average-F1-score"] = macroAvgF1Score; + values_["macro-average-precision"] = info.macroAvgPrecision; + values_["macro-average-recall"] = info.macroAvgRecall; + values_["macro-average-F1-score"] = info.macroAvgF1Score; if (!isMultiBinaryLabel_) { // precision and recall are equal in this case - values_["micro-average-precision"] = microAvgPrecision; + values_["micro-average-precision"] = info.microAvgPrecision; } else { - values_["micro-average-precision"] = microAvgPrecision; - values_["micro-average-recall"] = microAvgRecall; - values_["micro-average-F1-score"] = microAvgF1Score; + values_["micro-average-precision"] = info.microAvgPrecision; + values_["micro-average-recall"] = info.microAvgRecall; + values_["micro-average-F1-score"] = info.microAvgF1Score; } } } @@ -836,23 +818,16 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) { delete[] buf; } -bool PrecisionRecallEvaluator::getStatsInfo(double* precision, - double* recall, - double* f1, - double* macroAvgPrecision, - double* macroAvgRecall, - double* macroAvgF1Score, - double* microAvgPrecision, - double* microAvgRecall, - double* microAvgF1Score) const { +bool PrecisionRecallEvaluator::getStatsInfo( + PrecisionRecallEvaluator::PrintStatsInfo* info) const { int label = config_.positive_label(); if (label != -1) { CHECK(label >= 0 && label < (int)statsInfo_.size()) << "positive_label [" << label << "] should be in range [0, " << statsInfo_.size() << ")"; - *precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP); - *recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN); - *f1 = calcF1Score(*precision, *recall); + info->precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP); + info->recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN); + info->f1 = calcF1Score(info->precision, info->recall); return false; } @@ -861,23 +836,26 @@ bool PrecisionRecallEvaluator::getStatsInfo(double* precision, double microTotalTP = 0; double microTotalFP = 0; double microTotalFN = 0; - *macroAvgPrecision = 0; - *macroAvgRecall = 0; + info->macroAvgPrecision = 0; + info->macroAvgRecall = 0; size_t numLabels = statsInfo_.size(); for (size_t i = 0; i < numLabels; ++i) { microTotalTP += statsInfo_[i].TP; microTotalFP += statsInfo_[i].FP; microTotalFN += statsInfo_[i].FN; - *macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP); - *macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN); + info->macroAvgPrecision += + calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP); + info->macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN); } - *macroAvgPrecision /= numLabels; - *macroAvgRecall /= numLabels; - *macroAvgF1Score = calcF1Score(*macroAvgPrecision, *macroAvgRecall); - - *microAvgPrecision = calcPrecision(microTotalTP, microTotalFP); - *microAvgRecall = calcPrecision(microTotalTP, microTotalFN); - *microAvgF1Score = calcF1Score(*microAvgPrecision, *microAvgRecall); + info->macroAvgPrecision /= numLabels; + info->macroAvgRecall /= numLabels; + info->macroAvgF1Score = + calcF1Score(info->macroAvgPrecision, info->macroAvgRecall); + + info->microAvgPrecision = calcPrecision(microTotalTP, microTotalFP); + info->microAvgRecall = calcPrecision(microTotalTP, microTotalFN); + info->microAvgF1Score = + calcF1Score(info->microAvgPrecision, info->microAvgRecall); return true; } diff --git a/paddle/gserver/evaluators/Evaluator.h b/paddle/gserver/evaluators/Evaluator.h index c4110ec1c0b..b114500e2b7 100644 --- a/paddle/gserver/evaluators/Evaluator.h +++ b/paddle/gserver/evaluators/Evaluator.h @@ -379,15 +379,19 @@ private: IVectorPtr cpuLabel_; MatrixPtr cpuWeight_; - bool getStatsInfo(double* precision, - double* recall, - double* f1, - double* macroAvgPrecision, - double* macroAvgRecall, - double* macroAvgF1Score, - double* microAvgPrecision, - double* microAvgRecall, - double* microAvgF1Score) const; + struct PrintStatsInfo { + double precision; + double recall; + double f1; + double macroAvgPrecision; + double macroAvgRecall; + double macroAvgF1Score; + double microAvgPrecision; + double microAvgRecall; + double microAvgF1Score; + }; + + bool getStatsInfo(PrintStatsInfo* info) const; void calcStatsInfo(const MatrixPtr& output, const IVectorPtr& label, -- GitLab