diff --git a/paddle/gserver/evaluators/ChunkEvaluator.cpp b/paddle/gserver/evaluators/ChunkEvaluator.cpp index 13f02e51fe9e3831103982130bfdaa3255e1d174..b94a641b4a24d8bda214b375e6155c23a35af955 100644 --- a/paddle/gserver/evaluators/ChunkEvaluator.cpp +++ b/paddle/gserver/evaluators/ChunkEvaluator.cpp @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/math/Vector.h" +#include "paddle/utils/StringUtil.h" #include "Evaluator.h" @@ -121,11 +122,9 @@ public: } virtual void printStats(std::ostream& os) const { - double precision = (double)numCorrect_ / numOutputSegments_; - double recall = (double)numCorrect_ / numLabelSegments_; - double f1 = - !numCorrect_ ? 0 : 2 * precision * recall / (precision + recall); - os << config_.name() << "=" << f1 << " true_chunks=" << numLabelSegments_ + storeLocalValues(); + os << config_.name() << "=" << values_["F1-score"] + << " true_chunks=" << numLabelSegments_ << " result_chunks=" << numOutputSegments_ << " correct_chunks=" << numCorrect_; } @@ -243,6 +242,53 @@ public: if (tag == tagSingle_) return true; return false; } + +public: + // three metrics: precision, recall and F1-score + void getNames(std::vector* names) { + this->storeLocalValues(); + names->reserve(this->values_.size()); + for (auto it = this->values_.begin(); it != this->values_.end(); ++it) { + names->push_back(this->config_.name() + "." + it->first); + } + } + + // get value by field name + real getValue(const std::string& name, Error* err) const { + this->storeLocalValues(); + std::vector buffers; + paddle::str::split(name, '.', &buffers); + auto it = this->values_.find(buffers[buffers.size() - 1]); + if (it == this->values_.end()) { // not found + *err = Error("No such key %s", name.c_str()); + return 0.0f; + } + + return it->second; + } + + // get type of evaluator + std::string getType(const std::string& name, Error* err) const { + this->getValue(name, err); + if (!err->isOK()) { + return std::string(); + } + return "chunk"; + } + +private: + void storeLocalValues() const { + CHECK_GT(numOutputSegments_, 0); + CHECK_GT(numLabelSegments_, 0); + double precision = (double)numCorrect_ / numOutputSegments_; + double recall = (double)numCorrect_ / numLabelSegments_; + values_["precision"] = precision; + values_["recall"] = recall; + values_["F1-score"] = + !numCorrect_ ? 0 : 2 * precision * recall / (precision + recall); + } + + mutable std::unordered_map values_; }; REGISTER_EVALUATOR(chunk, ChunkEvaluator);