diff --git a/deep_speech_2/models/tests/test_decoders.py b/deep_speech_2/models/tests/test_decoders.py index fa43879b8741c4f9f62d9a2b648e105fc5d51d37..acce46af81c0168903fa57d5d756dcfd911aa15f 100644 --- a/deep_speech_2/models/tests/test_decoders.py +++ b/deep_speech_2/models/tests/test_decoders.py @@ -4,7 +4,7 @@ from __future__ import division from __future__ import print_function import unittest -from decoder import * +from models import decoder class TestDecoders(unittest.TestCase): @@ -53,15 +53,17 @@ class TestDecoders(unittest.TestCase): self.beam_search_result = ['acdc', "b'a"] def test_greedy_decoder_1(self): - bst_result = ctc_greedy_decoder(self.probs_seq1, self.vocab_list) + bst_result = decoder.ctc_greedy_decoder(self.probs_seq1, + self.vocab_list) self.assertEqual(bst_result, self.greedy_result[0]) def test_greedy_decoder_2(self): - bst_result = ctc_greedy_decoder(self.probs_seq2, self.vocab_list) + bst_result = decoder.ctc_greedy_decoder(self.probs_seq2, + self.vocab_list) self.assertEqual(bst_result, self.greedy_result[1]) def test_beam_search_decoder_1(self): - beam_result = ctc_beam_search_decoder( + beam_result = decoder.ctc_beam_search_decoder( probs_seq=self.probs_seq1, beam_size=self.beam_size, vocabulary=self.vocab_list, @@ -69,7 +71,7 @@ class TestDecoders(unittest.TestCase): self.assertEqual(beam_result[0][1], self.beam_search_result[0]) def test_beam_search_decoder_2(self): - beam_result = ctc_beam_search_decoder( + beam_result = decoder.ctc_beam_search_decoder( probs_seq=self.probs_seq2, beam_size=self.beam_size, vocabulary=self.vocab_list, @@ -77,7 +79,7 @@ class TestDecoders(unittest.TestCase): self.assertEqual(beam_result[0][1], self.beam_search_result[1]) def test_beam_search_decoder_batch(self): - beam_results = ctc_beam_search_decoder_batch( + beam_results = decoder.ctc_beam_search_decoder_batch( probs_split=[self.probs_seq1, self.probs_seq2], beam_size=self.beam_size, vocabulary=self.vocab_list, diff --git a/deep_speech_2/utils/tests/test_error_rate.py b/deep_speech_2/utils/tests/test_error_rate.py index 99e137a9a190cba8f2d99001cbf3c22ce8d53b56..d6bc7442e1f55bcea1f16234301785a884f2249a 100644 --- a/deep_speech_2/utils/tests/test_error_rate.py +++ b/deep_speech_2/utils/tests/test_error_rate.py @@ -5,7 +5,7 @@ from __future__ import division from __future__ import print_function import unittest -import error_rate +from utils import error_rate class TestParse(unittest.TestCase):