提交 9087e3a9 编写于 作者: Y Yu Yang

Follow comments to use struct get return value.

上级 e67aac1b
......@@ -647,33 +647,24 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
}
void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
double precision, recall, f1, macroAvgPrecision, macroAvgRecall,
macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score;
bool containMacroMicroInfo = getStatsInfo(&precision,
&recall,
&f1,
&macroAvgPrecision,
&macroAvgRecall,
&macroAvgF1Score,
&microAvgPrecision,
&microAvgRecall,
&microAvgF1Score);
PrintStatsInfo info;
bool containMacroMicroInfo = getStatsInfo(&info);
os << "positive_label=" << config_.positive_label()
<< " precision=" << precision << " recall=" << recall
<< " F1-score=" << f1;
<< " precision=" << info.precision << " recall=" << info.recall
<< " F1-score=" << info.f1;
if (containMacroMicroInfo) {
os << "macro-average-precision=" << macroAvgPrecision
<< " macro-average-recall=" << macroAvgRecall
<< " macro-average-F1-score=" << macroAvgF1Score;
os << "macro-average-precision=" << info.macroAvgPrecision
<< " macro-average-recall=" << info.macroAvgRecall
<< " macro-average-F1-score=" << info.macroAvgF1Score;
if (!isMultiBinaryLabel_) {
// precision and recall are equal in this case
os << " micro-average-precision=" << microAvgPrecision;
os << " micro-average-precision=" << info.microAvgPrecision;
} else {
os << " micro-average-precision=" << microAvgPrecision
<< " micro-average-recall=" << microAvgRecall
<< " micro-average-F1-score=" << microAvgF1Score;
os << " micro-average-precision=" << info.microAvgPrecision
<< " micro-average-recall=" << info.microAvgRecall
<< " micro-average-F1-score=" << info.microAvgF1Score;
}
};
}
}
void PrecisionRecallEvaluator::calcStatsInfo(const MatrixPtr& output,
......@@ -756,31 +747,22 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
void PrecisionRecallEvaluator::storeLocalValues() const {
if (this->values_.size() == 0) {
double precision, recall, f1, macroAvgPrecision, macroAvgRecall,
macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score;
bool containMacroMicroInfo = getStatsInfo(&precision,
&recall,
&f1,
&macroAvgPrecision,
&macroAvgRecall,
&macroAvgF1Score,
&microAvgPrecision,
&microAvgRecall,
&microAvgF1Score);
values_["precision"] = precision;
values_["recal"] = recall;
values_["F1-score"] = f1;
PrintStatsInfo info;
bool containMacroMicroInfo = getStatsInfo(&info);
values_["precision"] = info.precision;
values_["recal"] = info.recall;
values_["F1-score"] = info.f1;
if (containMacroMicroInfo) {
values_["macro-average-precision"] = macroAvgPrecision;
values_["macro-average-recall"] = macroAvgRecall;
values_["macro-average-F1-score"] = macroAvgF1Score;
values_["macro-average-precision"] = info.macroAvgPrecision;
values_["macro-average-recall"] = info.macroAvgRecall;
values_["macro-average-F1-score"] = info.macroAvgF1Score;
if (!isMultiBinaryLabel_) {
// precision and recall are equal in this case
values_["micro-average-precision"] = microAvgPrecision;
values_["micro-average-precision"] = info.microAvgPrecision;
} else {
values_["micro-average-precision"] = microAvgPrecision;
values_["micro-average-recall"] = microAvgRecall;
values_["micro-average-F1-score"] = microAvgF1Score;
values_["micro-average-precision"] = info.microAvgPrecision;
values_["micro-average-recall"] = info.microAvgRecall;
values_["micro-average-F1-score"] = info.microAvgF1Score;
}
}
}
......@@ -836,23 +818,16 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
delete[] buf;
}
bool PrecisionRecallEvaluator::getStatsInfo(double* precision,
double* recall,
double* f1,
double* macroAvgPrecision,
double* macroAvgRecall,
double* macroAvgF1Score,
double* microAvgPrecision,
double* microAvgRecall,
double* microAvgF1Score) const {
bool PrecisionRecallEvaluator::getStatsInfo(
PrecisionRecallEvaluator::PrintStatsInfo* info) const {
int label = config_.positive_label();
if (label != -1) {
CHECK(label >= 0 && label < (int)statsInfo_.size())
<< "positive_label [" << label << "] should be in range [0, "
<< statsInfo_.size() << ")";
*precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
*recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
*f1 = calcF1Score(*precision, *recall);
info->precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
info->recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
info->f1 = calcF1Score(info->precision, info->recall);
return false;
}
......@@ -861,23 +836,26 @@ bool PrecisionRecallEvaluator::getStatsInfo(double* precision,
double microTotalTP = 0;
double microTotalFP = 0;
double microTotalFN = 0;
*macroAvgPrecision = 0;
*macroAvgRecall = 0;
info->macroAvgPrecision = 0;
info->macroAvgRecall = 0;
size_t numLabels = statsInfo_.size();
for (size_t i = 0; i < numLabels; ++i) {
microTotalTP += statsInfo_[i].TP;
microTotalFP += statsInfo_[i].FP;
microTotalFN += statsInfo_[i].FN;
*macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
*macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
info->macroAvgPrecision +=
calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
info->macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
}
*macroAvgPrecision /= numLabels;
*macroAvgRecall /= numLabels;
*macroAvgF1Score = calcF1Score(*macroAvgPrecision, *macroAvgRecall);
*microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
*microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
*microAvgF1Score = calcF1Score(*microAvgPrecision, *microAvgRecall);
info->macroAvgPrecision /= numLabels;
info->macroAvgRecall /= numLabels;
info->macroAvgF1Score =
calcF1Score(info->macroAvgPrecision, info->macroAvgRecall);
info->microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
info->microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
info->microAvgF1Score =
calcF1Score(info->microAvgPrecision, info->microAvgRecall);
return true;
}
......
......@@ -379,15 +379,19 @@ private:
IVectorPtr cpuLabel_;
MatrixPtr cpuWeight_;
bool getStatsInfo(double* precision,
double* recall,
double* f1,
double* macroAvgPrecision,
double* macroAvgRecall,
double* macroAvgF1Score,
double* microAvgPrecision,
double* microAvgRecall,
double* microAvgF1Score) const;
struct PrintStatsInfo {
double precision;
double recall;
double f1;
double macroAvgPrecision;
double macroAvgRecall;
double macroAvgF1Score;
double microAvgPrecision;
double microAvgRecall;
double microAvgF1Score;
};
bool getStatsInfo(PrintStatsInfo* info) const;
void calcStatsInfo(const MatrixPtr& output,
const IVectorPtr& label,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册