提交 77a15b67 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #3844 from lcy-seso/fix_ctc_evaluator

fix ctc edit distance evaluator can not print information in v2 API.
...@@ -14,18 +14,20 @@ limitations under the License. */ ...@@ -14,18 +14,20 @@ limitations under the License. */
#include "Evaluator.h" #include "Evaluator.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h" #include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/utils/StringUtil.h"
namespace paddle { namespace paddle {
/** /**
* calculate sequence-to-sequence edit distance * calculate sequence-to-sequence edit distance
*/ */
class CTCErrorEvaluator : public NotGetableEvaluator { class CTCErrorEvaluator : public Evaluator {
private: private:
MatrixPtr outActivations_; MatrixPtr outActivations_;
int numTimes_, numClasses_, numSequences_, blank_; int numTimes_, numClasses_, numSequences_, blank_;
real deletions_, insertions_, substitutions_; real deletions_, insertions_, substitutions_;
int seqClassficationError_; int seqClassficationError_;
mutable std::unordered_map<std::string, real> evalResults_;
std::vector<int> path2String(const std::vector<int>& path) { std::vector<int> path2String(const std::vector<int>& path) {
std::vector<int> str; std::vector<int> str;
...@@ -183,6 +185,18 @@ private: ...@@ -183,6 +185,18 @@ private:
return stringAlignment(gtStr, recogStr); return stringAlignment(gtStr, recogStr);
} }
void storeLocalValues() const {
evalResults_["error"] = numSequences_ ? totalScore_ / numSequences_ : 0;
evalResults_["deletion_error"] =
numSequences_ ? deletions_ / numSequences_ : 0;
evalResults_["insertion_error"] =
numSequences_ ? insertions_ / numSequences_ : 0;
evalResults_["substitution_error"] =
numSequences_ ? substitutions_ / numSequences_ : 0;
evalResults_["sequence_error"] =
(real)seqClassficationError_ / numSequences_;
}
public: public:
CTCErrorEvaluator() CTCErrorEvaluator()
: numTimes_(0), : numTimes_(0),
...@@ -245,16 +259,12 @@ public: ...@@ -245,16 +259,12 @@ public:
} }
virtual void printStats(std::ostream& os) const { virtual void printStats(std::ostream& os) const {
os << config_.name() << "=" storeLocalValues();
<< (numSequences_ ? totalScore_ / numSequences_ : 0); os << config_.name() << " error = " << evalResults_["error"];
os << " deletions error" os << " deletions error = " << evalResults_["deletion_error"];
<< "=" << (numSequences_ ? deletions_ / numSequences_ : 0); os << " insertions error = " << evalResults_["insertion_error"];
os << " insertions error" os << " substitution error = " << evalResults_["substitution_error"];
<< "=" << (numSequences_ ? insertions_ / numSequences_ : 0); os << " sequence error = " << evalResults_["sequence_error"];
os << " substitutions error"
<< "=" << (numSequences_ ? substitutions_ / numSequences_ : 0);
os << " sequences error"
<< "=" << (real)seqClassficationError_ / numSequences_;
} }
virtual void distributeEval(ParameterClient2* client) { virtual void distributeEval(ParameterClient2* client) {
...@@ -272,6 +282,37 @@ public: ...@@ -272,6 +282,37 @@ public:
seqClassficationError_ = (int)buf[4]; seqClassficationError_ = (int)buf[4];
numSequences_ = (int)buf[5]; numSequences_ = (int)buf[5];
} }
void getNames(std::vector<std::string>* names) {
storeLocalValues();
names->reserve(names->size() + evalResults_.size());
for (auto it = evalResults_.begin(); it != evalResults_.end(); ++it) {
names->push_back(config_.name() + "." + it->first);
}
}
real getValue(const std::string& name, Error* err) const {
storeLocalValues();
std::vector<std::string> buffers;
paddle::str::split(name, '.', &buffers);
auto it = evalResults_.find(buffers[buffers.size() - 1]);
if (it == evalResults_.end()) {
*err = Error("Evaluator does not have the key %s", name.c_str());
return 0.0f;
}
return it->second;
}
std::string getType(const std::string& name, Error* err) const {
this->getValue(name, err);
if (!err->isOK()) {
return "";
}
return "ctc_edit_distance";
}
}; };
REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator); REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator);
......
...@@ -268,7 +268,13 @@ public: ...@@ -268,7 +268,13 @@ public:
} }
// get type of evaluator // get type of evaluator
std::string getTypeImpl() const { return "chunk"; } std::string getType(const std::string& name, Error* err) const {
this->getValue(name, err);
if (!err->isOK()) {
return "";
}
return "chunk";
}
private: private:
void storeLocalValues() const { void storeLocalValues() const {
......
...@@ -211,6 +211,7 @@ public: ...@@ -211,6 +211,7 @@ public:
*err = Error("Not implemented"); *err = Error("Not implemented");
return .0f; return .0f;
} }
std::string getType(const std::string& name, Error* err) const { std::string getType(const std::string& name, Error* err) const {
*err = Error("Not implemented"); *err = Error("Not implemented");
return ""; return "";
...@@ -331,6 +332,7 @@ private: ...@@ -331,6 +332,7 @@ private:
protected: protected:
std::string getTypeImpl() const; std::string getTypeImpl() const;
}; };
/** /**
* @brief precision, recall and f1 score Evaluator * @brief precision, recall and f1 score Evaluator
* \f[ * \f[
...@@ -358,6 +360,12 @@ public: ...@@ -358,6 +360,12 @@ public:
virtual void distributeEval(ParameterClient2* client); virtual void distributeEval(ParameterClient2* client);
void getNames(std::vector<std::string>* names);
real getValue(const std::string& name, Error* err) const;
std::string getType(const std::string& name, Error* err) const;
struct StatsInfo { struct StatsInfo {
/// numbers of true positives /// numbers of true positives
double TP; double TP;
...@@ -428,11 +436,6 @@ private: ...@@ -428,11 +436,6 @@ private:
mutable std::unordered_map<std::string, real> values_; mutable std::unordered_map<std::string, real> values_;
void storeLocalValues() const; void storeLocalValues() const;
// Evaluator interface
public:
void getNames(std::vector<std::string>* names);
real getValue(const std::string& name, Error* err) const;
std::string getType(const std::string& name, Error* err) const;
}; };
/* /*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册