提交 b1ab8b56 编写于 作者: Y Yu Yang

Use plain C++ 03 to implement getStatsInfo.

上级 bb751b7f
......@@ -626,78 +626,34 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
return 0;
}
template <typename T1, typename T2>
void PrecisionRecallEvaluator::printStatsHelper(T1 labelCallback,
T2 microAvgCallback) const {
int label = config_.positive_label();
if (label != -1) {
CHECK(label >= 0 && label < (int)statsInfo_.size())
<< "positive_label [" << label << "] should be in range [0, "
<< statsInfo_.size() << ")";
double precision =
calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
double recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
labelCallback(label, precision, recall, calcF1Score(precision, recall));
return;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double microTotalTP = 0;
double microTotalFP = 0;
double microTotalFN = 0;
double macroAvgPrecision = 0;
double macroAvgRecall = 0;
size_t numLabels = statsInfo_.size();
for (size_t i = 0; i < numLabels; ++i) {
microTotalTP += statsInfo_[i].TP;
microTotalFP += statsInfo_[i].FP;
microTotalFN += statsInfo_[i].FN;
macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
}
macroAvgPrecision /= numLabels;
macroAvgRecall /= numLabels;
double macroAvgF1Score = calcF1Score(macroAvgPrecision, macroAvgRecall);
double microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
double microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
double microAvgF1Score = calcF1Score(microAvgPrecision, microAvgRecall);
microAvgCallback(macroAvgPrecision,
macroAvgRecall,
macroAvgF1Score,
isMultiBinaryLabel_,
microAvgPrecision,
microAvgRecall,
microAvgF1Score);
}
void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
this->printStatsHelper(
[&os](int label, double precision, double recall, double f1) {
os << "positive_label=" << label << " precision=" << precision
<< " recall=" << recall << " F1-score=" << f1;
},
[&os](double macroAvgPrecision,
double macroAvgRecall,
double macroAvgF1Score,
bool isMultiBinaryLabel,
double microAvgPrecision,
double microAvgRecall,
double microAvgF1Score) {
os << "macro-average-precision=" << macroAvgPrecision
<< " macro-average-recall=" << macroAvgRecall
<< " macro-average-F1-score=" << macroAvgF1Score;
if (!isMultiBinaryLabel) {
// precision and recall are equal in this case
os << " micro-average-precision=" << microAvgPrecision;
} else {
os << " micro-average-precision=" << microAvgPrecision
<< " micro-average-recall=" << microAvgRecall
<< " micro-average-F1-score=" << microAvgF1Score;
}
});
double precision, recall, f1, macroAvgPrecision, macroAvgRecall,
macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score;
bool containMacroMicroInfo = getStatsInfo(&precision,
&recall,
&f1,
&macroAvgPrecision,
&macroAvgRecall,
&macroAvgF1Score,
&microAvgPrecision,
&microAvgRecall,
&microAvgF1Score);
os << "positive_label=" << config_.positive_label()
<< " precision=" << precision << " recall=" << recall
<< " F1-score=" << f1;
if (containMacroMicroInfo) {
os << "macro-average-precision=" << macroAvgPrecision
<< " macro-average-recall=" << macroAvgRecall
<< " macro-average-F1-score=" << macroAvgF1Score;
if (!isMultiBinaryLabel_) {
// precision and recall are equal in this case
os << " micro-average-precision=" << microAvgPrecision;
} else {
os << " micro-average-precision=" << microAvgPrecision
<< " micro-average-recall=" << microAvgRecall
<< " micro-average-F1-score=" << microAvgF1Score;
}
};
}
void PrecisionRecallEvaluator::calcStatsInfo(const MatrixPtr& output,
......@@ -780,32 +736,33 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
void PrecisionRecallEvaluator::storeLocalValues() const {
if (this->values_.size() == 0) {
this->printStatsHelper(
[this](int label, double precision, double recall, double f1) {
values_["positive_label"] = (double)label;
values_["precision"] = precision;
values_["recal"] = recall;
values_["F1-score"] = f1;
},
[this](double macroAvgPrecision,
double macroAvgRecall,
double macroAvgF1Score,
bool isMultiBinaryLabel,
double microAvgPrecision,
double microAvgRecall,
double microAvgF1Score) {
values_["macro-average-precision"] = macroAvgPrecision;
values_["macro-average-recall"] = macroAvgRecall;
values_["macro-average-F1-score"] = macroAvgF1Score;
if (!isMultiBinaryLabel) {
// precision and recall are equal in this case
values_["micro-average-precision"] = microAvgPrecision;
} else {
values_["micro-average-precision"] = microAvgPrecision;
values_["micro-average-recall"] = microAvgRecall;
values_["micro-average-F1-score"] = microAvgF1Score;
}
});
double precision, recall, f1, macroAvgPrecision, macroAvgRecall,
macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score;
bool containMacroMicroInfo = getStatsInfo(&precision,
&recall,
&f1,
&macroAvgPrecision,
&macroAvgRecall,
&macroAvgF1Score,
&microAvgPrecision,
&microAvgRecall,
&microAvgF1Score);
values_["precision"] = precision;
values_["recal"] = recall;
values_["F1-score"] = f1;
if (containMacroMicroInfo) {
values_["macro-average-precision"] = macroAvgPrecision;
values_["macro-average-recall"] = macroAvgRecall;
values_["macro-average-F1-score"] = macroAvgF1Score;
if (!isMultiBinaryLabel_) {
// precision and recall are equal in this case
values_["micro-average-precision"] = microAvgPrecision;
} else {
values_["micro-average-precision"] = microAvgPrecision;
values_["micro-average-recall"] = microAvgRecall;
values_["micro-average-F1-score"] = microAvgF1Score;
}
}
}
}
......@@ -865,6 +822,51 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
delete[] buf;
}
bool PrecisionRecallEvaluator::getStatsInfo(double* precision,
double* recall,
double* f1,
double* macroAvgPrecision,
double* macroAvgRecall,
double* macroAvgF1Score,
double* microAvgPrecision,
double* microAvgRecall,
double* microAvgF1Score) const {
int label = config_.positive_label();
if (label != -1) {
CHECK(label >= 0 && label < (int)statsInfo_.size())
<< "positive_label [" << label << "] should be in range [0, "
<< statsInfo_.size() << ")";
*precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
*recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
*f1 = calcF1Score(*precision, *recall);
return false;
}
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
// macro average method: precision = (precision1+precision2)/2
double microTotalTP = 0;
double microTotalFP = 0;
double microTotalFN = 0;
*macroAvgPrecision = 0;
*macroAvgRecall = 0;
size_t numLabels = statsInfo_.size();
for (size_t i = 0; i < numLabels; ++i) {
microTotalTP += statsInfo_[i].TP;
microTotalFP += statsInfo_[i].FP;
microTotalFN += statsInfo_[i].FN;
*macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
*macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
}
*macroAvgPrecision /= numLabels;
*macroAvgRecall /= numLabels;
*macroAvgF1Score = calcF1Score(*macroAvgPrecision, *macroAvgRecall);
*microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
*microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
*microAvgF1Score = calcF1Score(*microAvgPrecision, *microAvgRecall);
return true;
}
REGISTER_EVALUATOR(pnpair, PnpairEvaluator);
void PnpairEvaluator::start() {
Evaluator::start();
......
......@@ -125,7 +125,7 @@ public:
* has multiple field, the name could be `evaluator_name.field1`. For example
* the PrecisionRecallEvaluator contains `precision`, `recall` fields. The get
* names will return `precision_recall_evaluator.precision`,
* `precision_recall.recal`, etc.
* `precision_recall_evaluator.recal`, etc.
*
* Also, if current Evaluator is a combined evaluator. getNames will return
* all names of all evaluators inside the combined evaluator.
......@@ -387,8 +387,15 @@ private:
IVectorPtr cpuLabel_;
MatrixPtr cpuWeight_;
template <typename T1, typename T2>
void printStatsHelper(T1 labelCallback, T2 microAvgCallback) const;
bool getStatsInfo(double* precision,
double* recall,
double* f1,
double* macroAvgPrecision,
double* macroAvgRecall,
double* macroAvgF1Score,
double* microAvgPrecision,
double* microAvgRecall,
double* microAvgF1Score) const;
void calcStatsInfo(const MatrixPtr& output,
const IVectorPtr& label,
......
......@@ -37,10 +37,10 @@ namespace paddle {
*
* Error __must_check bar() {
* // do something.
* Status s = foo(); // invoke other method return status.
* if (!s) return s;
* Error err = foo(); // invoke other method return status.
* if (err) return err;
* // do something else.
* return Status();
* return Error();
* }
* @endcode{cpp}
*
......@@ -53,8 +53,8 @@ namespace paddle {
*
* int foo(Error* error) {
* // Do something.
* Error s = bar();
* if (!s) {
* Error err = bar();
* if (err) {
* *error = s;
* return 0;
* }
......@@ -68,10 +68,10 @@ namespace paddle {
* }
*
* Error foobar() {
* Error s;
* Error err;
* // do something.
* foo(&s);
* if (!s) return s;
* foo(&err);
* if (err) return err;
* }
* @endcode{cpp}
*
......@@ -112,18 +112,22 @@ public:
}
/**
* @brief operator bool, return True if there is no error.
* @brief operator bool, return True if there is something error.
*/
operator bool() const { return msg_ == nullptr; }
operator bool() const { return !this->isOK(); }
bool isOK() const { return *this; }
/**
* @brief isOK return True if there is no error.
* @return True if no error.
*/
bool isOK() const { return msg_ == nullptr; }
/**
* @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be
* removed later.
*/
void check() const { CHECK(*this) << msg(); }
void check() const { CHECK(this->isOK()) << msg(); }
private:
std::shared_ptr<std::string> msg_;
......
......@@ -18,17 +18,17 @@ limitations under the License. */
TEST(Error, testAll) {
paddle::Error error;
ASSERT_TRUE(error);
error = paddle::Error("I'm the error");
ASSERT_FALSE(error);
error = paddle::Error("I'm the error");
ASSERT_TRUE(error);
ASSERT_STREQ("I'm the error", error.msg());
error = paddle::Error("error2");
ASSERT_FALSE(error);
ASSERT_TRUE(error);
ASSERT_STREQ("error2", error.msg());
int i = 3;
auto error3 = paddle::Error("error%d", i);
ASSERT_FALSE(error3);
ASSERT_TRUE(error3);
ASSERT_STREQ("error3", error3.msg());
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册