diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index cde47f572092f5523b406e1a8f44246ba566b8fd..29b5284fe59b5d6e111b33354bb2142078a8fc03 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -102,6 +102,10 @@ public: virtual void distributeEval(ParameterClient2* client) { mergeResultsOfAllClients(client); } + + // Evaluator interface +protected: + std::string getTypeImpl() const { return "classification_error"; } }; /** @@ -140,6 +144,10 @@ public: virtual void distributeEval(ParameterClient2* client) { mergeResultsOfAllClients(client); } + + // Evaluator interface +protected: + std::string getTypeImpl() const { return "seq_classification_error"; } }; REGISTER_EVALUATOR(seq_classification_error, SequenceClassificationErrorEvaluator); @@ -230,6 +238,10 @@ public: private: IVectorPtr cpuLabel_; MatrixPtr cpuWeight_; + + // Evaluator interface +protected: + std::string getTypeImpl() const { return "sum"; } }; /** * @brief column sum Evaluator @@ -337,10 +349,18 @@ public: } private: - ColumnSumEvaluator() {} int32_t colIdx_; size_t colNum_; MatrixPtr sum_; /* cpu matrix */ + + // Evaluator interface +protected: + std::string getTypeImpl() const { + if (colIdx_ == -1) + return "last-column-sum"; + else + return "column-sum"; + } }; void AucEvaluator::start() { @@ -791,7 +811,6 @@ void PrecisionRecallEvaluator::storeLocalValues() const { void PrecisionRecallEvaluator::getNames(std::vector* names) { this->storeLocalValues(); - names->clear(); names->reserve(this->values_.size()); for (auto it = this->values_.begin(); it != this->values_.end(); ++it) { names->push_back(this->config_.name() + "." + it->first); @@ -1080,12 +1099,13 @@ public: } }; REGISTER_EVALUATOR(value_printer, ValuePrinter); + /** * @brief print gradient of each layer. * * The config file api is gradient_printer_evaluator. */ -class GradientPrinter : public Evaluator { +class GradientPrinter : public NotGetableEvaluator { public: virtual void eval(const NeuralNetwork& nn) { for (const std::string& name : config_.input_layers()) { @@ -1108,7 +1128,7 @@ REGISTER_EVALUATOR(gradient_printer, GradientPrinter); * * The config file api is maxid_printer_evaluator. */ -class MaxIdPrinter : public Evaluator { +class MaxIdPrinter : public NotGetableEvaluator { private: IVectorPtr maxIds_; MatrixPtr maxValues_; @@ -1150,7 +1170,7 @@ REGISTER_EVALUATOR(max_id_printer, MaxIdPrinter); * * The config file api is maxframe_printer_evaluator. */ -class MaxFramePrinter : public Evaluator { +class MaxFramePrinter : public NotGetableEvaluator { private: IVectorPtr maxIds_; MatrixPtr maxValues_; @@ -1237,7 +1257,7 @@ REGISTER_EVALUATOR(max_frame_printer, MaxFramePrinter); * The config file api is seqtext_printer_evaluator. * */ -class SequenceTextPrinter : public Evaluator { +class SequenceTextPrinter : public NotGetableEvaluator { private: /// dict_file, which contains a list of tokens std::vector dict_; diff --git a/paddle/gserver/evaluators/Evaluator.h b/paddle/gserver/evaluators/Evaluator.h index bec3d6e66820fa37168e975f053047474b280838..d5b2a63e35e15752e3ba918dde273671b4c3c0d1 100644 --- a/paddle/gserver/evaluators/Evaluator.h +++ b/paddle/gserver/evaluators/Evaluator.h @@ -119,7 +119,6 @@ public: static ClassRegistrar registrar_; virtual void getNames(std::vector* names) { - names->clear(); names->push_back(config_.name()); } @@ -168,6 +167,25 @@ protected: double totalScore_; }; +class NotGetableEvaluator : public Evaluator { + // Evaluator interface +public: + void getNames(std::vector* names) {} + + real getValue(const std::string& name, Error* err) const { + if (err != nullptr) { + *err = Error("Not implemented"); + } + return .0f; + } + std::string getType(const std::string& name, Error* err) const { + if (err != nullptr) { + *err = Error("Not implemented"); + } + return ""; + } +}; + class DummyEvaluator : public Evaluator { public: DummyEvaluator() {} diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index 22051e07ee0026bc3c44a8767e265a56b415b8e4..277fb6d8db13ce73e34eb83fef0437470dfb99c1 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -306,7 +306,6 @@ void NeuralNetwork::onPassEnd() { class CombinedEvaluator : public Evaluator { public: - CombinedEvaluator() {} void addEvaluator(std::unique_ptr&& evaluator) { evaluators_.emplace_back(std::move(evaluator)); } @@ -346,6 +345,50 @@ public: protected: std::vector> evaluators_; + + // Evaluator interface +public: + void getNames(std::vector* names) { + for (auto& eval : evaluators_) { + eval->getNames(names); + } + } + + real getValue(const std::string& name, Error* err) const { + return this->getMethodHelper( + name, err, [&name, err](const std::unique_ptr& eval) { + return eval->getValue(name, err); + }); + } + std::string getValueStr(const std::string& name, Error* err) const { + return this->getMethodHelper( + name, err, [&name, err](const std::unique_ptr& eval) { + return eval->getValueStr(name, err); + }); + } + std::string getType(const std::string& name, Error* err) const { + return this->getMethodHelper( + name, err, [&name, err](const std::unique_ptr& eval) { + return eval->getType(name, err); + }); + } + +private: + template + T getMethodHelper(const std::string& name, + Error* err, + const std::function&)>& + callback) const { + for (auto& eval : evaluators_) { + std::vector names; + eval->getNames(&names); + if (std::find(names.begin(), names.end(), name) != names.end()) { + return callback(eval); + } + } + if (err != nullptr) *err = Error("No such key %s", name.c_str()); + return T(); + } }; Evaluator* NeuralNetwork::makeEvaluator() const {