提交 c706dfec 编写于 作者: H Haoxin Ma

fix bug

上级 279348d7
...@@ -165,7 +165,7 @@ class DeepSpeech2Trainer(Trainer): ...@@ -165,7 +165,7 @@ class DeepSpeech2Trainer(Trainer):
sortagrad=config.data.sortagrad, sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.data.shuffle_method)
collate_fn = SpeechCollator(config, keep_transcription_text=False) collate_fn = SpeechCollator(config=config, keep_transcription_text=False)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
...@@ -342,7 +342,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -342,7 +342,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True)) collate_fn=SpeechCollator(config=config, keep_transcription_text=True))
logger.info("Setup test Dataloader!") logger.info("Setup test Dataloader!")
def setup_output_dir(self): def setup_output_dir(self):
......
...@@ -23,6 +23,8 @@ from deepspeech.frontend.speech import SpeechSegment ...@@ -23,6 +23,8 @@ from deepspeech.frontend.speech import SpeechSegment
import io import io
import time import time
from collections import namedtuple
__all__ = ["SpeechCollator"] __all__ = ["SpeechCollator"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -50,7 +52,7 @@ class SpeechCollator(): ...@@ -50,7 +52,7 @@ class SpeechCollator():
aug_file = config.data.augmentation_config aug_file = config.data.augmentation_config
assert isinstance(aug_file, io.StringIO) assert isinstance(aug_file, io.StringIO)
self._local_data = TarLocalData(tar2info={}, tar2object={} self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline( self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=aug_file.read(), augmentation_config=aug_file.read(),
random_seed=config.data.random_seed) random_seed=config.data.random_seed)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册