提交 6deb33d9 编写于 作者: Y Yu Yang

Complete combined evaluator

上级 9a484425
...@@ -102,6 +102,10 @@ public: ...@@ -102,6 +102,10 @@ public:
virtual void distributeEval(ParameterClient2* client) { virtual void distributeEval(ParameterClient2* client) {
mergeResultsOfAllClients(client); mergeResultsOfAllClients(client);
} }
// Evaluator interface
protected:
std::string getTypeImpl() const { return "classification_error"; }
}; };
/** /**
...@@ -140,6 +144,10 @@ public: ...@@ -140,6 +144,10 @@ public:
virtual void distributeEval(ParameterClient2* client) { virtual void distributeEval(ParameterClient2* client) {
mergeResultsOfAllClients(client); mergeResultsOfAllClients(client);
} }
// Evaluator interface
protected:
std::string getTypeImpl() const { return "seq_classification_error"; }
}; };
REGISTER_EVALUATOR(seq_classification_error, REGISTER_EVALUATOR(seq_classification_error,
SequenceClassificationErrorEvaluator); SequenceClassificationErrorEvaluator);
...@@ -230,6 +238,10 @@ public: ...@@ -230,6 +238,10 @@ public:
private: private:
IVectorPtr cpuLabel_; IVectorPtr cpuLabel_;
MatrixPtr cpuWeight_; MatrixPtr cpuWeight_;
// Evaluator interface
protected:
std::string getTypeImpl() const { return "sum"; }
}; };
/** /**
* @brief column sum Evaluator * @brief column sum Evaluator
...@@ -337,10 +349,18 @@ public: ...@@ -337,10 +349,18 @@ public:
} }
private: private:
ColumnSumEvaluator() {}
int32_t colIdx_; int32_t colIdx_;
size_t colNum_; size_t colNum_;
MatrixPtr sum_; /* cpu matrix */ MatrixPtr sum_; /* cpu matrix */
// Evaluator interface
protected:
std::string getTypeImpl() const {
if (colIdx_ == -1)
return "last-column-sum";
else
return "column-sum";
}
}; };
void AucEvaluator::start() { void AucEvaluator::start() {
...@@ -791,7 +811,6 @@ void PrecisionRecallEvaluator::storeLocalValues() const { ...@@ -791,7 +811,6 @@ void PrecisionRecallEvaluator::storeLocalValues() const {
void PrecisionRecallEvaluator::getNames(std::vector<std::string>* names) { void PrecisionRecallEvaluator::getNames(std::vector<std::string>* names) {
this->storeLocalValues(); this->storeLocalValues();
names->clear();
names->reserve(this->values_.size()); names->reserve(this->values_.size());
for (auto it = this->values_.begin(); it != this->values_.end(); ++it) { for (auto it = this->values_.begin(); it != this->values_.end(); ++it) {
names->push_back(this->config_.name() + "." + it->first); names->push_back(this->config_.name() + "." + it->first);
...@@ -1080,12 +1099,13 @@ public: ...@@ -1080,12 +1099,13 @@ public:
} }
}; };
REGISTER_EVALUATOR(value_printer, ValuePrinter); REGISTER_EVALUATOR(value_printer, ValuePrinter);
/** /**
* @brief print gradient of each layer. * @brief print gradient of each layer.
* *
* The config file api is gradient_printer_evaluator. * The config file api is gradient_printer_evaluator.
*/ */
class GradientPrinter : public Evaluator { class GradientPrinter : public NotGetableEvaluator {
public: public:
virtual void eval(const NeuralNetwork& nn) { virtual void eval(const NeuralNetwork& nn) {
for (const std::string& name : config_.input_layers()) { for (const std::string& name : config_.input_layers()) {
...@@ -1108,7 +1128,7 @@ REGISTER_EVALUATOR(gradient_printer, GradientPrinter); ...@@ -1108,7 +1128,7 @@ REGISTER_EVALUATOR(gradient_printer, GradientPrinter);
* *
* The config file api is maxid_printer_evaluator. * The config file api is maxid_printer_evaluator.
*/ */
class MaxIdPrinter : public Evaluator { class MaxIdPrinter : public NotGetableEvaluator {
private: private:
IVectorPtr maxIds_; IVectorPtr maxIds_;
MatrixPtr maxValues_; MatrixPtr maxValues_;
...@@ -1150,7 +1170,7 @@ REGISTER_EVALUATOR(max_id_printer, MaxIdPrinter); ...@@ -1150,7 +1170,7 @@ REGISTER_EVALUATOR(max_id_printer, MaxIdPrinter);
* *
* The config file api is maxframe_printer_evaluator. * The config file api is maxframe_printer_evaluator.
*/ */
class MaxFramePrinter : public Evaluator { class MaxFramePrinter : public NotGetableEvaluator {
private: private:
IVectorPtr maxIds_; IVectorPtr maxIds_;
MatrixPtr maxValues_; MatrixPtr maxValues_;
...@@ -1237,7 +1257,7 @@ REGISTER_EVALUATOR(max_frame_printer, MaxFramePrinter); ...@@ -1237,7 +1257,7 @@ REGISTER_EVALUATOR(max_frame_printer, MaxFramePrinter);
* The config file api is seqtext_printer_evaluator. * The config file api is seqtext_printer_evaluator.
* *
*/ */
class SequenceTextPrinter : public Evaluator { class SequenceTextPrinter : public NotGetableEvaluator {
private: private:
/// dict_file, which contains a list of tokens /// dict_file, which contains a list of tokens
std::vector<std::string> dict_; std::vector<std::string> dict_;
......
...@@ -119,7 +119,6 @@ public: ...@@ -119,7 +119,6 @@ public:
static ClassRegistrar<Evaluator> registrar_; static ClassRegistrar<Evaluator> registrar_;
virtual void getNames(std::vector<std::string>* names) { virtual void getNames(std::vector<std::string>* names) {
names->clear();
names->push_back(config_.name()); names->push_back(config_.name());
} }
...@@ -168,6 +167,25 @@ protected: ...@@ -168,6 +167,25 @@ protected:
double totalScore_; double totalScore_;
}; };
class NotGetableEvaluator : public Evaluator {
// Evaluator interface
public:
void getNames(std::vector<std::string>* names) {}
real getValue(const std::string& name, Error* err) const {
if (err != nullptr) {
*err = Error("Not implemented");
}
return .0f;
}
std::string getType(const std::string& name, Error* err) const {
if (err != nullptr) {
*err = Error("Not implemented");
}
return "";
}
};
class DummyEvaluator : public Evaluator { class DummyEvaluator : public Evaluator {
public: public:
DummyEvaluator() {} DummyEvaluator() {}
......
...@@ -306,7 +306,6 @@ void NeuralNetwork::onPassEnd() { ...@@ -306,7 +306,6 @@ void NeuralNetwork::onPassEnd() {
class CombinedEvaluator : public Evaluator { class CombinedEvaluator : public Evaluator {
public: public:
CombinedEvaluator() {}
void addEvaluator(std::unique_ptr<Evaluator>&& evaluator) { void addEvaluator(std::unique_ptr<Evaluator>&& evaluator) {
evaluators_.emplace_back(std::move(evaluator)); evaluators_.emplace_back(std::move(evaluator));
} }
...@@ -346,6 +345,50 @@ public: ...@@ -346,6 +345,50 @@ public:
protected: protected:
std::vector<std::unique_ptr<Evaluator>> evaluators_; std::vector<std::unique_ptr<Evaluator>> evaluators_;
// Evaluator interface
public:
void getNames(std::vector<std::string>* names) {
for (auto& eval : evaluators_) {
eval->getNames(names);
}
}
real getValue(const std::string& name, Error* err) const {
return this->getMethodHelper<real>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getValue(name, err);
});
}
std::string getValueStr(const std::string& name, Error* err) const {
return this->getMethodHelper<std::string>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getValueStr(name, err);
});
}
std::string getType(const std::string& name, Error* err) const {
return this->getMethodHelper<std::string>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getType(name, err);
});
}
private:
template <typename T>
T getMethodHelper(const std::string& name,
Error* err,
const std::function<T(const std::unique_ptr<Evaluator>&)>&
callback) const {
for (auto& eval : evaluators_) {
std::vector<std::string> names;
eval->getNames(&names);
if (std::find(names.begin(), names.end(), name) != names.end()) {
return callback(eval);
}
}
if (err != nullptr) *err = Error("No such key %s", name.c_str());
return T();
}
}; };
Evaluator* NeuralNetwork::makeEvaluator() const { Evaluator* NeuralNetwork::makeEvaluator() const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册