提交 856d641c 编写于 作者: H Hui Zhang

multi worker for dataloader

上级 b7b1bda3
...@@ -235,16 +235,18 @@ class DeepSpeech2Trainer(Trainer): ...@@ -235,16 +235,18 @@ class DeepSpeech2Trainer(Trainer):
num_workers=config.collator.num_workers) num_workers=config.collator.num_workers)
self.valid_loader = DataLoader( self.valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=int(config.collator.batch_size / 4), batch_size=int(config.collator.batch_size),
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn_dev) collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers)
self.test_loader = DataLoader( self.test_loader = DataLoader(
test_dataset, test_dataset,
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn_test) collate_fn=collate_fn_test,
num_workers=config.collator.num_workers)
logger.info("Setup train/valid/test Dataloader!") logger.info("Setup train/valid/test Dataloader!")
......
...@@ -292,7 +292,8 @@ class U2Trainer(Trainer): ...@@ -292,7 +292,8 @@ class U2Trainer(Trainer):
batch_size=config.collator.batch_size, batch_size=config.collator.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn_dev) collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers, )
# test dataset, return raw text # test dataset, return raw text
config.data.manifest = config.data.test_manifest config.data.manifest = config.data.test_manifest
...@@ -314,7 +315,8 @@ class U2Trainer(Trainer): ...@@ -314,7 +315,8 @@ class U2Trainer(Trainer):
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator.from_config(config)) collate_fn=SpeechCollator.from_config(config),
num_workers=config.collator.num_workers, )
# return text token id # return text token id
config.collator.keep_transcription_text = False config.collator.keep_transcription_text = False
self.align_loader = DataLoader( self.align_loader = DataLoader(
...@@ -322,7 +324,8 @@ class U2Trainer(Trainer): ...@@ -322,7 +324,8 @@ class U2Trainer(Trainer):
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator.from_config(config)) collate_fn=SpeechCollator.from_config(config),
num_workers=config.collator.num_workers, )
logger.info("Setup train/valid/test/align Dataloader!") logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self): def setup_model(self):
......
...@@ -292,7 +292,8 @@ class U2STTrainer(Trainer): ...@@ -292,7 +292,8 @@ class U2STTrainer(Trainer):
batch_size=config.collator.batch_size, batch_size=config.collator.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn_dev) collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers, )
# test dataset, return raw text # test dataset, return raw text
config.data.manifest = config.data.test_manifest config.data.manifest = config.data.test_manifest
...@@ -313,7 +314,8 @@ class U2STTrainer(Trainer): ...@@ -313,7 +314,8 @@ class U2STTrainer(Trainer):
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=TestCollator.from_config(config)) collate_fn=TestCollator.from_config(config),
num_workers=config.collator.num_workers, )
# return text token id # return text token id
config.collator.keep_transcription_text = False config.collator.keep_transcription_text = False
self.align_loader = DataLoader( self.align_loader = DataLoader(
...@@ -321,7 +323,8 @@ class U2STTrainer(Trainer): ...@@ -321,7 +323,8 @@ class U2STTrainer(Trainer):
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=TestCollator.from_config(config)) collate_fn=TestCollator.from_config(config),
num_workers=config.collator.num_workers, )
logger.info("Setup train/valid/test/align Dataloader!") logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self): def setup_model(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册