diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 4f6ff4cb9fd64e34f27a9024b5f8fa65481f67c2..46e5b4d95779d886e156b0a4a4e2b1b9e6e4dc3a 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -228,7 +228,7 @@ class U2Trainer(Trainer): maxlen_in=float('inf'), maxlen_out=float('inf'), minibatches=0, - mini_batch_size=1, + mini_batch_size=self.args.nprocs, batch_count='auto', batch_bins=0, batch_frames_in=0, @@ -247,7 +247,7 @@ class U2Trainer(Trainer): maxlen_in=float('inf'), maxlen_out=float('inf'), minibatches=0, - mini_batch_size=1, + mini_batch_size=self.args.nprocs, batch_count='auto', batch_bins=0, batch_frames_in=0, @@ -263,7 +263,7 @@ class U2Trainer(Trainer): json_file=config.data.test_manifest, train_mode=False, sortagrad=False, - batch_size=config.collator.batch_size, + batch_size=config.decoding.batch_size, maxlen_in=float('inf'), maxlen_out=float('inf'), minibatches=0, @@ -282,7 +282,7 @@ class U2Trainer(Trainer): json_file=config.data.test_manifest, train_mode=False, sortagrad=False, - batch_size=config.collator.batch_size, + batch_size=config.decoding.batch_size, maxlen_in=float('inf'), maxlen_out=float('inf'), minibatches=0,