提交 df28da76 编写于 作者: Q qingqing01 提交者: emailweixu

try to fix bug for CTCErrorEvaluator.cpp when batch_size > 1 (#82)

* try to fix bug for ctc_error_evaluator
上级 703cce35
...@@ -207,7 +207,7 @@ public: ...@@ -207,7 +207,7 @@ public:
real err = 0; real err = 0;
err = editDistance( err = editDistance(
output.value->getData() + output.value->getWidth() * outputStarts[i], output.value->getData() + output.value->getWidth() * outputStarts[i],
output.value->getHeight(), output.value->getWidth(), outputStarts[i+1] - outputStarts[i], output.value->getWidth(),
label.ids->getData() + labelStarts[i], label.ids->getData() + labelStarts[i],
labelStarts[i + 1] - labelStarts[i]); labelStarts[i + 1] - labelStarts[i]);
...@@ -224,6 +224,9 @@ public: ...@@ -224,6 +224,9 @@ public:
for (const std::string& name : config_.input_layers()) { for (const std::string& name : config_.input_layers()) {
arguments.push_back(nn.getLayer(name)->getOutput()); arguments.push_back(nn.getLayer(name)->getOutput());
} }
}
virtual void updateSamplesNum(const std::vector<Argument>& arguments) {
numSequences_ += arguments[1].getNumSequences(); numSequences_ += arguments[1].getNumSequences();
} }
......
...@@ -87,18 +87,31 @@ void testEvaluator(TestConfig testConf, string testEvaluatorName, ...@@ -87,18 +87,31 @@ void testEvaluator(TestConfig testConf, string testEvaluatorName,
return; return;
} }
ICpuGpuVectorPtr sequenceStartPositions;
if (testConf.inputDefs[i].inputType == INPUT_SEQUENCE_DATA ||
testConf.inputDefs[i].inputType == INPUT_SEQUENCE_LABEL) {
if (!sequenceStartPositions) {
generateSequenceStartPositions(batchSize, sequenceStartPositions);
}
data.sequenceStartPositions = sequenceStartPositions;
}
arguments.push_back(data); arguments.push_back(data);
} }
Evaluator* testEvaluator = Evaluator::create(testConf.evaluatorConfig); Evaluator* testEvaluator = Evaluator::create(testConf.evaluatorConfig);
double totalScore = 0.0; double totalScore = 0.0;
testEvaluator->start();
totalScore += testEvaluator->evalImp(arguments); totalScore += testEvaluator->evalImp(arguments);
testEvaluator->updateSamplesNum(arguments); testEvaluator->updateSamplesNum(arguments);
testEvaluator->finish();
LOG(INFO) << *testEvaluator; LOG(INFO) << *testEvaluator;
double totalScore2 = 0.0; double totalScore2 = 0.0;
if (testConf.testAccumulate) { if (testConf.testAccumulate) {
testEvaluator->start();
totalScore2 += testEvaluator->evalImp(arguments); totalScore2 += testEvaluator->evalImp(arguments);
testEvaluator->finish();
EXPECT_LE(fabs(totalScore - totalScore2), 1.0e-5); EXPECT_LE(fabs(totalScore - totalScore2), 1.0e-5);
} }
} }
...@@ -202,6 +215,15 @@ TEST(Evaluator, precision_recall) { ...@@ -202,6 +215,15 @@ TEST(Evaluator, precision_recall) {
false); false);
} }
TEST(Evaluator, ctc_error_evaluator) {
TestConfig config;
config.evaluatorConfig.set_type("ctc_edit_distance");
config.inputDefs.push_back({INPUT_SEQUENCE_DATA, "output", 32});
config.inputDefs.push_back({INPUT_SEQUENCE_LABEL, "label", 1});
testEvaluatorAll(config, "ctc_error_evaluator", 100);
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
initMain(argc, argv); initMain(argc, argv);
FLAGS_thread_local_rand_use_global_seed = true; FLAGS_thread_local_rand_use_global_seed = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册