diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 7b929f8b701b57ce399013b36fcda867d1287fff..6424cfdf389ea2d62f80a6577e4f4270ccebca46 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -167,6 +167,11 @@ class DeepSpeech2Trainer(Trainer): logger.info(f"{model}") layer_tools.print_params(model, logger.info) + self.model = model + logger.info("Setup model!") + + if not self.train: + return grad_clip = ClipGradByGlobalNormWithLog( config.training.global_grad_clip) @@ -180,74 +185,77 @@ class DeepSpeech2Trainer(Trainer): weight_decay=paddle.regularizer.L2Decay( config.training.weight_decay), grad_clip=grad_clip) - - self.model = model self.optimizer = optimizer self.lr_scheduler = lr_scheduler - logger.info("Setup model/optimizer/lr_scheduler!") + logger.info("Setup optimizer/lr_scheduler!") + def setup_dataloader(self): config = self.config.clone() config.defrost() - config.collator.keep_transcription_text = False - - config.data.manifest = config.data.train_manifest - train_dataset = ManifestDataset.from_config(config) - - config.data.manifest = config.data.dev_manifest - dev_dataset = ManifestDataset.from_config(config) - - config.data.manifest = config.data.test_manifest - test_dataset = ManifestDataset.from_config(config) - - if self.parallel: - batch_sampler = SortagradDistributedBatchSampler( + if self.train: + # train + config.data.manifest = config.data.train_manifest + train_dataset = ManifestDataset.from_config(config) + if self.parallel: + batch_sampler = SortagradDistributedBatchSampler( + train_dataset, + batch_size=config.collator.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + train_dataset, + shuffle=True, + batch_size=config.collator.batch_size, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + + config.collator.keep_transcription_text = False + collate_fn_train = SpeechCollator.from_config(config) + self.train_loader = DataLoader( train_dataset, - batch_size=config.collator.batch_size, - num_replicas=None, - rank=None, - shuffle=True, - drop_last=True, - sortagrad=config.collator.sortagrad, - shuffle_method=config.collator.shuffle_method) + batch_sampler=batch_sampler, + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers) + + # dev + config.data.manifest = config.data.dev_manifest + dev_dataset = ManifestDataset.from_config(config) + + config.collator.augmentation_config = "" + config.collator.keep_transcription_text = False + collate_fn_dev = SpeechCollator.from_config(config) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=int(config.collator.batch_size), + shuffle=False, + drop_last=False, + collate_fn=collate_fn_dev, + num_workers=config.collator.num_workers) + logger.info("Setup train/valid Dataloader!") else: - batch_sampler = SortagradBatchSampler( - train_dataset, - shuffle=True, - batch_size=config.collator.batch_size, - drop_last=True, - sortagrad=config.collator.sortagrad, - shuffle_method=config.collator.shuffle_method) - - collate_fn_train = SpeechCollator.from_config(config) - - config.collator.augmentation_config = "" - collate_fn_dev = SpeechCollator.from_config(config) - - config.collator.keep_transcription_text = True - config.collator.augmentation_config = "" - collate_fn_test = SpeechCollator.from_config(config) - - self.train_loader = DataLoader( - train_dataset, - batch_sampler=batch_sampler, - collate_fn=collate_fn_train, - num_workers=config.collator.num_workers) - self.valid_loader = DataLoader( - dev_dataset, - batch_size=int(config.collator.batch_size), - shuffle=False, - drop_last=False, - collate_fn=collate_fn_dev, - num_workers=config.collator.num_workers) - self.test_loader = DataLoader( - test_dataset, - batch_size=config.decoding.batch_size, - shuffle=False, - drop_last=False, - collate_fn=collate_fn_test, - num_workers=config.collator.num_workers) - logger.info("Setup train/valid/test Dataloader!") + # test + config.data.manifest = config.data.test_manifest + test_dataset = ManifestDataset.from_config(config) + + config.collator.augmentation_config = "" + config.collator.keep_transcription_text = True + collate_fn_test = SpeechCollator.from_config(config) + + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_test, + num_workers=config.collator.num_workers) + logger.info("Setup test Dataloader!") class DeepSpeech2Tester(DeepSpeech2Trainer): diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 7806aaa491bcde1c26969ead6e4c8032e6aec665..e47a59edaf0435578b57edfc37222acca7df2de2 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -172,7 +172,7 @@ class U2Trainer(Trainer): dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts - def train(self): + def do_train(self): """The training process control by step.""" # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index f86243269ef5eec4955e82a155897c5882c45938..663c36d8b41f01d73cac5f9cabfee3fe99021144 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -173,7 +173,7 @@ class U2Trainer(Trainer): dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts - def train(self): + def do_train(self): """The training process control by step.""" # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index c5df44c6704678c0642a995b44c324693bd4e4b4..1f638e64c082f8e8bb7bd9fc8c4be7a2b53f529d 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -184,7 +184,7 @@ class U2STTrainer(Trainer): dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts - def train(self): + def do_train(self): """The training process control by step.""" # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 2c2389203a8989331d43d21edae80f24f55d6c8f..2da838047206fc34011986daf33aa4203c5e92dd 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -134,6 +134,10 @@ class Trainer(): logger.info( f"Benchmark reset batch-size: {self.args.benchmark_batch_size}") + @property + def train(self): + return self._train + @contextmanager def eval(self): self._train = False @@ -248,7 +252,7 @@ class Trainer(): sys.exit( f"Reach benchmark-max-step: {self.args.benchmark_max_step}") - def train(self): + def do_train(self): """The training process control by epoch.""" self.before_train() @@ -321,7 +325,7 @@ class Trainer(): """ try: with Timer("Training Done: {}"): - self.train() + self.do_train() except KeyboardInterrupt: exit(-1) finally: @@ -432,7 +436,7 @@ class Trainer(): beginning of the experiment. """ config_file = self.config_dir / "config.yaml" - if self._train and config_file.exists(): + if self.train and config_file.exists(): time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime()) target_path = self.config_dir / ".".join( [time_stamp, "config.yaml"])