From d5cfa6fcce622982624274e4e86a6670b29b634d Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 19 Feb 2017 15:17:40 +0800 Subject: [PATCH] Stash --- paddle/gserver/evaluators/Evaluator.cpp | 19 +++++++++++-- paddle/gserver/evaluators/Evaluator.h | 38 +++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index ae7508e2bb1..4689222e3ad 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -449,6 +449,21 @@ double AucEvaluator::calcAuc() const { } } +real AucEvaluator::getValueImpl() const { return calcAuc(); } + +std::string AucEvaluator::getTypeImpl() const { + if (colIdx_ == -1) { + return "last-column-auc"; + } else { + return "auc"; + } +} + +static InitFunction __reg_type_auc__([]() { + Evaluator::registrar_.registerClass("last-column-auc", + [] { return new AucEvaluator(-1); }); +}); + // class RankAucEvaluator REGISTER_EVALUATOR(rankauc, RankAucEvaluator); @@ -873,8 +888,6 @@ Evaluator* Evaluator::create(const EvaluatorConfig& config) { evaluator = new SumEvaluator(); } else if (config.type() == "last-column-sum") { evaluator = new ColumnSumEvaluator(-1); - } else if (config.type() == "last-column-auc") { - evaluator = new AucEvaluator(-1); } else { evaluator = registrar_.createByType(config.type()); } @@ -1253,4 +1266,6 @@ public: }; REGISTER_EVALUATOR(classification_error_printer, ClassificationErrorPrinter); +std::string DummyEvaluator::getTypeImpl() const { return "dummy"; } + } // namespace paddle diff --git a/paddle/gserver/evaluators/Evaluator.h b/paddle/gserver/evaluators/Evaluator.h index 57708473096..bf08aa07f0c 100644 --- a/paddle/gserver/evaluators/Evaluator.h +++ b/paddle/gserver/evaluators/Evaluator.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/parameter/Argument.h" #include "paddle/pserver/ParameterClient2.h" #include "paddle/utils/ClassRegistrar.h" +#include "paddle/utils/Error.h" namespace paddle { @@ -117,6 +118,34 @@ public: static ClassRegistrar registrar_; + virtual void getNames(std::vector* names) { + names->clear(); + names->push_back(config_.name()); + } + + virtual real getValue(const std::string& name, + paddle::Error* err = nullptr) const { + if (name != config_.name() && err != nullptr) { + *err = paddle::Error("no such name of evaluator %s", name.c_str()); + return .0f; + } + return this->getValueImpl(); + } + + virtual std::string getType(const std::string& name, + paddle::Error* err = nullptr) const { + if (name != config_.name() && err != nullptr) { + *err = paddle::Error("no such name of evaluator %s", name.c_str()); + return std::string(); + } + return this->getTypeImpl(); + } + +protected: + virtual real getValueImpl() const { return .0f; } + + virtual std::string getTypeImpl() const { return "base"; } + protected: EvaluatorConfig config_; double numSamples_; @@ -135,6 +164,10 @@ public: } virtual void finish() {} virtual void printStats(std::ostream&) const {} + + // Evaluator interface +protected: + std::string getTypeImpl() const; }; /** * @brief evaluate AUC using colIdx-th column as prediction. @@ -191,6 +224,11 @@ private: } double calcAuc() const; + + // Evaluator interface +protected: + real getValueImpl() const; + std::string getTypeImpl() const; }; /** -- GitLab