提交 df28da76 编写于 作者: Q qingqing01 提交者: emailweixu

try to fix bug for CTCErrorEvaluator.cpp when batch_size > 1 (#82)

* try to fix bug for ctc_error_evaluator
上级 703cce35
......@@ -207,7 +207,7 @@ public:
real err = 0;
err = editDistance(
output.value->getData() + output.value->getWidth() * outputStarts[i],
output.value->getHeight(), output.value->getWidth(),
outputStarts[i+1] - outputStarts[i], output.value->getWidth(),
label.ids->getData() + labelStarts[i],
labelStarts[i + 1] - labelStarts[i]);
......@@ -224,6 +224,9 @@ public:
for (const std::string& name : config_.input_layers()) {
arguments.push_back(nn.getLayer(name)->getOutput());
}
}
virtual void updateSamplesNum(const std::vector<Argument>& arguments) {
numSequences_ += arguments[1].getNumSequences();
}
......
......@@ -87,18 +87,31 @@ void testEvaluator(TestConfig testConf, string testEvaluatorName,
return;
}
ICpuGpuVectorPtr sequenceStartPositions;
if (testConf.inputDefs[i].inputType == INPUT_SEQUENCE_DATA ||
testConf.inputDefs[i].inputType == INPUT_SEQUENCE_LABEL) {
if (!sequenceStartPositions) {
generateSequenceStartPositions(batchSize, sequenceStartPositions);
}
data.sequenceStartPositions = sequenceStartPositions;
}
arguments.push_back(data);
}
Evaluator* testEvaluator = Evaluator::create(testConf.evaluatorConfig);
double totalScore = 0.0;
testEvaluator->start();
totalScore += testEvaluator->evalImp(arguments);
testEvaluator->updateSamplesNum(arguments);
testEvaluator->finish();
LOG(INFO) << *testEvaluator;
double totalScore2 = 0.0;
if (testConf.testAccumulate) {
testEvaluator->start();
totalScore2 += testEvaluator->evalImp(arguments);
testEvaluator->finish();
EXPECT_LE(fabs(totalScore - totalScore2), 1.0e-5);
}
}
......@@ -202,6 +215,15 @@ TEST(Evaluator, precision_recall) {
false);
}
TEST(Evaluator, ctc_error_evaluator) {
TestConfig config;
config.evaluatorConfig.set_type("ctc_edit_distance");
config.inputDefs.push_back({INPUT_SEQUENCE_DATA, "output", 32});
config.inputDefs.push_back({INPUT_SEQUENCE_LABEL, "label", 1});
testEvaluatorAll(config, "ctc_error_evaluator", 100);
}
int main(int argc, char** argv) {
initMain(argc, argv);
FLAGS_thread_local_rand_use_global_seed = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册