diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index ae7508e2bb117a60492e0c28230f2fbb4b14915e..4689222e3ade62f20bfb4fcac5edb3ba62f48cfd 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -449,6 +449,21 @@ double AucEvaluator::calcAuc() const { } } +real AucEvaluator::getValueImpl() const { return calcAuc(); } + +std::string AucEvaluator::getTypeImpl() const { + if (colIdx_ == -1) { + return "last-column-auc"; + } else { + return "auc"; + } +} + +static InitFunction __reg_type_auc__([]() { + Evaluator::registrar_.registerClass("last-column-auc", + [] { return new AucEvaluator(-1); }); +}); + // class RankAucEvaluator REGISTER_EVALUATOR(rankauc, RankAucEvaluator); @@ -873,8 +888,6 @@ Evaluator* Evaluator::create(const EvaluatorConfig& config) { evaluator = new SumEvaluator(); } else if (config.type() == "last-column-sum") { evaluator = new ColumnSumEvaluator(-1); - } else if (config.type() == "last-column-auc") { - evaluator = new AucEvaluator(-1); } else { evaluator = registrar_.createByType(config.type()); } @@ -1253,4 +1266,6 @@ public: }; REGISTER_EVALUATOR(classification_error_printer, ClassificationErrorPrinter); +std::string DummyEvaluator::getTypeImpl() const { return "dummy"; } + } // namespace paddle diff --git a/paddle/gserver/evaluators/Evaluator.h b/paddle/gserver/evaluators/Evaluator.h index 5770847309670ef1856cfb9255fa847c24513b56..bf08aa07f0ce3108cbf395f3d52ded1845bac9a0 100644 --- a/paddle/gserver/evaluators/Evaluator.h +++ b/paddle/gserver/evaluators/Evaluator.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/parameter/Argument.h" #include "paddle/pserver/ParameterClient2.h" #include "paddle/utils/ClassRegistrar.h" +#include "paddle/utils/Error.h" namespace paddle { @@ -117,6 +118,34 @@ public: static ClassRegistrar registrar_; + virtual void getNames(std::vector* names) { + names->clear(); + names->push_back(config_.name()); + } + + virtual real getValue(const std::string& name, + paddle::Error* err = nullptr) const { + if (name != config_.name() && err != nullptr) { + *err = paddle::Error("no such name of evaluator %s", name.c_str()); + return .0f; + } + return this->getValueImpl(); + } + + virtual std::string getType(const std::string& name, + paddle::Error* err = nullptr) const { + if (name != config_.name() && err != nullptr) { + *err = paddle::Error("no such name of evaluator %s", name.c_str()); + return std::string(); + } + return this->getTypeImpl(); + } + +protected: + virtual real getValueImpl() const { return .0f; } + + virtual std::string getTypeImpl() const { return "base"; } + protected: EvaluatorConfig config_; double numSamples_; @@ -135,6 +164,10 @@ public: } virtual void finish() {} virtual void printStats(std::ostream&) const {} + + // Evaluator interface +protected: + std::string getTypeImpl() const; }; /** * @brief evaluate AUC using colIdx-th column as prediction. @@ -191,6 +224,11 @@ private: } double calcAuc() const; + + // Evaluator interface +protected: + real getValueImpl() const; + std::string getTypeImpl() const; }; /**