From 1100366803c10aa9cbd7265616d286e1b38681c0 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 19 Feb 2017 20:43:04 +0800 Subject: [PATCH] Add unittests & fix some bugs. --- paddle/gserver/evaluators/CTCErrorEvaluator.cpp | 2 +- paddle/gserver/evaluators/Evaluator.cpp | 6 ++++-- paddle/gserver/tests/test_Evaluator.cpp | 12 ++++++++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp index 05aa6c012ae..132119015f9 100644 --- a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp +++ b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp @@ -20,7 +20,7 @@ namespace paddle { /** * calculate sequence-to-sequence edit distance */ -class CTCErrorEvaluator : public Evaluator { +class CTCErrorEvaluator : public NotGetableEvaluator { private: MatrixPtr outActivations_; int numTimes_, numClasses_, numSequences_, blank_; diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index 29b5284fe59..42a877954b6 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -823,8 +823,10 @@ real PrecisionRecallEvaluator::getValue(const std::string& name, std::vector buffers; paddle::str::split(name, '.', &buffers); auto it = this->values_.find(buffers[buffers.size() - 1]); - if (it != this->values_.end() && err != nullptr) { - *err = Error("No such key %s", name.c_str()); + if (it == this->values_.end()) { // not found + if (err != nullptr) { + *err = Error("No such key %s", name.c_str()); + } return .0f; } diff --git a/paddle/gserver/tests/test_Evaluator.cpp b/paddle/gserver/tests/test_Evaluator.cpp index 8165eb82693..100cf078095 100644 --- a/paddle/gserver/tests/test_Evaluator.cpp +++ b/paddle/gserver/tests/test_Evaluator.cpp @@ -110,6 +110,18 @@ void testEvaluator(TestConfig testConf, testEvaluator->finish(); LOG(INFO) << *testEvaluator; + std::vector names; + testEvaluator->getNames(&names); + paddle::Error err; + for (auto& name : names) { + auto value = testEvaluator->getValueStr(name, &err); + ASSERT_TRUE(err.isOK()); + LOG(INFO) << name << " " << value; + auto tp = testEvaluator->getType(name, &err); + ASSERT_TRUE(err.isOK()); + ASSERT_EQ(testConf.evaluatorConfig.type(), tp); + } + double totalScore2 = 0.0; if (testConf.testAccumulate) { testEvaluator->start(); -- GitLab