提交 9087e3a9 编写于 作者: Y Yu Yang

Follow comments to use struct get return value.

上级 e67aac1b
...@@ -647,33 +647,24 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) { ...@@ -647,33 +647,24 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
} }
void PrecisionRecallEvaluator::printStats(std::ostream& os) const { void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
double precision, recall, f1, macroAvgPrecision, macroAvgRecall, PrintStatsInfo info;
macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score; bool containMacroMicroInfo = getStatsInfo(&info);
bool containMacroMicroInfo = getStatsInfo(&precision,
&recall,
&f1,
&macroAvgPrecision,
&macroAvgRecall,
&macroAvgF1Score,
&microAvgPrecision,
&microAvgRecall,
&microAvgF1Score);
os << "positive_label=" << config_.positive_label() os << "positive_label=" << config_.positive_label()
<< " precision=" << precision << " recall=" << recall << " precision=" << info.precision << " recall=" << info.recall
<< " F1-score=" << f1; << " F1-score=" << info.f1;
if (containMacroMicroInfo) { if (containMacroMicroInfo) {
os << "macro-average-precision=" << macroAvgPrecision os << "macro-average-precision=" << info.macroAvgPrecision
<< " macro-average-recall=" << macroAvgRecall << " macro-average-recall=" << info.macroAvgRecall
<< " macro-average-F1-score=" << macroAvgF1Score; << " macro-average-F1-score=" << info.macroAvgF1Score;
if (!isMultiBinaryLabel_) { if (!isMultiBinaryLabel_) {
// precision and recall are equal in this case // precision and recall are equal in this case
os << " micro-average-precision=" << microAvgPrecision; os << " micro-average-precision=" << info.microAvgPrecision;
} else { } else {
os << " micro-average-precision=" << microAvgPrecision os << " micro-average-precision=" << info.microAvgPrecision
<< " micro-average-recall=" << microAvgRecall << " micro-average-recall=" << info.microAvgRecall
<< " micro-average-F1-score=" << microAvgF1Score; << " micro-average-F1-score=" << info.microAvgF1Score;
} }
}; }
} }
void PrecisionRecallEvaluator::calcStatsInfo(const MatrixPtr& output, void PrecisionRecallEvaluator::calcStatsInfo(const MatrixPtr& output,
...@@ -756,31 +747,22 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output, ...@@ -756,31 +747,22 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
void PrecisionRecallEvaluator::storeLocalValues() const { void PrecisionRecallEvaluator::storeLocalValues() const {
if (this->values_.size() == 0) { if (this->values_.size() == 0) {
double precision, recall, f1, macroAvgPrecision, macroAvgRecall, PrintStatsInfo info;
macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score; bool containMacroMicroInfo = getStatsInfo(&info);
bool containMacroMicroInfo = getStatsInfo(&precision, values_["precision"] = info.precision;
&recall, values_["recal"] = info.recall;
&f1, values_["F1-score"] = info.f1;
&macroAvgPrecision,
&macroAvgRecall,
&macroAvgF1Score,
&microAvgPrecision,
&microAvgRecall,
&microAvgF1Score);
values_["precision"] = precision;
values_["recal"] = recall;
values_["F1-score"] = f1;
if (containMacroMicroInfo) { if (containMacroMicroInfo) {
values_["macro-average-precision"] = macroAvgPrecision; values_["macro-average-precision"] = info.macroAvgPrecision;
values_["macro-average-recall"] = macroAvgRecall; values_["macro-average-recall"] = info.macroAvgRecall;
values_["macro-average-F1-score"] = macroAvgF1Score; values_["macro-average-F1-score"] = info.macroAvgF1Score;
if (!isMultiBinaryLabel_) { if (!isMultiBinaryLabel_) {
// precision and recall are equal in this case // precision and recall are equal in this case
values_["micro-average-precision"] = microAvgPrecision; values_["micro-average-precision"] = info.microAvgPrecision;
} else { } else {
values_["micro-average-precision"] = microAvgPrecision; values_["micro-average-precision"] = info.microAvgPrecision;
values_["micro-average-recall"] = microAvgRecall; values_["micro-average-recall"] = info.microAvgRecall;
values_["micro-average-F1-score"] = microAvgF1Score; values_["micro-average-F1-score"] = info.microAvgF1Score;
} }
} }
} }
...@@ -836,23 +818,16 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) { ...@@ -836,23 +818,16 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
delete[] buf; delete[] buf;
} }
bool PrecisionRecallEvaluator::getStatsInfo(double* precision, bool PrecisionRecallEvaluator::getStatsInfo(
double* recall, PrecisionRecallEvaluator::PrintStatsInfo* info) const {
double* f1,
double* macroAvgPrecision,
double* macroAvgRecall,
double* macroAvgF1Score,
double* microAvgPrecision,
double* microAvgRecall,
double* microAvgF1Score) const {
int label = config_.positive_label(); int label = config_.positive_label();
if (label != -1) { if (label != -1) {
CHECK(label >= 0 && label < (int)statsInfo_.size()) CHECK(label >= 0 && label < (int)statsInfo_.size())
<< "positive_label [" << label << "] should be in range [0, " << "positive_label [" << label << "] should be in range [0, "
<< statsInfo_.size() << ")"; << statsInfo_.size() << ")";
*precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP); info->precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
*recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN); info->recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
*f1 = calcF1Score(*precision, *recall); info->f1 = calcF1Score(info->precision, info->recall);
return false; return false;
} }
...@@ -861,23 +836,26 @@ bool PrecisionRecallEvaluator::getStatsInfo(double* precision, ...@@ -861,23 +836,26 @@ bool PrecisionRecallEvaluator::getStatsInfo(double* precision,
double microTotalTP = 0; double microTotalTP = 0;
double microTotalFP = 0; double microTotalFP = 0;
double microTotalFN = 0; double microTotalFN = 0;
*macroAvgPrecision = 0; info->macroAvgPrecision = 0;
*macroAvgRecall = 0; info->macroAvgRecall = 0;
size_t numLabels = statsInfo_.size(); size_t numLabels = statsInfo_.size();
for (size_t i = 0; i < numLabels; ++i) { for (size_t i = 0; i < numLabels; ++i) {
microTotalTP += statsInfo_[i].TP; microTotalTP += statsInfo_[i].TP;
microTotalFP += statsInfo_[i].FP; microTotalFP += statsInfo_[i].FP;
microTotalFN += statsInfo_[i].FN; microTotalFN += statsInfo_[i].FN;
*macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP); info->macroAvgPrecision +=
*macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN); calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
info->macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
} }
*macroAvgPrecision /= numLabels; info->macroAvgPrecision /= numLabels;
*macroAvgRecall /= numLabels; info->macroAvgRecall /= numLabels;
*macroAvgF1Score = calcF1Score(*macroAvgPrecision, *macroAvgRecall); info->macroAvgF1Score =
calcF1Score(info->macroAvgPrecision, info->macroAvgRecall);
*microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
*microAvgRecall = calcPrecision(microTotalTP, microTotalFN); info->microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
*microAvgF1Score = calcF1Score(*microAvgPrecision, *microAvgRecall); info->microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
info->microAvgF1Score =
calcF1Score(info->microAvgPrecision, info->microAvgRecall);
return true; return true;
} }
......
...@@ -379,15 +379,19 @@ private: ...@@ -379,15 +379,19 @@ private:
IVectorPtr cpuLabel_; IVectorPtr cpuLabel_;
MatrixPtr cpuWeight_; MatrixPtr cpuWeight_;
bool getStatsInfo(double* precision, struct PrintStatsInfo {
double* recall, double precision;
double* f1, double recall;
double* macroAvgPrecision, double f1;
double* macroAvgRecall, double macroAvgPrecision;
double* macroAvgF1Score, double macroAvgRecall;
double* microAvgPrecision, double macroAvgF1Score;
double* microAvgRecall, double microAvgPrecision;
double* microAvgF1Score) const; double microAvgRecall;
double microAvgF1Score;
};
bool getStatsInfo(PrintStatsInfo* info) const;
void calcStatsInfo(const MatrixPtr& output, void calcStatsInfo(const MatrixPtr& output,
const IVectorPtr& label, const IVectorPtr& label,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册