提交 d5cfa6fc 编写于 作者: Y Yu Yang

Stash

上级 6089b7c6
......@@ -449,6 +449,21 @@ double AucEvaluator::calcAuc() const {
}
}
real AucEvaluator::getValueImpl() const { return calcAuc(); }
std::string AucEvaluator::getTypeImpl() const {
if (colIdx_ == -1) {
return "last-column-auc";
} else {
return "auc";
}
}
static InitFunction __reg_type_auc__([]() {
Evaluator::registrar_.registerClass("last-column-auc",
[] { return new AucEvaluator(-1); });
});
// class RankAucEvaluator
REGISTER_EVALUATOR(rankauc, RankAucEvaluator);
......@@ -873,8 +888,6 @@ Evaluator* Evaluator::create(const EvaluatorConfig& config) {
evaluator = new SumEvaluator();
} else if (config.type() == "last-column-sum") {
evaluator = new ColumnSumEvaluator(-1);
} else if (config.type() == "last-column-auc") {
evaluator = new AucEvaluator(-1);
} else {
evaluator = registrar_.createByType(config.type());
}
......@@ -1253,4 +1266,6 @@ public:
};
REGISTER_EVALUATOR(classification_error_printer, ClassificationErrorPrinter);
std::string DummyEvaluator::getTypeImpl() const { return "dummy"; }
} // namespace paddle
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/parameter/Argument.h"
#include "paddle/pserver/ParameterClient2.h"
#include "paddle/utils/ClassRegistrar.h"
#include "paddle/utils/Error.h"
namespace paddle {
......@@ -117,6 +118,34 @@ public:
static ClassRegistrar<Evaluator> registrar_;
virtual void getNames(std::vector<std::string>* names) {
names->clear();
names->push_back(config_.name());
}
virtual real getValue(const std::string& name,
paddle::Error* err = nullptr) const {
if (name != config_.name() && err != nullptr) {
*err = paddle::Error("no such name of evaluator %s", name.c_str());
return .0f;
}
return this->getValueImpl();
}
virtual std::string getType(const std::string& name,
paddle::Error* err = nullptr) const {
if (name != config_.name() && err != nullptr) {
*err = paddle::Error("no such name of evaluator %s", name.c_str());
return std::string();
}
return this->getTypeImpl();
}
protected:
virtual real getValueImpl() const { return .0f; }
virtual std::string getTypeImpl() const { return "base"; }
protected:
EvaluatorConfig config_;
double numSamples_;
......@@ -135,6 +164,10 @@ public:
}
virtual void finish() {}
virtual void printStats(std::ostream&) const {}
// Evaluator interface
protected:
std::string getTypeImpl() const;
};
/**
* @brief evaluate AUC using colIdx-th column as prediction.
......@@ -191,6 +224,11 @@ private:
}
double calcAuc() const;
// Evaluator interface
protected:
real getValueImpl() const;
std::string getTypeImpl() const;
};
/**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册