diff --git a/lm/lm_scorer.py b/lm/lm_scorer.py index 1c029e97f63bd9b4db73430993d41c7b67c23ea1..de41754f9b8f5c83978a8c01a97223db03ca0614 100644 --- a/lm/lm_scorer.py +++ b/lm/lm_scorer.py @@ -42,6 +42,11 @@ class LmScorer(object): words = sentence.strip().split(' ') return len(words) + # reset alpha and beta + def reset_params(self, alpha, beta): + self._alpha = alpha + self._beta = beta + # execute evaluation def __call__(self, sentence, log=False): """Evaluation function, gathering all the different scores diff --git a/tests/test_decoders.py b/tests/test_decoders.py index a5e19b08b8622621496cd628ccbe2f37f3d149da..99d8a8289d93574c58ced50923716c39cfb96558 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -76,7 +76,7 @@ class TestDecoders(unittest.TestCase): blank_id=len(self.vocab_list)) self.assertEqual(beam_result[0][1], self.beam_search_result[1]) - def test_beam_search_nproc_decoder(self): + def test_beam_search_decoder_batch(self): beam_results = ctc_beam_search_decoder_batch( probs_split=[self.probs_seq1, self.probs_seq2], beam_size=self.beam_size, diff --git a/tune.py b/tune.py index 9cea66b90fa7874f5a61499506a11dd6afeae6e9..e26bc45cee2f84d3336e8a4387487db99e8d95c9 100644 --- a/tune.py +++ b/tune.py @@ -12,6 +12,7 @@ from model import deep_speech2 from decoder import * from lm.lm_scorer import LmScorer from error_rate import wer +import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -180,10 +181,13 @@ def tune(): params_grid = [(alpha, beta) for alpha in cand_alphas for beta in cand_betas] + ext_scorer = LmScorer(args.alpha_from, args.beta_from, + args.language_model_path) ## tune parameters in loop - for (alpha, beta) in params_grid: + for alpha, beta in params_grid: wer_sum, wer_counter = 0, 0 - ext_scorer = LmScorer(alpha, beta, args.language_model_path) + # reset scorer + ext_scorer.reset_params(alpha, beta) # beam search using multiple processes beam_search_results = ctc_beam_search_decoder_batch( probs_split=probs_split,