提交 faa4da48 编写于 作者: C caoying03

fix ctc edit distance in v2 API.

上级 c1feb27f
...@@ -26,6 +26,7 @@ private: ...@@ -26,6 +26,7 @@ private:
int numTimes_, numClasses_, numSequences_, blank_; int numTimes_, numClasses_, numSequences_, blank_;
real deletions_, insertions_, substitutions_; real deletions_, insertions_, substitutions_;
int seqClassficationError_; int seqClassficationError_;
mutable std::unordered_map<std::string, real> evalResults_;
std::vector<int> path2String(const std::vector<int>& path) { std::vector<int> path2String(const std::vector<int>& path) {
std::vector<int> str; std::vector<int> str;
...@@ -183,6 +184,18 @@ private: ...@@ -183,6 +184,18 @@ private:
return stringAlignment(gtStr, recogStr); return stringAlignment(gtStr, recogStr);
} }
void storeLocalValues() const {
evalResults_["error"] = numSequences_ ? totalScore_ / numSequences_ : 0;
evalResults_["deletion_error"] =
numSequences_ ? deletions_ / numSequences_ : 0;
evalResults_["insertion_error"] =
numSequences_ ? insertions_ / numSequences_ : 0;
evalResults_["substitution_error"] =
numSequences_ ? substitutions_ / numSequences_ : 0;
evalResults_["sequence_error"] =
(real)seqClassficationError_ / numSequences_;
}
public: public:
CTCErrorEvaluator() CTCErrorEvaluator()
: numTimes_(0), : numTimes_(0),
...@@ -245,16 +258,12 @@ public: ...@@ -245,16 +258,12 @@ public:
} }
virtual void printStats(std::ostream& os) const { virtual void printStats(std::ostream& os) const {
os << config_.name() << "=" storeLocalValues();
<< (numSequences_ ? totalScore_ / numSequences_ : 0); os << config_.name() << "=" << evalResults_["error"];
os << " deletions error" os << " deletions error = " << evalResults_["deletion_error"];
<< "=" << (numSequences_ ? deletions_ / numSequences_ : 0); os << " insertions error = " << evalResults_["insertion_error"];
os << " insertions error" os << " substitution error = " << evalResults_["substitution_error"];
<< "=" << (numSequences_ ? insertions_ / numSequences_ : 0); os << " sequence error = " << evalResults_["sequence_error"];
os << " substitutions error"
<< "=" << (numSequences_ ? substitutions_ / numSequences_ : 0);
os << " sequences error"
<< "=" << (real)seqClassficationError_ / numSequences_;
} }
virtual void distributeEval(ParameterClient2* client) { virtual void distributeEval(ParameterClient2* client) {
...@@ -272,6 +281,33 @@ public: ...@@ -272,6 +281,33 @@ public:
seqClassficationError_ = (int)buf[4]; seqClassficationError_ = (int)buf[4];
numSequences_ = (int)buf[5]; numSequences_ = (int)buf[5];
} }
void getNames(std::vector<std::string>* names) {
storeLocalValues();
names->reserve(names->size() + evalResults_.size());
for (auto it = evalResults_.begin(); it != evalResults_.end(); ++it) {
names->push_back(config_.name() + "." + it->first);
}
}
real getValue(const std::string& name, Error* err) const {
storeLocalValues();
const std::string delimiter(".");
std::string::size_type foundPos = name.find(delimiter, 0);
CHECK(foundPos != std::string::npos);
auto it = evalResults_.find(
name.substr(foundPos + delimiter.size(), name.length()));
if (it == evalResults_.end()) {
*err = Error("Evaluator does not have the key %s", name.c_str());
return 0.0f;
}
return it->second;
}
std::string getTypeImpl() const { return "ctc_edit_distance"; }
}; };
REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator); REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册