From a523bea8e585bd63f4167e012a05b03ad435b574 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Mon, 4 Sep 2017 20:40:33 +0800 Subject: [PATCH] fix getType. --- paddle/gserver/evaluators/CTCErrorEvaluator.cpp | 8 +++++--- paddle/gserver/evaluators/ChunkEvaluator.cpp | 8 +++++++- paddle/gserver/evaluators/Evaluator.h | 13 ++++++++----- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp index 928c77a088f..92087fa32b1 100644 --- a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp +++ b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp @@ -21,7 +21,7 @@ namespace paddle { /** * calculate sequence-to-sequence edit distance */ -class CTCErrorEvaluator : public NotGetableEvaluator { +class CTCErrorEvaluator : public Evaluator { private: MatrixPtr outActivations_; int numTimes_, numClasses_, numSequences_, blank_; @@ -307,8 +307,10 @@ public: } std::string getType(const std::string& name, Error* err) const { - getValue(name, err); - if (!err->isOK()) return ""; + this->getValue(name, err); + if (!err->isOK()) { + return ""; + } return "ctc_edit_distance"; } }; diff --git a/paddle/gserver/evaluators/ChunkEvaluator.cpp b/paddle/gserver/evaluators/ChunkEvaluator.cpp index 1658282f3a5..a2ab15eedee 100644 --- a/paddle/gserver/evaluators/ChunkEvaluator.cpp +++ b/paddle/gserver/evaluators/ChunkEvaluator.cpp @@ -268,7 +268,13 @@ public: } // get type of evaluator - std::string getTypeImpl() const { return "chunk"; } + std::string getType(const std::string& name, Error* err) const { + this->getValue(name, err); + if (!err->isOK()) { + return ""; + } + return "chunk"; + } private: void storeLocalValues() const { diff --git a/paddle/gserver/evaluators/Evaluator.h b/paddle/gserver/evaluators/Evaluator.h index b114500e2b7..90203553e0a 100644 --- a/paddle/gserver/evaluators/Evaluator.h +++ b/paddle/gserver/evaluators/Evaluator.h @@ -211,6 +211,7 @@ public: *err = Error("Not implemented"); return .0f; } + std::string getType(const std::string& name, Error* err) const { *err = Error("Not implemented"); return ""; @@ -331,6 +332,7 @@ private: protected: std::string getTypeImpl() const; }; + /** * @brief precision, recall and f1 score Evaluator * \f[ @@ -358,6 +360,12 @@ public: virtual void distributeEval(ParameterClient2* client); + void getNames(std::vector* names); + + real getValue(const std::string& name, Error* err) const; + + std::string getType(const std::string& name, Error* err) const; + struct StatsInfo { /// numbers of true positives double TP; @@ -428,11 +436,6 @@ private: mutable std::unordered_map values_; void storeLocalValues() const; - // Evaluator interface -public: - void getNames(std::vector* names); - real getValue(const std::string& name, Error* err) const; - std::string getType(const std::string& name, Error* err) const; }; /* -- GitLab