提交 6deb33d9 编写于 作者: Y Yu Yang

Complete combined evaluator

上级 9a484425
......@@ -102,6 +102,10 @@ public:
virtual void distributeEval(ParameterClient2* client) {
mergeResultsOfAllClients(client);
}
// Evaluator interface
protected:
std::string getTypeImpl() const { return "classification_error"; }
};
/**
......@@ -140,6 +144,10 @@ public:
virtual void distributeEval(ParameterClient2* client) {
mergeResultsOfAllClients(client);
}
// Evaluator interface
protected:
std::string getTypeImpl() const { return "seq_classification_error"; }
};
REGISTER_EVALUATOR(seq_classification_error,
SequenceClassificationErrorEvaluator);
......@@ -230,6 +238,10 @@ public:
private:
IVectorPtr cpuLabel_;
MatrixPtr cpuWeight_;
// Evaluator interface
protected:
std::string getTypeImpl() const { return "sum"; }
};
/**
* @brief column sum Evaluator
......@@ -337,10 +349,18 @@ public:
}
private:
ColumnSumEvaluator() {}
int32_t colIdx_;
size_t colNum_;
MatrixPtr sum_; /* cpu matrix */
// Evaluator interface
protected:
std::string getTypeImpl() const {
if (colIdx_ == -1)
return "last-column-sum";
else
return "column-sum";
}
};
void AucEvaluator::start() {
......@@ -791,7 +811,6 @@ void PrecisionRecallEvaluator::storeLocalValues() const {
void PrecisionRecallEvaluator::getNames(std::vector<std::string>* names) {
this->storeLocalValues();
names->clear();
names->reserve(this->values_.size());
for (auto it = this->values_.begin(); it != this->values_.end(); ++it) {
names->push_back(this->config_.name() + "." + it->first);
......@@ -1080,12 +1099,13 @@ public:
}
};
REGISTER_EVALUATOR(value_printer, ValuePrinter);
/**
* @brief print gradient of each layer.
*
* The config file api is gradient_printer_evaluator.
*/
class GradientPrinter : public Evaluator {
class GradientPrinter : public NotGetableEvaluator {
public:
virtual void eval(const NeuralNetwork& nn) {
for (const std::string& name : config_.input_layers()) {
......@@ -1108,7 +1128,7 @@ REGISTER_EVALUATOR(gradient_printer, GradientPrinter);
*
* The config file api is maxid_printer_evaluator.
*/
class MaxIdPrinter : public Evaluator {
class MaxIdPrinter : public NotGetableEvaluator {
private:
IVectorPtr maxIds_;
MatrixPtr maxValues_;
......@@ -1150,7 +1170,7 @@ REGISTER_EVALUATOR(max_id_printer, MaxIdPrinter);
*
* The config file api is maxframe_printer_evaluator.
*/
class MaxFramePrinter : public Evaluator {
class MaxFramePrinter : public NotGetableEvaluator {
private:
IVectorPtr maxIds_;
MatrixPtr maxValues_;
......@@ -1237,7 +1257,7 @@ REGISTER_EVALUATOR(max_frame_printer, MaxFramePrinter);
* The config file api is seqtext_printer_evaluator.
*
*/
class SequenceTextPrinter : public Evaluator {
class SequenceTextPrinter : public NotGetableEvaluator {
private:
/// dict_file, which contains a list of tokens
std::vector<std::string> dict_;
......
......@@ -119,7 +119,6 @@ public:
static ClassRegistrar<Evaluator> registrar_;
virtual void getNames(std::vector<std::string>* names) {
names->clear();
names->push_back(config_.name());
}
......@@ -168,6 +167,25 @@ protected:
double totalScore_;
};
class NotGetableEvaluator : public Evaluator {
// Evaluator interface
public:
void getNames(std::vector<std::string>* names) {}
real getValue(const std::string& name, Error* err) const {
if (err != nullptr) {
*err = Error("Not implemented");
}
return .0f;
}
std::string getType(const std::string& name, Error* err) const {
if (err != nullptr) {
*err = Error("Not implemented");
}
return "";
}
};
class DummyEvaluator : public Evaluator {
public:
DummyEvaluator() {}
......
......@@ -306,7 +306,6 @@ void NeuralNetwork::onPassEnd() {
class CombinedEvaluator : public Evaluator {
public:
CombinedEvaluator() {}
void addEvaluator(std::unique_ptr<Evaluator>&& evaluator) {
evaluators_.emplace_back(std::move(evaluator));
}
......@@ -346,6 +345,50 @@ public:
protected:
std::vector<std::unique_ptr<Evaluator>> evaluators_;
// Evaluator interface
public:
void getNames(std::vector<std::string>* names) {
for (auto& eval : evaluators_) {
eval->getNames(names);
}
}
real getValue(const std::string& name, Error* err) const {
return this->getMethodHelper<real>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getValue(name, err);
});
}
std::string getValueStr(const std::string& name, Error* err) const {
return this->getMethodHelper<std::string>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getValueStr(name, err);
});
}
std::string getType(const std::string& name, Error* err) const {
return this->getMethodHelper<std::string>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getType(name, err);
});
}
private:
template <typename T>
T getMethodHelper(const std::string& name,
Error* err,
const std::function<T(const std::unique_ptr<Evaluator>&)>&
callback) const {
for (auto& eval : evaluators_) {
std::vector<std::string> names;
eval->getNames(&names);
if (std::find(names.begin(), names.end(), name) != names.end()) {
return callback(eval);
}
}
if (err != nullptr) *err = Error("No such key %s", name.c_str());
return T();
}
};
Evaluator* NeuralNetwork::makeEvaluator() const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册