提交 d5cfa6fc 编写于 作者: Y Yu Yang

Stash

上级 6089b7c6
...@@ -449,6 +449,21 @@ double AucEvaluator::calcAuc() const { ...@@ -449,6 +449,21 @@ double AucEvaluator::calcAuc() const {
} }
} }
real AucEvaluator::getValueImpl() const { return calcAuc(); }
std::string AucEvaluator::getTypeImpl() const {
if (colIdx_ == -1) {
return "last-column-auc";
} else {
return "auc";
}
}
static InitFunction __reg_type_auc__([]() {
Evaluator::registrar_.registerClass("last-column-auc",
[] { return new AucEvaluator(-1); });
});
// class RankAucEvaluator // class RankAucEvaluator
REGISTER_EVALUATOR(rankauc, RankAucEvaluator); REGISTER_EVALUATOR(rankauc, RankAucEvaluator);
...@@ -873,8 +888,6 @@ Evaluator* Evaluator::create(const EvaluatorConfig& config) { ...@@ -873,8 +888,6 @@ Evaluator* Evaluator::create(const EvaluatorConfig& config) {
evaluator = new SumEvaluator(); evaluator = new SumEvaluator();
} else if (config.type() == "last-column-sum") { } else if (config.type() == "last-column-sum") {
evaluator = new ColumnSumEvaluator(-1); evaluator = new ColumnSumEvaluator(-1);
} else if (config.type() == "last-column-auc") {
evaluator = new AucEvaluator(-1);
} else { } else {
evaluator = registrar_.createByType(config.type()); evaluator = registrar_.createByType(config.type());
} }
...@@ -1253,4 +1266,6 @@ public: ...@@ -1253,4 +1266,6 @@ public:
}; };
REGISTER_EVALUATOR(classification_error_printer, ClassificationErrorPrinter); REGISTER_EVALUATOR(classification_error_printer, ClassificationErrorPrinter);
std::string DummyEvaluator::getTypeImpl() const { return "dummy"; }
} // namespace paddle } // namespace paddle
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/parameter/Argument.h" #include "paddle/parameter/Argument.h"
#include "paddle/pserver/ParameterClient2.h" #include "paddle/pserver/ParameterClient2.h"
#include "paddle/utils/ClassRegistrar.h" #include "paddle/utils/ClassRegistrar.h"
#include "paddle/utils/Error.h"
namespace paddle { namespace paddle {
...@@ -117,6 +118,34 @@ public: ...@@ -117,6 +118,34 @@ public:
static ClassRegistrar<Evaluator> registrar_; static ClassRegistrar<Evaluator> registrar_;
virtual void getNames(std::vector<std::string>* names) {
names->clear();
names->push_back(config_.name());
}
virtual real getValue(const std::string& name,
paddle::Error* err = nullptr) const {
if (name != config_.name() && err != nullptr) {
*err = paddle::Error("no such name of evaluator %s", name.c_str());
return .0f;
}
return this->getValueImpl();
}
virtual std::string getType(const std::string& name,
paddle::Error* err = nullptr) const {
if (name != config_.name() && err != nullptr) {
*err = paddle::Error("no such name of evaluator %s", name.c_str());
return std::string();
}
return this->getTypeImpl();
}
protected:
virtual real getValueImpl() const { return .0f; }
virtual std::string getTypeImpl() const { return "base"; }
protected: protected:
EvaluatorConfig config_; EvaluatorConfig config_;
double numSamples_; double numSamples_;
...@@ -135,6 +164,10 @@ public: ...@@ -135,6 +164,10 @@ public:
} }
virtual void finish() {} virtual void finish() {}
virtual void printStats(std::ostream&) const {} virtual void printStats(std::ostream&) const {}
// Evaluator interface
protected:
std::string getTypeImpl() const;
}; };
/** /**
* @brief evaluate AUC using colIdx-th column as prediction. * @brief evaluate AUC using colIdx-th column as prediction.
...@@ -191,6 +224,11 @@ private: ...@@ -191,6 +224,11 @@ private:
} }
double calcAuc() const; double calcAuc() const;
// Evaluator interface
protected:
real getValueImpl() const;
std::string getTypeImpl() const;
}; };
/** /**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册