From d776ce9bd71d1878bd51c2a795bd4373dd0119fb Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Wed, 6 Sep 2017 16:02:22 +0800 Subject: [PATCH] Fix import errors in unitests. --- models/tests/test_decoders.py | 14 ++++++++------ utils/tests/test_error_rate.py | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/models/tests/test_decoders.py b/models/tests/test_decoders.py index fa43879b..acce46af 100644 --- a/models/tests/test_decoders.py +++ b/models/tests/test_decoders.py @@ -4,7 +4,7 @@ from __future__ import division from __future__ import print_function import unittest -from decoder import * +from models import decoder class TestDecoders(unittest.TestCase): @@ -53,15 +53,17 @@ class TestDecoders(unittest.TestCase): self.beam_search_result = ['acdc', "b'a"] def test_greedy_decoder_1(self): - bst_result = ctc_greedy_decoder(self.probs_seq1, self.vocab_list) + bst_result = decoder.ctc_greedy_decoder(self.probs_seq1, + self.vocab_list) self.assertEqual(bst_result, self.greedy_result[0]) def test_greedy_decoder_2(self): - bst_result = ctc_greedy_decoder(self.probs_seq2, self.vocab_list) + bst_result = decoder.ctc_greedy_decoder(self.probs_seq2, + self.vocab_list) self.assertEqual(bst_result, self.greedy_result[1]) def test_beam_search_decoder_1(self): - beam_result = ctc_beam_search_decoder( + beam_result = decoder.ctc_beam_search_decoder( probs_seq=self.probs_seq1, beam_size=self.beam_size, vocabulary=self.vocab_list, @@ -69,7 +71,7 @@ class TestDecoders(unittest.TestCase): self.assertEqual(beam_result[0][1], self.beam_search_result[0]) def test_beam_search_decoder_2(self): - beam_result = ctc_beam_search_decoder( + beam_result = decoder.ctc_beam_search_decoder( probs_seq=self.probs_seq2, beam_size=self.beam_size, vocabulary=self.vocab_list, @@ -77,7 +79,7 @@ class TestDecoders(unittest.TestCase): self.assertEqual(beam_result[0][1], self.beam_search_result[1]) def test_beam_search_decoder_batch(self): - beam_results = ctc_beam_search_decoder_batch( + beam_results = decoder.ctc_beam_search_decoder_batch( probs_split=[self.probs_seq1, self.probs_seq2], beam_size=self.beam_size, vocabulary=self.vocab_list, diff --git a/utils/tests/test_error_rate.py b/utils/tests/test_error_rate.py index 99e137a9..d6bc7442 100644 --- a/utils/tests/test_error_rate.py +++ b/utils/tests/test_error_rate.py @@ -5,7 +5,7 @@ from __future__ import division from __future__ import print_function import unittest -import error_rate +from utils import error_rate class TestParse(unittest.TestCase): -- GitLab