diff --git a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp index 8e2dc020cd84883fdbae1e5f1e969345e7a7b445..928c77a088ff1c96c990310051a530293b6b88b7 100644 --- a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp +++ b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include "Evaluator.h" #include "paddle/gserver/gradientmachines/NeuralNetwork.h" +#include "paddle/utils/StringUtil.h" namespace paddle { @@ -259,7 +260,7 @@ public: virtual void printStats(std::ostream& os) const { storeLocalValues(); - os << config_.name() << "=" << evalResults_["error"]; + os << config_.name() << " error = " << evalResults_["error"]; os << " deletions error = " << evalResults_["deletion_error"]; os << " insertions error = " << evalResults_["insertion_error"]; os << " substitution error = " << evalResults_["substitution_error"]; @@ -293,12 +294,10 @@ public: real getValue(const std::string& name, Error* err) const { storeLocalValues(); - const std::string delimiter("."); - std::string::size_type foundPos = name.find(delimiter, 0); - CHECK(foundPos != std::string::npos); + std::vector buffers; + paddle::str::split(name, '.', &buffers); + auto it = evalResults_.find(buffers[buffers.size() - 1]); - auto it = evalResults_.find( - name.substr(foundPos + delimiter.size(), name.length())); if (it == evalResults_.end()) { *err = Error("Evaluator does not have the key %s", name.c_str()); return 0.0f; @@ -307,7 +306,11 @@ public: return it->second; } - std::string getTypeImpl() const { return "ctc_edit_distance"; } + std::string getType(const std::string& name, Error* err) const { + getValue(name, err); + if (!err->isOK()) return ""; + return "ctc_edit_distance"; + } }; REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator);