提交 04eaf75c 编写于 作者: Y Yu Yang

Add getValue to some evaluators.

上级 39feacb0
...@@ -538,12 +538,15 @@ double RankAucEvaluator::calcRankAuc(real* outputData, ...@@ -538,12 +538,15 @@ double RankAucEvaluator::calcRankAuc(real* outputData,
: aucTmp / (clickSum * noClickSum); : aucTmp / (clickSum * noClickSum);
} }
std::string RankAucEvaluator::getTypeImpl() const { return "rankauc"; }
// class PrecisionRecallEvaluator // class PrecisionRecallEvaluator
REGISTER_EVALUATOR(precision_recall, PrecisionRecallEvaluator); REGISTER_EVALUATOR(precision_recall, PrecisionRecallEvaluator);
void PrecisionRecallEvaluator::start() { void PrecisionRecallEvaluator::start() {
Evaluator::start(); Evaluator::start();
statsInfo_.clear(); statsInfo_.clear();
values_.clear();
} }
real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) { real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
...@@ -603,7 +606,9 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) { ...@@ -603,7 +606,9 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
return 0; return 0;
} }
void PrecisionRecallEvaluator::printStats(std::ostream& os) const { template <typename T1, typename T2>
void PrecisionRecallEvaluator::printStatsHelper(T1 labelCallback,
T2 microAvgCallback) const {
int label = config_.positive_label(); int label = config_.positive_label();
if (label != -1) { if (label != -1) {
CHECK(label >= 0 && label < (int)statsInfo_.size()) CHECK(label >= 0 && label < (int)statsInfo_.size())
...@@ -612,9 +617,7 @@ void PrecisionRecallEvaluator::printStats(std::ostream& os) const { ...@@ -612,9 +617,7 @@ void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
double precision = double precision =
calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP); calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
double recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN); double recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
os << "positive_label=" << label << " precision=" << precision labelCallback(label, precision, recall, calcF1Score(precision, recall));
<< " recall=" << recall
<< " F1-score=" << calcF1Score(precision, recall);
return; return;
} }
...@@ -636,21 +639,45 @@ void PrecisionRecallEvaluator::printStats(std::ostream& os) const { ...@@ -636,21 +639,45 @@ void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
macroAvgPrecision /= numLabels; macroAvgPrecision /= numLabels;
macroAvgRecall /= numLabels; macroAvgRecall /= numLabels;
double macroAvgF1Score = calcF1Score(macroAvgPrecision, macroAvgRecall); double macroAvgF1Score = calcF1Score(macroAvgPrecision, macroAvgRecall);
os << "macro-average-precision=" << macroAvgPrecision
<< " macro-average-recall=" << macroAvgRecall
<< " macro-average-F1-score=" << macroAvgF1Score;
double microAvgPrecision = calcPrecision(microTotalTP, microTotalFP); double microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
double microAvgRecall = calcPrecision(microTotalTP, microTotalFN); double microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
double microAvgF1Score = calcF1Score(microAvgPrecision, microAvgRecall); double microAvgF1Score = calcF1Score(microAvgPrecision, microAvgRecall);
if (!isMultiBinaryLabel_) {
// precision and recall are equal in this case microAvgCallback(macroAvgPrecision,
os << " micro-average-precision=" << microAvgPrecision; macroAvgRecall,
} else { macroAvgF1Score,
os << " micro-average-precision=" << microAvgPrecision isMultiBinaryLabel_,
<< " micro-average-recall=" << microAvgRecall microAvgPrecision,
<< " micro-average-F1-score=" << microAvgF1Score; microAvgRecall,
} microAvgF1Score);
}
void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
this->printStatsHelper(
[&os](int label, double precision, double recall, double f1) {
os << "positive_label=" << label << " precision=" << precision
<< " recall=" << recall << " F1-score=" << f1;
},
[&os](double macroAvgPrecision,
double macroAvgRecall,
double macroAvgF1Score,
bool isMultiBinaryLabel,
double microAvgPrecision,
double microAvgRecall,
double microAvgF1Score) {
os << "macro-average-precision=" << macroAvgPrecision
<< " macro-average-recall=" << macroAvgRecall
<< " macro-average-F1-score=" << macroAvgF1Score;
if (!isMultiBinaryLabel) {
// precision and recall are equal in this case
os << " micro-average-precision=" << microAvgPrecision;
} else {
os << " micro-average-precision=" << microAvgPrecision
<< " micro-average-recall=" << microAvgRecall
<< " micro-average-F1-score=" << microAvgF1Score;
}
});
} }
void PrecisionRecallEvaluator::calcStatsInfo(const MatrixPtr& output, void PrecisionRecallEvaluator::calcStatsInfo(const MatrixPtr& output,
...@@ -731,6 +758,69 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output, ...@@ -731,6 +758,69 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
} }
} }
void PrecisionRecallEvaluator::storeLocalValues() const {
if (this->values_.size() == 0) {
this->printStatsHelper(
[this](int label, double precision, double recall, double f1) {
values_["positive_label"] = (double)label;
values_["precision"] = precision;
values_["recal"] = recall;
values_["F1-score"] = f1;
},
[this](double macroAvgPrecision,
double macroAvgRecall,
double macroAvgF1Score,
bool isMultiBinaryLabel,
double microAvgPrecision,
double microAvgRecall,
double microAvgF1Score) {
values_["macro-average-precision"] = macroAvgPrecision;
values_["macro-average-recall"] = macroAvgRecall;
values_["macro-average-F1-score"] = macroAvgF1Score;
if (!isMultiBinaryLabel) {
// precision and recall are equal in this case
values_["micro-average-precision"] = microAvgPrecision;
} else {
values_["micro-average-precision"] = microAvgPrecision;
values_["micro-average-recall"] = microAvgRecall;
values_["micro-average-F1-score"] = microAvgF1Score;
}
});
}
}
void PrecisionRecallEvaluator::getNames(std::vector<std::string>* names) {
this->storeLocalValues();
names->clear();
names->reserve(this->values_.size());
for (auto it = this->values_.begin(); it != this->values_.end(); ++it) {
names->push_back(this->config_.name() + "." + it->first);
}
}
real PrecisionRecallEvaluator::getValue(const std::string& name,
Error* err) const {
this->storeLocalValues();
auto it = this->values_.find(name);
if (it != this->values_.end() && err != nullptr) {
*err = Error("No such key %s", name.c_str());
return .0f;
}
return it->second;
}
std::string PrecisionRecallEvaluator::getType(const std::string& name,
Error* err) const {
this->storeLocalValues();
auto it = this->values_.find(name);
if (it != this->values_.end() && err != nullptr) {
*err = Error("No such key %s", name.c_str());
return "";
}
return "precision_recall";
}
void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) { void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
size_t size = 4 * statsInfo_.size(); size_t size = 4 * statsInfo_.size();
double* buf = new double[size]; double* buf = new double[size];
...@@ -874,6 +964,8 @@ void PnpairEvaluator::calc(std::vector<PredictionResult>& predictArray) { ...@@ -874,6 +964,8 @@ void PnpairEvaluator::calc(std::vector<PredictionResult>& predictArray) {
<< " calc total special pair: " << special; << " calc total special pair: " << special;
} }
std::string PnpairEvaluator::getTypeImpl() const { return "pnpair"; }
ClassRegistrar<Evaluator> Evaluator::registrar_; ClassRegistrar<Evaluator> Evaluator::registrar_;
Evaluator* Evaluator::create(const EvaluatorConfig& config) { Evaluator* Evaluator::create(const EvaluatorConfig& config) {
Evaluator* evaluator = registrar_.createByType(config.type()); Evaluator* evaluator = registrar_.createByType(config.type());
...@@ -901,27 +993,12 @@ public: ...@@ -901,27 +993,12 @@ public:
virtual void eval(const NeuralNetwork& nn) { virtual void eval(const NeuralNetwork& nn) {
for (const std::string& name : config_.input_layers()) { for (const std::string& name : config_.input_layers()) {
const Argument& argu = nn.getLayer(name)->getOutput(); std::vector<std::tuple<std::string, std::string>> out;
if (argu.value) { auto err = nn.getLayerOutputValue(name, &out);
std::ostringstream os; err.check();
argu.value->print(os); for (auto& each : out) {
LOG(INFO) << "layer=" << name << " value matrix:\n" << os.str(); LOG(INFO) << "layer=" << name << std::get<0>(each) << ":\n"
} << std::get<1>(each);
if (argu.ids) {
std::ostringstream os;
argu.ids->print(os, argu.ids->getSize());
LOG(INFO) << "layer=" << name << " ids vector:\n" << os.str();
}
if (auto startPos = argu.sequenceStartPositions) {
std::ostringstream os;
startPos->getVector(false)->print(os, startPos->getSize());
LOG(INFO) << "layer=" << name << " sequence pos vector:\n" << os.str();
}
if (auto subStartPos = argu.subSequenceStartPositions) {
std::ostringstream os;
subStartPos->getVector(false)->print(os, subStartPos->getSize());
LOG(INFO) << "layer=" << name << " sub-sequence pos vector:\n"
<< os.str();
} }
} }
} }
......
...@@ -132,6 +132,20 @@ public: ...@@ -132,6 +132,20 @@ public:
return this->getValueImpl(); return this->getValueImpl();
} }
virtual std::string getValueStr(const std::string& name,
paddle::Error* err = nullptr) const {
paddle::Error localErr;
if (err == nullptr) {
err = &localErr;
}
real result = this->getValue(name, err);
if (!err->isOK()) {
return "";
} else {
return std::to_string(result);
}
}
virtual std::string getType(const std::string& name, virtual std::string getType(const std::string& name,
paddle::Error* err = nullptr) const { paddle::Error* err = nullptr) const {
if (name != config_.name() && err != nullptr) { if (name != config_.name() && err != nullptr) {
...@@ -142,7 +156,9 @@ public: ...@@ -142,7 +156,9 @@ public:
} }
protected: protected:
virtual real getValueImpl() const { return .0f; } virtual real getValueImpl() const {
return numSamples_ != .0 ? totalScore_ / numSamples_ : .0;
}
virtual std::string getTypeImpl() const { return "base"; } virtual std::string getTypeImpl() const { return "base"; }
...@@ -261,6 +277,10 @@ private: ...@@ -261,6 +277,10 @@ private:
real* clickData, real* clickData,
real* pvData, real* pvData,
size_t size); size_t size);
// Evaluator interface
protected:
std::string getTypeImpl() const;
}; };
/** /**
* @brief precision, recall and f1 score Evaluator * @brief precision, recall and f1 score Evaluator
...@@ -310,6 +330,9 @@ private: ...@@ -310,6 +330,9 @@ private:
IVectorPtr cpuLabel_; IVectorPtr cpuLabel_;
MatrixPtr cpuWeight_; MatrixPtr cpuWeight_;
template <typename T1, typename T2>
void printStatsHelper(T1 labelCallback, T2 microAvgCallback) const;
void calcStatsInfo(const MatrixPtr& output, void calcStatsInfo(const MatrixPtr& output,
const IVectorPtr& label, const IVectorPtr& label,
const MatrixPtr& weight); const MatrixPtr& weight);
...@@ -341,6 +364,15 @@ private: ...@@ -341,6 +364,15 @@ private:
return 0; return 0;
} }
} }
mutable std::unordered_map<std::string, real> values_;
void storeLocalValues() const;
// Evaluator interface
public:
void getNames(std::vector<std::string>* names);
real getValue(const std::string& name, Error* err) const;
std::string getType(const std::string& name, Error* err) const;
}; };
/* /*
...@@ -387,8 +419,7 @@ public: ...@@ -387,8 +419,7 @@ public:
virtual void finish() { calc(predictArray_); } virtual void finish() { calc(predictArray_); }
virtual void printStats(std::ostream& os) const { virtual void printStats(std::ostream& os) const {
os << " pos/neg" os << " pos/neg=" << this->getValueImpl();
<< "=" << pairArray_[0] / ((pairArray_[1] <= 0) ? 1.0 : pairArray_[1]);
} }
virtual void distributeEval(ParameterClient2* client) { virtual void distributeEval(ParameterClient2* client) {
...@@ -404,6 +435,13 @@ private: ...@@ -404,6 +435,13 @@ private:
IVectorPtr cpuLabel_; IVectorPtr cpuLabel_;
IVectorPtr cpuInfo_; IVectorPtr cpuInfo_;
MatrixPtr cpuWeight_; MatrixPtr cpuWeight_;
// Evaluator interface
protected:
real getValueImpl() const {
return pairArray_[0] / ((pairArray_[1] <= 0) ? 1.0 : pairArray_[1]);
}
std::string getTypeImpl() const;
}; };
} // namespace paddle } // namespace paddle
...@@ -405,4 +405,42 @@ NeuralNetwork* NeuralNetwork::newNeuralNetwork(const std::string& name, ...@@ -405,4 +405,42 @@ NeuralNetwork* NeuralNetwork::newNeuralNetwork(const std::string& name,
} }
} }
Error NeuralNetwork::getLayerOutputValue(
const std::string& layerName,
std::vector<std::tuple<std::string, std::string>>* out) const {
auto& layers = this->config_.layers();
auto it = std::find_if(
layers.begin(), layers.end(), [&layerName](const LayerConfig& conf) {
return conf.name() == layerName;
});
if (it == layers.end()) {
return Error("Cannot find layer %s", layerName.c_str());
}
auto& layer = this->getLayer(layerName);
out->reserve(4);
auto& argu = layer->getOutput();
if (argu.value) {
std::ostringstream os;
argu.value->print(os);
out->push_back({"value", os.str()});
}
if (argu.ids) {
std::ostringstream os;
argu.ids->print(os, argu.ids->getSize());
out->push_back({"ids", os.str()});
}
if (auto startPos = argu.sequenceStartPositions) {
std::ostringstream os;
startPos->getVector(false)->print(os, startPos->getSize());
out->push_back({"sequence pos", os.str()});
}
if (auto subStartPos = argu.subSequenceStartPositions) {
std::ostringstream os;
subStartPos->getVector(false)->print(os, subStartPos->getSize());
out->push_back({"sub-sequence pos", os.str()});
}
return Error();
}
} // namespace paddle } // namespace paddle
...@@ -128,6 +128,10 @@ public: ...@@ -128,6 +128,10 @@ public:
static NeuralNetwork* newNeuralNetwork(const std::string& name = "", static NeuralNetwork* newNeuralNetwork(const std::string& name = "",
NeuralNetwork* rootNetwork = nullptr); NeuralNetwork* rootNetwork = nullptr);
inline Error __must_check getLayerOutputValue(
const std::string& layerName,
std::vector<std::tuple<std::string, std::string>>* out) const;
protected: protected:
/** /**
* The constructor of NeuralNetwork. * The constructor of NeuralNetwork.
......
...@@ -116,6 +116,8 @@ public: ...@@ -116,6 +116,8 @@ public:
*/ */
operator bool() const { return msg_ == nullptr; } operator bool() const { return msg_ == nullptr; }
bool isOK() const { return *this; }
/** /**
* @brief check this status by glog. * @brief check this status by glog.
* @note It is a temp method used during cleaning Paddle code. It will be * @note It is a temp method used during cleaning Paddle code. It will be
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册