diff --git a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp index cd4ed19c2ca45c310032c834da4cad56fb1cbdff..e397c71c877dce8c34aefac12481373a037510f6 100644 --- a/paddle/gserver/evaluators/CTCErrorEvaluator.cpp +++ b/paddle/gserver/evaluators/CTCErrorEvaluator.cpp @@ -207,7 +207,7 @@ public: real err = 0; err = editDistance( output.value->getData() + output.value->getWidth() * outputStarts[i], - output.value->getHeight(), output.value->getWidth(), + outputStarts[i+1] - outputStarts[i], output.value->getWidth(), label.ids->getData() + labelStarts[i], labelStarts[i + 1] - labelStarts[i]); @@ -224,6 +224,9 @@ public: for (const std::string& name : config_.input_layers()) { arguments.push_back(nn.getLayer(name)->getOutput()); } + } + + virtual void updateSamplesNum(const std::vector& arguments) { numSequences_ += arguments[1].getNumSequences(); } diff --git a/paddle/gserver/tests/test_Evaluator.cpp b/paddle/gserver/tests/test_Evaluator.cpp index 8e857781468fed694dbd061d896263bf05303260..3a591a316b8bafccac9c59ff28e57b4e27f8377a 100644 --- a/paddle/gserver/tests/test_Evaluator.cpp +++ b/paddle/gserver/tests/test_Evaluator.cpp @@ -87,18 +87,31 @@ void testEvaluator(TestConfig testConf, string testEvaluatorName, return; } + ICpuGpuVectorPtr sequenceStartPositions; + if (testConf.inputDefs[i].inputType == INPUT_SEQUENCE_DATA || + testConf.inputDefs[i].inputType == INPUT_SEQUENCE_LABEL) { + if (!sequenceStartPositions) { + generateSequenceStartPositions(batchSize, sequenceStartPositions); + } + data.sequenceStartPositions = sequenceStartPositions; + } + arguments.push_back(data); } Evaluator* testEvaluator = Evaluator::create(testConf.evaluatorConfig); double totalScore = 0.0; + testEvaluator->start(); totalScore += testEvaluator->evalImp(arguments); testEvaluator->updateSamplesNum(arguments); + testEvaluator->finish(); LOG(INFO) << *testEvaluator; double totalScore2 = 0.0; if (testConf.testAccumulate) { + testEvaluator->start(); totalScore2 += testEvaluator->evalImp(arguments); + testEvaluator->finish(); EXPECT_LE(fabs(totalScore - totalScore2), 1.0e-5); } } @@ -202,6 +215,15 @@ TEST(Evaluator, precision_recall) { false); } +TEST(Evaluator, ctc_error_evaluator) { + TestConfig config; + config.evaluatorConfig.set_type("ctc_edit_distance"); + + config.inputDefs.push_back({INPUT_SEQUENCE_DATA, "output", 32}); + config.inputDefs.push_back({INPUT_SEQUENCE_LABEL, "label", 1}); + testEvaluatorAll(config, "ctc_error_evaluator", 100); +} + int main(int argc, char** argv) { initMain(argc, argv); FLAGS_thread_local_rand_use_global_seed = true;