diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index b854a996966efc608e3a2ce08019eacf216f677c..e84de61574be3040db674d21215fbee89e5f0bc9 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -235,16 +235,18 @@ class DeepSpeech2Trainer(Trainer): num_workers=config.collator.num_workers) self.valid_loader = DataLoader( dev_dataset, - batch_size=int(config.collator.batch_size / 4), + batch_size=int(config.collator.batch_size), shuffle=False, drop_last=False, - collate_fn=collate_fn_dev) + collate_fn=collate_fn_dev, + num_workers=config.collator.num_workers) self.test_loader = DataLoader( test_dataset, batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=collate_fn_test) + collate_fn=collate_fn_test, + num_workers=config.collator.num_workers) logger.info("Setup train/valid/test Dataloader!") diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 1afd9b101889086d2d806f0a15e03f89c19284aa..c30f324b9ae6eddfb3b7a1a4a1101745e44520bd 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -292,7 +292,8 @@ class U2Trainer(Trainer): batch_size=config.collator.batch_size, shuffle=False, drop_last=False, - collate_fn=collate_fn_dev) + collate_fn=collate_fn_dev, + num_workers=config.collator.num_workers, ) # test dataset, return raw text config.data.manifest = config.data.test_manifest @@ -314,7 +315,8 @@ class U2Trainer(Trainer): batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=SpeechCollator.from_config(config)) + collate_fn=SpeechCollator.from_config(config), + num_workers=config.collator.num_workers, ) # return text token id config.collator.keep_transcription_text = False self.align_loader = DataLoader( @@ -322,7 +324,8 @@ class U2Trainer(Trainer): batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=SpeechCollator.from_config(config)) + collate_fn=SpeechCollator.from_config(config), + num_workers=config.collator.num_workers, ) logger.info("Setup train/valid/test/align Dataloader!") def setup_model(self): diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index 9a34cbdc2948a6df6ab9322f9a5eff31ce6b4917..c480499c7b397ff12eb4f7d5047101168a33d149 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -292,7 +292,8 @@ class U2STTrainer(Trainer): batch_size=config.collator.batch_size, shuffle=False, drop_last=False, - collate_fn=collate_fn_dev) + collate_fn=collate_fn_dev, + num_workers=config.collator.num_workers, ) # test dataset, return raw text config.data.manifest = config.data.test_manifest @@ -313,7 +314,8 @@ class U2STTrainer(Trainer): batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=TestCollator.from_config(config)) + collate_fn=TestCollator.from_config(config), + num_workers=config.collator.num_workers, ) # return text token id config.collator.keep_transcription_text = False self.align_loader = DataLoader( @@ -321,7 +323,8 @@ class U2STTrainer(Trainer): batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=TestCollator.from_config(config)) + collate_fn=TestCollator.from_config(config), + num_workers=config.collator.num_workers, ) logger.info("Setup train/valid/test/align Dataloader!") def setup_model(self):