diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 50ff3c17b66bb04df5484c2ca1210357f305a192..bcd66d19ed5b5cc245a536ef3894c2e9b427c5d0 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -165,7 +165,7 @@ class DeepSpeech2Trainer(Trainer): sortagrad=config.data.sortagrad, shuffle_method=config.data.shuffle_method) - collate_fn = SpeechCollator(config, keep_transcription_text=False) + collate_fn = SpeechCollator(config=config, keep_transcription_text=False) self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -342,7 +342,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=SpeechCollator(keep_transcription_text=True)) + collate_fn=SpeechCollator(config=config, keep_transcription_text=True)) logger.info("Setup test Dataloader!") def setup_output_dir(self): diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index d725b0b1e69d8eeadfc5ea28e091c3733147f35b..0f86b8e7262357da0620deb387e0ac7ff312b1e2 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -23,6 +23,8 @@ from deepspeech.frontend.speech import SpeechSegment import io import time +from collections import namedtuple + __all__ = ["SpeechCollator"] logger = Log(__name__).getlog() @@ -50,7 +52,7 @@ class SpeechCollator(): aug_file = config.data.augmentation_config assert isinstance(aug_file, io.StringIO) - self._local_data = TarLocalData(tar2info={}, tar2object={}) + self._local_data = TarLocalData(tar2info={}, tar2object={}) self._augmentation_pipeline = AugmentationPipeline( augmentation_config=aug_file.read(), random_seed=config.data.random_seed)