提交 b1ab8b56 编写于 作者: Y Yu Yang

Use plain C++ 03 to implement getStatsInfo.

上级 bb751b7f
...@@ -626,78 +626,34 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) { ...@@ -626,78 +626,34 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
return 0; return 0;
} }
template <typename T1, typename T2>
void PrecisionRecallEvaluator::printStatsHelper(T1 labelCallback,
T2 microAvgCallback) const {
int label = config_.positive_label();
if (label != -1) {
CHECK(label >= 0 && label < (int)statsInfo_.size())
<< "positive_label [" << label << "] should be in range [0, "
<< statsInfo_.size() << ")";
double precision =
calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
double recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
labelCallback(label, precision, recall, calcF1Score(precision, recall));
return;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double microTotalTP = 0;
double microTotalFP = 0;
double microTotalFN = 0;
double macroAvgPrecision = 0;
double macroAvgRecall = 0;
size_t numLabels = statsInfo_.size();
for (size_t i = 0; i < numLabels; ++i) {
microTotalTP += statsInfo_[i].TP;
microTotalFP += statsInfo_[i].FP;
microTotalFN += statsInfo_[i].FN;
macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
}
macroAvgPrecision /= numLabels;
macroAvgRecall /= numLabels;
double macroAvgF1Score = calcF1Score(macroAvgPrecision, macroAvgRecall);
double microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
double microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
double microAvgF1Score = calcF1Score(microAvgPrecision, microAvgRecall);
microAvgCallback(macroAvgPrecision,
macroAvgRecall,
macroAvgF1Score,
isMultiBinaryLabel_,
microAvgPrecision,
microAvgRecall,
microAvgF1Score);
}
void PrecisionRecallEvaluator::printStats(std::ostream& os) const { void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
this->printStatsHelper( double precision, recall, f1, macroAvgPrecision, macroAvgRecall,
[&os](int label, double precision, double recall, double f1) { macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score;
os << "positive_label=" << label << " precision=" << precision bool containMacroMicroInfo = getStatsInfo(&precision,
<< " recall=" << recall << " F1-score=" << f1; &recall,
}, &f1,
[&os](double macroAvgPrecision, &macroAvgPrecision,
double macroAvgRecall, &macroAvgRecall,
double macroAvgF1Score, &macroAvgF1Score,
bool isMultiBinaryLabel, &microAvgPrecision,
double microAvgPrecision, &microAvgRecall,
double microAvgRecall, &microAvgF1Score);
double microAvgF1Score) { os << "positive_label=" << config_.positive_label()
os << "macro-average-precision=" << macroAvgPrecision << " precision=" << precision << " recall=" << recall
<< " macro-average-recall=" << macroAvgRecall << " F1-score=" << f1;
<< " macro-average-F1-score=" << macroAvgF1Score; if (containMacroMicroInfo) {
if (!isMultiBinaryLabel) { os << "macro-average-precision=" << macroAvgPrecision
// precision and recall are equal in this case << " macro-average-recall=" << macroAvgRecall
os << " micro-average-precision=" << microAvgPrecision; << " macro-average-F1-score=" << macroAvgF1Score;
} else { if (!isMultiBinaryLabel_) {
os << " micro-average-precision=" << microAvgPrecision // precision and recall are equal in this case
<< " micro-average-recall=" << microAvgRecall os << " micro-average-precision=" << microAvgPrecision;
<< " micro-average-F1-score=" << microAvgF1Score; } else {
} os << " micro-average-precision=" << microAvgPrecision
}); << " micro-average-recall=" << microAvgRecall
<< " micro-average-F1-score=" << microAvgF1Score;
}
};
} }
void PrecisionRecallEvaluator::calcStatsInfo(const MatrixPtr& output, void PrecisionRecallEvaluator::calcStatsInfo(const MatrixPtr& output,
...@@ -780,32 +736,33 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output, ...@@ -780,32 +736,33 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
void PrecisionRecallEvaluator::storeLocalValues() const { void PrecisionRecallEvaluator::storeLocalValues() const {
if (this->values_.size() == 0) { if (this->values_.size() == 0) {
this->printStatsHelper( double precision, recall, f1, macroAvgPrecision, macroAvgRecall,
[this](int label, double precision, double recall, double f1) { macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score;
values_["positive_label"] = (double)label; bool containMacroMicroInfo = getStatsInfo(&precision,
values_["precision"] = precision; &recall,
values_["recal"] = recall; &f1,
values_["F1-score"] = f1; &macroAvgPrecision,
}, &macroAvgRecall,
[this](double macroAvgPrecision, &macroAvgF1Score,
double macroAvgRecall, &microAvgPrecision,
double macroAvgF1Score, &microAvgRecall,
bool isMultiBinaryLabel, &microAvgF1Score);
double microAvgPrecision, values_["precision"] = precision;
double microAvgRecall, values_["recal"] = recall;
double microAvgF1Score) { values_["F1-score"] = f1;
values_["macro-average-precision"] = macroAvgPrecision; if (containMacroMicroInfo) {
values_["macro-average-recall"] = macroAvgRecall; values_["macro-average-precision"] = macroAvgPrecision;
values_["macro-average-F1-score"] = macroAvgF1Score; values_["macro-average-recall"] = macroAvgRecall;
if (!isMultiBinaryLabel) { values_["macro-average-F1-score"] = macroAvgF1Score;
// precision and recall are equal in this case if (!isMultiBinaryLabel_) {
values_["micro-average-precision"] = microAvgPrecision; // precision and recall are equal in this case
} else { values_["micro-average-precision"] = microAvgPrecision;
values_["micro-average-precision"] = microAvgPrecision; } else {
values_["micro-average-recall"] = microAvgRecall; values_["micro-average-precision"] = microAvgPrecision;
values_["micro-average-F1-score"] = microAvgF1Score; values_["micro-average-recall"] = microAvgRecall;
} values_["micro-average-F1-score"] = microAvgF1Score;
}); }
}
} }
} }
...@@ -865,6 +822,51 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) { ...@@ -865,6 +822,51 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
delete[] buf; delete[] buf;
} }
bool PrecisionRecallEvaluator::getStatsInfo(double* precision,
double* recall,
double* f1,
double* macroAvgPrecision,
double* macroAvgRecall,
double* macroAvgF1Score,
double* microAvgPrecision,
double* microAvgRecall,
double* microAvgF1Score) const {
int label = config_.positive_label();
if (label != -1) {
CHECK(label >= 0 && label < (int)statsInfo_.size())
<< "positive_label [" << label << "] should be in range [0, "
<< statsInfo_.size() << ")";
*precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
*recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
*f1 = calcF1Score(*precision, *recall);
return false;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double microTotalTP = 0;
double microTotalFP = 0;
double microTotalFN = 0;
*macroAvgPrecision = 0;
*macroAvgRecall = 0;
size_t numLabels = statsInfo_.size();
for (size_t i = 0; i < numLabels; ++i) {
microTotalTP += statsInfo_[i].TP;
microTotalFP += statsInfo_[i].FP;
microTotalFN += statsInfo_[i].FN;
*macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
*macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
}
*macroAvgPrecision /= numLabels;
*macroAvgRecall /= numLabels;
*macroAvgF1Score = calcF1Score(*macroAvgPrecision, *macroAvgRecall);
*microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
*microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
*microAvgF1Score = calcF1Score(*microAvgPrecision, *microAvgRecall);
return true;
}
REGISTER_EVALUATOR(pnpair, PnpairEvaluator); REGISTER_EVALUATOR(pnpair, PnpairEvaluator);
void PnpairEvaluator::start() { void PnpairEvaluator::start() {
Evaluator::start(); Evaluator::start();
......
...@@ -125,7 +125,7 @@ public: ...@@ -125,7 +125,7 @@ public:
* has multiple field, the name could be `evaluator_name.field1`. For example * has multiple field, the name could be `evaluator_name.field1`. For example
* the PrecisionRecallEvaluator contains `precision`, `recall` fields. The get * the PrecisionRecallEvaluator contains `precision`, `recall` fields. The get
* names will return `precision_recall_evaluator.precision`, * names will return `precision_recall_evaluator.precision`,
* `precision_recall.recal`, etc. * `precision_recall_evaluator.recal`, etc.
* *
* Also, if current Evaluator is a combined evaluator. getNames will return * Also, if current Evaluator is a combined evaluator. getNames will return
* all names of all evaluators inside the combined evaluator. * all names of all evaluators inside the combined evaluator.
...@@ -387,8 +387,15 @@ private: ...@@ -387,8 +387,15 @@ private:
IVectorPtr cpuLabel_; IVectorPtr cpuLabel_;
MatrixPtr cpuWeight_; MatrixPtr cpuWeight_;
template <typename T1, typename T2> bool getStatsInfo(double* precision,
void printStatsHelper(T1 labelCallback, T2 microAvgCallback) const; double* recall,
double* f1,
double* macroAvgPrecision,
double* macroAvgRecall,
double* macroAvgF1Score,
double* microAvgPrecision,
double* microAvgRecall,
double* microAvgF1Score) const;
void calcStatsInfo(const MatrixPtr& output, void calcStatsInfo(const MatrixPtr& output,
const IVectorPtr& label, const IVectorPtr& label,
......
...@@ -37,10 +37,10 @@ namespace paddle { ...@@ -37,10 +37,10 @@ namespace paddle {
* *
* Error __must_check bar() { * Error __must_check bar() {
* // do something. * // do something.
* Status s = foo(); // invoke other method return status. * Error err = foo(); // invoke other method return status.
* if (!s) return s; * if (err) return err;
* // do something else. * // do something else.
* return Status(); * return Error();
* } * }
* @endcode{cpp} * @endcode{cpp}
* *
...@@ -53,8 +53,8 @@ namespace paddle { ...@@ -53,8 +53,8 @@ namespace paddle {
* *
* int foo(Error* error) { * int foo(Error* error) {
* // Do something. * // Do something.
* Error s = bar(); * Error err = bar();
* if (!s) { * if (err) {
* *error = s; * *error = s;
* return 0; * return 0;
* } * }
...@@ -68,10 +68,10 @@ namespace paddle { ...@@ -68,10 +68,10 @@ namespace paddle {
* } * }
* *
* Error foobar() { * Error foobar() {
* Error s; * Error err;
* // do something. * // do something.
* foo(&s); * foo(&err);
* if (!s) return s; * if (err) return err;
* } * }
* @endcode{cpp} * @endcode{cpp}
* *
...@@ -112,18 +112,22 @@ public: ...@@ -112,18 +112,22 @@ public:
} }
/** /**
* @brief operator bool, return True if there is no error. * @brief operator bool, return True if there is something error.
*/ */
operator bool() const { return msg_ == nullptr; } operator bool() const { return !this->isOK(); }
bool isOK() const { return *this; } /**
* @brief isOK return True if there is no error.
* @return True if no error.
*/
bool isOK() const { return msg_ == nullptr; }
/** /**
* @brief check this status by glog. * @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be * @note It is a temp method used during cleaning Paddle code. It will be
* removed later. * removed later.
*/ */
void check() const { CHECK(*this) << msg(); } void check() const { CHECK(this->isOK()) << msg(); }
private: private:
std::shared_ptr<std::string> msg_; std::shared_ptr<std::string> msg_;
......
...@@ -18,17 +18,17 @@ limitations under the License. */ ...@@ -18,17 +18,17 @@ limitations under the License. */
TEST(Error, testAll) { TEST(Error, testAll) {
paddle::Error error; paddle::Error error;
ASSERT_TRUE(error);
error = paddle::Error("I'm the error");
ASSERT_FALSE(error); ASSERT_FALSE(error);
error = paddle::Error("I'm the error");
ASSERT_TRUE(error);
ASSERT_STREQ("I'm the error", error.msg()); ASSERT_STREQ("I'm the error", error.msg());
error = paddle::Error("error2"); error = paddle::Error("error2");
ASSERT_FALSE(error); ASSERT_TRUE(error);
ASSERT_STREQ("error2", error.msg()); ASSERT_STREQ("error2", error.msg());
int i = 3; int i = 3;
auto error3 = paddle::Error("error%d", i); auto error3 = paddle::Error("error%d", i);
ASSERT_FALSE(error3); ASSERT_TRUE(error3);
ASSERT_STREQ("error3", error3.msg()); ASSERT_STREQ("error3", error3.msg());
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册