提交 a523bea8 编写于 作者: C caoying03

fix getType.

上级 0b478e99
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
/** /**
* calculate sequence-to-sequence edit distance * calculate sequence-to-sequence edit distance
*/ */
class CTCErrorEvaluator : public NotGetableEvaluator { class CTCErrorEvaluator : public Evaluator {
private: private:
MatrixPtr outActivations_; MatrixPtr outActivations_;
int numTimes_, numClasses_, numSequences_, blank_; int numTimes_, numClasses_, numSequences_, blank_;
...@@ -307,8 +307,10 @@ public: ...@@ -307,8 +307,10 @@ public:
} }
std::string getType(const std::string& name, Error* err) const { std::string getType(const std::string& name, Error* err) const {
getValue(name, err); this->getValue(name, err);
if (!err->isOK()) return ""; if (!err->isOK()) {
return "";
}
return "ctc_edit_distance"; return "ctc_edit_distance";
} }
}; };
......
...@@ -268,7 +268,13 @@ public: ...@@ -268,7 +268,13 @@ public:
} }
// get type of evaluator // get type of evaluator
std::string getTypeImpl() const { return "chunk"; } std::string getType(const std::string& name, Error* err) const {
this->getValue(name, err);
if (!err->isOK()) {
return "";
}
return "chunk";
}
private: private:
void storeLocalValues() const { void storeLocalValues() const {
......
...@@ -211,6 +211,7 @@ public: ...@@ -211,6 +211,7 @@ public:
*err = Error("Not implemented"); *err = Error("Not implemented");
return .0f; return .0f;
} }
std::string getType(const std::string& name, Error* err) const { std::string getType(const std::string& name, Error* err) const {
*err = Error("Not implemented"); *err = Error("Not implemented");
return ""; return "";
...@@ -331,6 +332,7 @@ private: ...@@ -331,6 +332,7 @@ private:
protected: protected:
std::string getTypeImpl() const; std::string getTypeImpl() const;
}; };
/** /**
* @brief precision, recall and f1 score Evaluator * @brief precision, recall and f1 score Evaluator
* \f[ * \f[
...@@ -358,6 +360,12 @@ public: ...@@ -358,6 +360,12 @@ public:
virtual void distributeEval(ParameterClient2* client); virtual void distributeEval(ParameterClient2* client);
void getNames(std::vector<std::string>* names);
real getValue(const std::string& name, Error* err) const;
std::string getType(const std::string& name, Error* err) const;
struct StatsInfo { struct StatsInfo {
/// numbers of true positives /// numbers of true positives
double TP; double TP;
...@@ -428,11 +436,6 @@ private: ...@@ -428,11 +436,6 @@ private:
mutable std::unordered_map<std::string, real> values_; mutable std::unordered_map<std::string, real> values_;
void storeLocalValues() const; void storeLocalValues() const;
// Evaluator interface
public:
void getNames(std::vector<std::string>* names);
real getValue(const std::string& name, Error* err) const;
std::string getType(const std::string& name, Error* err) const;
}; };
/* /*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册