From 5ffccbd5c923ed726f92a960eadf9d1a183d8e79 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 26 Oct 2021 09:38:54 +0000 Subject: [PATCH] exp with eval mode --- deepspeech/exps/deepspeech2/model.py | 128 ++++++++++++++------------- deepspeech/exps/u2/model.py | 2 +- deepspeech/exps/u2_kaldi/model.py | 2 +- deepspeech/exps/u2_st/model.py | 2 +- deepspeech/training/trainer.py | 10 ++- 5 files changed, 78 insertions(+), 66 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 7b929f8b..6424cfdf 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -167,6 +167,11 @@ class DeepSpeech2Trainer(Trainer): logger.info(f"{model}") layer_tools.print_params(model, logger.info) + self.model = model + logger.info("Setup model!") + + if not self.train: + return grad_clip = ClipGradByGlobalNormWithLog( config.training.global_grad_clip) @@ -180,74 +185,77 @@ class DeepSpeech2Trainer(Trainer): weight_decay=paddle.regularizer.L2Decay( config.training.weight_decay), grad_clip=grad_clip) - - self.model = model self.optimizer = optimizer self.lr_scheduler = lr_scheduler - logger.info("Setup model/optimizer/lr_scheduler!") + logger.info("Setup optimizer/lr_scheduler!") + def setup_dataloader(self): config = self.config.clone() config.defrost() - config.collator.keep_transcription_text = False - - config.data.manifest = config.data.train_manifest - train_dataset = ManifestDataset.from_config(config) - - config.data.manifest = config.data.dev_manifest - dev_dataset = ManifestDataset.from_config(config) - - config.data.manifest = config.data.test_manifest - test_dataset = ManifestDataset.from_config(config) - - if self.parallel: - batch_sampler = SortagradDistributedBatchSampler( + if self.train: + # train + config.data.manifest = config.data.train_manifest + train_dataset = ManifestDataset.from_config(config) + if self.parallel: + batch_sampler = SortagradDistributedBatchSampler( + train_dataset, + batch_size=config.collator.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + train_dataset, + shuffle=True, + batch_size=config.collator.batch_size, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + + config.collator.keep_transcription_text = False + collate_fn_train = SpeechCollator.from_config(config) + self.train_loader = DataLoader( train_dataset, - batch_size=config.collator.batch_size, - num_replicas=None, - rank=None, - shuffle=True, - drop_last=True, - sortagrad=config.collator.sortagrad, - shuffle_method=config.collator.shuffle_method) + batch_sampler=batch_sampler, + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers) + + # dev + config.data.manifest = config.data.dev_manifest + dev_dataset = ManifestDataset.from_config(config) + + config.collator.augmentation_config = "" + config.collator.keep_transcription_text = False + collate_fn_dev = SpeechCollator.from_config(config) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=int(config.collator.batch_size), + shuffle=False, + drop_last=False, + collate_fn=collate_fn_dev, + num_workers=config.collator.num_workers) + logger.info("Setup train/valid Dataloader!") else: - batch_sampler = SortagradBatchSampler( - train_dataset, - shuffle=True, - batch_size=config.collator.batch_size, - drop_last=True, - sortagrad=config.collator.sortagrad, - shuffle_method=config.collator.shuffle_method) - - collate_fn_train = SpeechCollator.from_config(config) - - config.collator.augmentation_config = "" - collate_fn_dev = SpeechCollator.from_config(config) - - config.collator.keep_transcription_text = True - config.collator.augmentation_config = "" - collate_fn_test = SpeechCollator.from_config(config) - - self.train_loader = DataLoader( - train_dataset, - batch_sampler=batch_sampler, - collate_fn=collate_fn_train, - num_workers=config.collator.num_workers) - self.valid_loader = DataLoader( - dev_dataset, - batch_size=int(config.collator.batch_size), - shuffle=False, - drop_last=False, - collate_fn=collate_fn_dev, - num_workers=config.collator.num_workers) - self.test_loader = DataLoader( - test_dataset, - batch_size=config.decoding.batch_size, - shuffle=False, - drop_last=False, - collate_fn=collate_fn_test, - num_workers=config.collator.num_workers) - logger.info("Setup train/valid/test Dataloader!") + # test + config.data.manifest = config.data.test_manifest + test_dataset = ManifestDataset.from_config(config) + + config.collator.augmentation_config = "" + config.collator.keep_transcription_text = True + collate_fn_test = SpeechCollator.from_config(config) + + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_test, + num_workers=config.collator.num_workers) + logger.info("Setup test Dataloader!") class DeepSpeech2Tester(DeepSpeech2Trainer): diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 7806aaa4..e47a59ed 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -172,7 +172,7 @@ class U2Trainer(Trainer): dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts - def train(self): + def do_train(self): """The training process control by step.""" # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index f8624326..663c36d8 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -173,7 +173,7 @@ class U2Trainer(Trainer): dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts - def train(self): + def do_train(self): """The training process control by step.""" # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index c5df44c6..1f638e64 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -184,7 +184,7 @@ class U2STTrainer(Trainer): dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts - def train(self): + def do_train(self): """The training process control by step.""" # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 2c238920..2da83804 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -134,6 +134,10 @@ class Trainer(): logger.info( f"Benchmark reset batch-size: {self.args.benchmark_batch_size}") + @property + def train(self): + return self._train + @contextmanager def eval(self): self._train = False @@ -248,7 +252,7 @@ class Trainer(): sys.exit( f"Reach benchmark-max-step: {self.args.benchmark_max_step}") - def train(self): + def do_train(self): """The training process control by epoch.""" self.before_train() @@ -321,7 +325,7 @@ class Trainer(): """ try: with Timer("Training Done: {}"): - self.train() + self.do_train() except KeyboardInterrupt: exit(-1) finally: @@ -432,7 +436,7 @@ class Trainer(): beginning of the experiment. """ config_file = self.config_dir / "config.yaml" - if self._train and config_file.exists(): + if self.train and config_file.exists(): time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime()) target_path = self.config_dir / ".".join( [time_stamp, "config.yaml"]) -- GitLab