提交 8411a73e 编写于 作者: Y yangyaming

overload several virtual functions to make ChunkEvaluator output multiple metrics

上级 1ba82069
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/math/Vector.h" #include "paddle/math/Vector.h"
#include "paddle/utils/StringUtil.h"
#include "Evaluator.h" #include "Evaluator.h"
...@@ -121,11 +122,9 @@ public: ...@@ -121,11 +122,9 @@ public:
} }
virtual void printStats(std::ostream& os) const { virtual void printStats(std::ostream& os) const {
double precision = (double)numCorrect_ / numOutputSegments_; storeLocalValues();
double recall = (double)numCorrect_ / numLabelSegments_; os << config_.name() << "=" << values_["F1-score"]
double f1 = << " true_chunks=" << numLabelSegments_
!numCorrect_ ? 0 : 2 * precision * recall / (precision + recall);
os << config_.name() << "=" << f1 << " true_chunks=" << numLabelSegments_
<< " result_chunks=" << numOutputSegments_ << " result_chunks=" << numOutputSegments_
<< " correct_chunks=" << numCorrect_; << " correct_chunks=" << numCorrect_;
} }
...@@ -243,6 +242,53 @@ public: ...@@ -243,6 +242,53 @@ public:
if (tag == tagSingle_) return true; if (tag == tagSingle_) return true;
return false; return false;
} }
public:
// three metrics: precision, recall and F1-score
void getNames(std::vector<std::string>* names) {
this->storeLocalValues();
names->reserve(this->values_.size());
for (auto it = this->values_.begin(); it != this->values_.end(); ++it) {
names->push_back(this->config_.name() + "." + it->first);
}
}
// get value by field name
real getValue(const std::string& name, Error* err) const {
this->storeLocalValues();
std::vector<std::string> buffers;
paddle::str::split(name, '.', &buffers);
auto it = this->values_.find(buffers[buffers.size() - 1]);
if (it == this->values_.end()) { // not found
*err = Error("No such key %s", name.c_str());
return 0.0f;
}
return it->second;
}
// get type of evaluator
std::string getType(const std::string& name, Error* err) const {
this->getValue(name, err);
if (!err->isOK()) {
return std::string();
}
return "chunk";
}
private:
void storeLocalValues() const {
CHECK_GT(numOutputSegments_, 0);
CHECK_GT(numLabelSegments_, 0);
double precision = (double)numCorrect_ / numOutputSegments_;
double recall = (double)numCorrect_ / numLabelSegments_;
values_["precision"] = precision;
values_["recall"] = recall;
values_["F1-score"] =
!numCorrect_ ? 0 : 2 * precision * recall / (precision + recall);
}
mutable std::unordered_map<std::string, real> values_;
}; };
REGISTER_EVALUATOR(chunk, ChunkEvaluator); REGISTER_EVALUATOR(chunk, ChunkEvaluator);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册