diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index 89f95438019d8c6a038c9c8204b6ae391f47196e..8fce8df8a3b0e52cc8af4d6848d2ba797757e8cd 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -626,78 +626,34 @@ real PrecisionRecallEvaluator::evalImp(std::vector& arguments) { return 0; } -template -void PrecisionRecallEvaluator::printStatsHelper(T1 labelCallback, - T2 microAvgCallback) const { - int label = config_.positive_label(); - if (label != -1) { - CHECK(label >= 0 && label < (int)statsInfo_.size()) - << "positive_label [" << label << "] should be in range [0, " - << statsInfo_.size() << ")"; - double precision = - calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP); - double recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN); - labelCallback(label, precision, recall, calcF1Score(precision, recall)); - return; - } - - // micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2) - // macro average method: precision = (precision1+precision2)/2 - double microTotalTP = 0; - double microTotalFP = 0; - double microTotalFN = 0; - double macroAvgPrecision = 0; - double macroAvgRecall = 0; - size_t numLabels = statsInfo_.size(); - for (size_t i = 0; i < numLabels; ++i) { - microTotalTP += statsInfo_[i].TP; - microTotalFP += statsInfo_[i].FP; - microTotalFN += statsInfo_[i].FN; - macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP); - macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN); - } - macroAvgPrecision /= numLabels; - macroAvgRecall /= numLabels; - double macroAvgF1Score = calcF1Score(macroAvgPrecision, macroAvgRecall); - - double microAvgPrecision = calcPrecision(microTotalTP, microTotalFP); - double microAvgRecall = calcPrecision(microTotalTP, microTotalFN); - double microAvgF1Score = calcF1Score(microAvgPrecision, microAvgRecall); - - microAvgCallback(macroAvgPrecision, - macroAvgRecall, - macroAvgF1Score, - isMultiBinaryLabel_, - microAvgPrecision, - microAvgRecall, - microAvgF1Score); -} - void PrecisionRecallEvaluator::printStats(std::ostream& os) const { - this->printStatsHelper( - [&os](int label, double precision, double recall, double f1) { - os << "positive_label=" << label << " precision=" << precision - << " recall=" << recall << " F1-score=" << f1; - }, - [&os](double macroAvgPrecision, - double macroAvgRecall, - double macroAvgF1Score, - bool isMultiBinaryLabel, - double microAvgPrecision, - double microAvgRecall, - double microAvgF1Score) { - os << "macro-average-precision=" << macroAvgPrecision - << " macro-average-recall=" << macroAvgRecall - << " macro-average-F1-score=" << macroAvgF1Score; - if (!isMultiBinaryLabel) { - // precision and recall are equal in this case - os << " micro-average-precision=" << microAvgPrecision; - } else { - os << " micro-average-precision=" << microAvgPrecision - << " micro-average-recall=" << microAvgRecall - << " micro-average-F1-score=" << microAvgF1Score; - } - }); + double precision, recall, f1, macroAvgPrecision, macroAvgRecall, + macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score; + bool containMacroMicroInfo = getStatsInfo(&precision, + &recall, + &f1, + ¯oAvgPrecision, + ¯oAvgRecall, + ¯oAvgF1Score, + µAvgPrecision, + µAvgRecall, + µAvgF1Score); + os << "positive_label=" << config_.positive_label() + << " precision=" << precision << " recall=" << recall + << " F1-score=" << f1; + if (containMacroMicroInfo) { + os << "macro-average-precision=" << macroAvgPrecision + << " macro-average-recall=" << macroAvgRecall + << " macro-average-F1-score=" << macroAvgF1Score; + if (!isMultiBinaryLabel_) { + // precision and recall are equal in this case + os << " micro-average-precision=" << microAvgPrecision; + } else { + os << " micro-average-precision=" << microAvgPrecision + << " micro-average-recall=" << microAvgRecall + << " micro-average-F1-score=" << microAvgF1Score; + } + }; } void PrecisionRecallEvaluator::calcStatsInfo(const MatrixPtr& output, @@ -780,32 +736,33 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output, void PrecisionRecallEvaluator::storeLocalValues() const { if (this->values_.size() == 0) { - this->printStatsHelper( - [this](int label, double precision, double recall, double f1) { - values_["positive_label"] = (double)label; - values_["precision"] = precision; - values_["recal"] = recall; - values_["F1-score"] = f1; - }, - [this](double macroAvgPrecision, - double macroAvgRecall, - double macroAvgF1Score, - bool isMultiBinaryLabel, - double microAvgPrecision, - double microAvgRecall, - double microAvgF1Score) { - values_["macro-average-precision"] = macroAvgPrecision; - values_["macro-average-recall"] = macroAvgRecall; - values_["macro-average-F1-score"] = macroAvgF1Score; - if (!isMultiBinaryLabel) { - // precision and recall are equal in this case - values_["micro-average-precision"] = microAvgPrecision; - } else { - values_["micro-average-precision"] = microAvgPrecision; - values_["micro-average-recall"] = microAvgRecall; - values_["micro-average-F1-score"] = microAvgF1Score; - } - }); + double precision, recall, f1, macroAvgPrecision, macroAvgRecall, + macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score; + bool containMacroMicroInfo = getStatsInfo(&precision, + &recall, + &f1, + ¯oAvgPrecision, + ¯oAvgRecall, + ¯oAvgF1Score, + µAvgPrecision, + µAvgRecall, + µAvgF1Score); + values_["precision"] = precision; + values_["recal"] = recall; + values_["F1-score"] = f1; + if (containMacroMicroInfo) { + values_["macro-average-precision"] = macroAvgPrecision; + values_["macro-average-recall"] = macroAvgRecall; + values_["macro-average-F1-score"] = macroAvgF1Score; + if (!isMultiBinaryLabel_) { + // precision and recall are equal in this case + values_["micro-average-precision"] = microAvgPrecision; + } else { + values_["micro-average-precision"] = microAvgPrecision; + values_["micro-average-recall"] = microAvgRecall; + values_["micro-average-F1-score"] = microAvgF1Score; + } + } } } @@ -865,6 +822,51 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) { delete[] buf; } +bool PrecisionRecallEvaluator::getStatsInfo(double* precision, + double* recall, + double* f1, + double* macroAvgPrecision, + double* macroAvgRecall, + double* macroAvgF1Score, + double* microAvgPrecision, + double* microAvgRecall, + double* microAvgF1Score) const { + int label = config_.positive_label(); + if (label != -1) { + CHECK(label >= 0 && label < (int)statsInfo_.size()) + << "positive_label [" << label << "] should be in range [0, " + << statsInfo_.size() << ")"; + *precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP); + *recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN); + *f1 = calcF1Score(*precision, *recall); + return false; + } + + // micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2) + // macro average method: precision = (precision1+precision2)/2 + double microTotalTP = 0; + double microTotalFP = 0; + double microTotalFN = 0; + *macroAvgPrecision = 0; + *macroAvgRecall = 0; + size_t numLabels = statsInfo_.size(); + for (size_t i = 0; i < numLabels; ++i) { + microTotalTP += statsInfo_[i].TP; + microTotalFP += statsInfo_[i].FP; + microTotalFN += statsInfo_[i].FN; + *macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP); + *macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN); + } + *macroAvgPrecision /= numLabels; + *macroAvgRecall /= numLabels; + *macroAvgF1Score = calcF1Score(*macroAvgPrecision, *macroAvgRecall); + + *microAvgPrecision = calcPrecision(microTotalTP, microTotalFP); + *microAvgRecall = calcPrecision(microTotalTP, microTotalFN); + *microAvgF1Score = calcF1Score(*microAvgPrecision, *microAvgRecall); + return true; +} + REGISTER_EVALUATOR(pnpair, PnpairEvaluator); void PnpairEvaluator::start() { Evaluator::start(); diff --git a/paddle/gserver/evaluators/Evaluator.h b/paddle/gserver/evaluators/Evaluator.h index a5694a088c3c445422081f505137c96a9f5c5152..eb19e6f4dde5d90e8479b7bd2d78d9bcb61e84dd 100644 --- a/paddle/gserver/evaluators/Evaluator.h +++ b/paddle/gserver/evaluators/Evaluator.h @@ -125,7 +125,7 @@ public: * has multiple field, the name could be `evaluator_name.field1`. For example * the PrecisionRecallEvaluator contains `precision`, `recall` fields. The get * names will return `precision_recall_evaluator.precision`, - * `precision_recall.recal`, etc. + * `precision_recall_evaluator.recal`, etc. * * Also, if current Evaluator is a combined evaluator. getNames will return * all names of all evaluators inside the combined evaluator. @@ -387,8 +387,15 @@ private: IVectorPtr cpuLabel_; MatrixPtr cpuWeight_; - template - void printStatsHelper(T1 labelCallback, T2 microAvgCallback) const; + bool getStatsInfo(double* precision, + double* recall, + double* f1, + double* macroAvgPrecision, + double* macroAvgRecall, + double* macroAvgF1Score, + double* microAvgPrecision, + double* microAvgRecall, + double* microAvgF1Score) const; void calcStatsInfo(const MatrixPtr& output, const IVectorPtr& label, diff --git a/paddle/utils/Error.h b/paddle/utils/Error.h index 1ae202890f874c3fd2045d47353ab4f78819ffaf..cda1b5c37dada8d0c6c77fc2fb03bb614d5301b5 100644 --- a/paddle/utils/Error.h +++ b/paddle/utils/Error.h @@ -37,10 +37,10 @@ namespace paddle { * * Error __must_check bar() { * // do something. - * Status s = foo(); // invoke other method return status. - * if (!s) return s; + * Error err = foo(); // invoke other method return status. + * if (err) return err; * // do something else. - * return Status(); + * return Error(); * } * @endcode{cpp} * @@ -53,8 +53,8 @@ namespace paddle { * * int foo(Error* error) { * // Do something. - * Error s = bar(); - * if (!s) { + * Error err = bar(); + * if (err) { * *error = s; * return 0; * } @@ -68,10 +68,10 @@ namespace paddle { * } * * Error foobar() { - * Error s; + * Error err; * // do something. - * foo(&s); - * if (!s) return s; + * foo(&err); + * if (err) return err; * } * @endcode{cpp} * @@ -112,18 +112,22 @@ public: } /** - * @brief operator bool, return True if there is no error. + * @brief operator bool, return True if there is something error. */ - operator bool() const { return msg_ == nullptr; } + operator bool() const { return !this->isOK(); } - bool isOK() const { return *this; } + /** + * @brief isOK return True if there is no error. + * @return True if no error. + */ + bool isOK() const { return msg_ == nullptr; } /** * @brief check this status by glog. * @note It is a temp method used during cleaning Paddle code. It will be * removed later. */ - void check() const { CHECK(*this) << msg(); } + void check() const { CHECK(this->isOK()) << msg(); } private: std::shared_ptr msg_; diff --git a/paddle/utils/tests/test_Error.cpp b/paddle/utils/tests/test_Error.cpp index 85156466e2cafd36d49941836c066a542dbbd60e..fdf326b17a1c8baa87e2a17fafae253565d1e699 100644 --- a/paddle/utils/tests/test_Error.cpp +++ b/paddle/utils/tests/test_Error.cpp @@ -18,17 +18,17 @@ limitations under the License. */ TEST(Error, testAll) { paddle::Error error; - ASSERT_TRUE(error); - error = paddle::Error("I'm the error"); ASSERT_FALSE(error); + error = paddle::Error("I'm the error"); + ASSERT_TRUE(error); ASSERT_STREQ("I'm the error", error.msg()); error = paddle::Error("error2"); - ASSERT_FALSE(error); + ASSERT_TRUE(error); ASSERT_STREQ("error2", error.msg()); int i = 3; auto error3 = paddle::Error("error%d", i); - ASSERT_FALSE(error3); + ASSERT_TRUE(error3); ASSERT_STREQ("error3", error3.msg()); }