From faa4da4835e0622387e2ab1f6481875c5243c5c9 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Mon, 4 Sep 2017 13:02:27 +0800 Subject: [PATCH] fix ctc edit distance in v2 API. --- .../gserver/evaluators/CTCErrorEvaluator.cpp | 56 +++++++++++++++---- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp index 132119015f..8e2dc020cd 100644 --- a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp +++ b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp @@ -26,6 +26,7 @@ private: int numTimes_, numClasses_, numSequences_, blank_; real deletions_, insertions_, substitutions_; int seqClassficationError_; + mutable std::unordered_map evalResults_; std::vector path2String(const std::vector& path) { std::vector str; @@ -183,6 +184,18 @@ private: return stringAlignment(gtStr, recogStr); } + void storeLocalValues() const { + evalResults_["error"] = numSequences_ ? totalScore_ / numSequences_ : 0; + evalResults_["deletion_error"] = + numSequences_ ? deletions_ / numSequences_ : 0; + evalResults_["insertion_error"] = + numSequences_ ? insertions_ / numSequences_ : 0; + evalResults_["substitution_error"] = + numSequences_ ? substitutions_ / numSequences_ : 0; + evalResults_["sequence_error"] = + (real)seqClassficationError_ / numSequences_; + } + public: CTCErrorEvaluator() : numTimes_(0), @@ -245,16 +258,12 @@ public: } virtual void printStats(std::ostream& os) const { - os << config_.name() << "=" - << (numSequences_ ? totalScore_ / numSequences_ : 0); - os << " deletions error" - << "=" << (numSequences_ ? deletions_ / numSequences_ : 0); - os << " insertions error" - << "=" << (numSequences_ ? insertions_ / numSequences_ : 0); - os << " substitutions error" - << "=" << (numSequences_ ? substitutions_ / numSequences_ : 0); - os << " sequences error" - << "=" << (real)seqClassficationError_ / numSequences_; + storeLocalValues(); + os << config_.name() << "=" << evalResults_["error"]; + os << " deletions error = " << evalResults_["deletion_error"]; + os << " insertions error = " << evalResults_["insertion_error"]; + os << " substitution error = " << evalResults_["substitution_error"]; + os << " sequence error = " << evalResults_["sequence_error"]; } virtual void distributeEval(ParameterClient2* client) { @@ -272,6 +281,33 @@ public: seqClassficationError_ = (int)buf[4]; numSequences_ = (int)buf[5]; } + + void getNames(std::vector* names) { + storeLocalValues(); + names->reserve(names->size() + evalResults_.size()); + for (auto it = evalResults_.begin(); it != evalResults_.end(); ++it) { + names->push_back(config_.name() + "." + it->first); + } + } + + real getValue(const std::string& name, Error* err) const { + storeLocalValues(); + + const std::string delimiter("."); + std::string::size_type foundPos = name.find(delimiter, 0); + CHECK(foundPos != std::string::npos); + + auto it = evalResults_.find( + name.substr(foundPos + delimiter.size(), name.length())); + if (it == evalResults_.end()) { + *err = Error("Evaluator does not have the key %s", name.c_str()); + return 0.0f; + } + + return it->second; + } + + std::string getTypeImpl() const { return "ctc_edit_distance"; } }; REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator); -- GitLab