diff --git a/paddle/gserver/evaluators/Evaluator.h b/paddle/gserver/evaluators/Evaluator.h index d5b2a63e35e15752e3ba918dde273671b4c3c0d1..a5694a088c3c445422081f505137c96a9f5c5152 100644 --- a/paddle/gserver/evaluators/Evaluator.h +++ b/paddle/gserver/evaluators/Evaluator.h @@ -118,33 +118,55 @@ public: static ClassRegistrar registrar_; + /** + * @brief getNames will return all field names of current evaluator. + * + * The format of name is `evaluator_name.evaluator_fields`. If the evaluator + * has multiple field, the name could be `evaluator_name.field1`. For example + * the PrecisionRecallEvaluator contains `precision`, `recall` fields. The get + * names will return `precision_recall_evaluator.precision`, + * `precision_recall.recal`, etc. + * + * Also, if current Evaluator is a combined evaluator. getNames will return + * all names of all evaluators inside the combined evaluator. + * + * @param names [out]: the field names of current evaluator. + * @note Never clear the names parameter inside getNames. + */ virtual void getNames(std::vector* names) { names->push_back(config_.name()); } + /** + * @brief getValue will return the current evaluate value of one field. + * + * @param name: The field name of current evaluator. + * @param err [out]: The error state. nullptr means don't care. + * + * @return The evaluate value(metric). + */ virtual real getValue(const std::string& name, paddle::Error* err = nullptr) const { - if (name != config_.name() && err != nullptr) { - *err = paddle::Error("no such name of evaluator %s", name.c_str()); + if (name != config_.name()) { + if (err != nullptr) { + *err = paddle::Error("no such name of evaluator %s", name.c_str()); + } return .0f; } return this->getValueImpl(); } - virtual std::string getValueStr(const std::string& name, - paddle::Error* err = nullptr) const { - paddle::Error localErr; - if (err == nullptr) { - err = &localErr; - } - real result = this->getValue(name, err); - if (!err->isOK()) { - return ""; - } else { - return std::to_string(result); - } - } - + /** + * @brief getType will return the evaluator type by field name. + * + * Evaluate Type is the current type of evaluator in string. Such as 'auc', + * 'precision_recall'. In combined evaluator, different name may get different + * evaluate type because it could be evaluated by different evaluator inside. + * + * @param name: The field name of current Evaluator. + * @param err: The error state. nullptr means don't care. + * @return the evaluator type string. + */ virtual std::string getType(const std::string& name, paddle::Error* err = nullptr) const { if (name != config_.name() && err != nullptr) { @@ -155,10 +177,22 @@ public: } protected: + /** + * @brief getValueImpl The simplest way to define getValue result. If this + * evaluator doesn't contain multiple fields, and do not throw any error, just + * implemented this method to get the evaluate result(metric). + * @return Evaluate result(metric). + */ virtual real getValueImpl() const { return numSamples_ != .0 ? totalScore_ / numSamples_ : .0; } + /** + * @brief getTypeImpl The simplest way to define getType result. If this + * evaluator doesn't combine many evaluators, the get type should only return + * itself type. + * @return Evaluator type. + */ virtual std::string getTypeImpl() const { return "base"; } protected: @@ -167,6 +201,11 @@ protected: double totalScore_; }; +/** + * @brief The NotGetableEvaluator class is the base class of evaluator that + * cannot get value in runtime. The most NotGetableEvaluator is Printer + * Evaluator, which is only used to debug network configuration. + */ class NotGetableEvaluator : public Evaluator { // Evaluator interface public: diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index 277fb6d8db13ce73e34eb83fef0437470dfb99c1..4d2bdf0dc992da32659de0f18f685fbcbbeb5656 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -348,24 +348,29 @@ protected: // Evaluator interface public: + /** + * @brief getNames will return all inside evaluators' names. + * @param names [out]: return names. + */ void getNames(std::vector* names) { for (auto& eval : evaluators_) { eval->getNames(names); } } + /** + * @brief getValue could get all inside evaluators' value. + */ real getValue(const std::string& name, Error* err) const { return this->getMethodHelper( name, err, [&name, err](const std::unique_ptr& eval) { return eval->getValue(name, err); }); } - std::string getValueStr(const std::string& name, Error* err) const { - return this->getMethodHelper( - name, err, [&name, err](const std::unique_ptr& eval) { - return eval->getValueStr(name, err); - }); - } + + /** + * @brief getType could get all inside evaluators' type. + */ std::string getType(const std::string& name, Error* err) const { return this->getMethodHelper( name, err, [&name, err](const std::unique_ptr& eval) { diff --git a/paddle/gserver/tests/test_Evaluator.cpp b/paddle/gserver/tests/test_Evaluator.cpp index 100cf0780950921ddf060be4193172ead771abae..07f486b1f4511ba210256b5a21021e8ca0265eb8 100644 --- a/paddle/gserver/tests/test_Evaluator.cpp +++ b/paddle/gserver/tests/test_Evaluator.cpp @@ -114,7 +114,7 @@ void testEvaluator(TestConfig testConf, testEvaluator->getNames(&names); paddle::Error err; for (auto& name : names) { - auto value = testEvaluator->getValueStr(name, &err); + auto value = testEvaluator->getValue(name, &err); ASSERT_TRUE(err.isOK()); LOG(INFO) << name << " " << value; auto tp = testEvaluator->getType(name, &err);