提交 9db0d25f 编写于 作者: Y Yibing Liu

pass unittest for deprecated decoders

上级 cfecaa8a
...@@ -4,7 +4,7 @@ from __future__ import division ...@@ -4,7 +4,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from model_utils import decoder from decoders import decoders_deprecated as decoder
class TestDecoders(unittest.TestCase): class TestDecoders(unittest.TestCase):
...@@ -66,16 +66,14 @@ class TestDecoders(unittest.TestCase): ...@@ -66,16 +66,14 @@ class TestDecoders(unittest.TestCase):
beam_result = decoder.ctc_beam_search_decoder( beam_result = decoder.ctc_beam_search_decoder(
probs_seq=self.probs_seq1, probs_seq=self.probs_seq1,
beam_size=self.beam_size, beam_size=self.beam_size,
vocabulary=self.vocab_list, vocabulary=self.vocab_list)
blank_id=len(self.vocab_list))
self.assertEqual(beam_result[0][1], self.beam_search_result[0]) self.assertEqual(beam_result[0][1], self.beam_search_result[0])
def test_beam_search_decoder_2(self): def test_beam_search_decoder_2(self):
beam_result = decoder.ctc_beam_search_decoder( beam_result = decoder.ctc_beam_search_decoder(
probs_seq=self.probs_seq2, probs_seq=self.probs_seq2,
beam_size=self.beam_size, beam_size=self.beam_size,
vocabulary=self.vocab_list, vocabulary=self.vocab_list)
blank_id=len(self.vocab_list))
self.assertEqual(beam_result[0][1], self.beam_search_result[1]) self.assertEqual(beam_result[0][1], self.beam_search_result[1])
def test_beam_search_decoder_batch(self): def test_beam_search_decoder_batch(self):
...@@ -83,7 +81,6 @@ class TestDecoders(unittest.TestCase): ...@@ -83,7 +81,6 @@ class TestDecoders(unittest.TestCase):
probs_split=[self.probs_seq1, self.probs_seq2], probs_split=[self.probs_seq1, self.probs_seq2],
beam_size=self.beam_size, beam_size=self.beam_size,
vocabulary=self.vocab_list, vocabulary=self.vocab_list,
blank_id=len(self.vocab_list),
num_processes=24) num_processes=24)
self.assertEqual(beam_results[0][0][1], self.beam_search_result[0]) self.assertEqual(beam_results[0][0][1], self.beam_search_result[0])
self.assertEqual(beam_results[1][0][1], self.beam_search_result[1]) self.assertEqual(beam_results[1][0][1], self.beam_search_result[1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册