diff --git a/infer.py b/infer.py index 5f0f268a84b331d1dc2c3516c1a8683cb9e3baf5..686f2822c2b6ffa264f4305bc04f745f3c77ed43 100644 --- a/infer.py +++ b/infer.py @@ -63,7 +63,7 @@ parser.add_argument( help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( "--model_filepath", - default='checkpoints/params.tar.gz.41', + default='checkpoints/params.latest.tar.gz', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( diff --git a/tests/test_decoders.py b/tests/test_decoders.py index 4435355cc10069b9de5159010705b0e8b7fe7928..a5e19b08b8622621496cd628ccbe2f37f3d149da 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -81,7 +81,8 @@ class TestDecoders(unittest.TestCase): probs_split=[self.probs_seq1, self.probs_seq2], beam_size=self.beam_size, vocabulary=self.vocab_list, - blank_id=len(self.vocab_list)) + blank_id=len(self.vocab_list), + num_processes=24) self.assertEqual(beam_results[0][0][1], self.beam_search_result[0]) self.assertEqual(beam_results[1][0][1], self.beam_search_result[1])