From f7d7e70cb24338e61e921240c18c20bc88456150 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 27 Sep 2021 10:50:20 +0000 Subject: [PATCH] more ctc check; valid dataloader with num workers --- deepspeech/exps/u2/model.py | 4 +++- deepspeech/modules/ctc.py | 2 +- deepspeech/modules/loss.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 5cb0962a..5cf8866c 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -243,6 +243,7 @@ class U2Trainer(Trainer): self.visualizer.add_scalars( 'epoch', {'cv_loss': cv_loss, 'lr': self.lr_scheduler()}, self.epoch) + self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.new_epoch() @@ -291,7 +292,8 @@ class U2Trainer(Trainer): batch_size=config.collator.batch_size, shuffle=False, drop_last=False, - collate_fn=collate_fn_dev) + collate_fn=collate_fn_dev, + num_workers=config.collator.num_workers, ) # test dataset, return raw text config.data.manifest = config.data.test_manifest diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index 11ce871f..551bbf67 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -49,7 +49,7 @@ class CTCDecoder(nn.Layer): dropout_rate (float): dropout rate (0.0 ~ 1.0) reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' batch_average (bool): do batch dim wise average. - grad_norm_type (str): one of 'instance', 'batchsize', 'frame', None. + grad_norm_type (str): one of 'instance', 'batch', 'frame', None. """ assert check_argument_types() super().__init__() diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index 7d24e170..1f33e512 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -49,6 +49,8 @@ class CTCLoss(nn.Layer): self.norm_by_batchsize = True elif grad_norm_type == 'frame': self.norm_by_total_logits_len = True + else: + raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}") def forward(self, logits, ys_pad, hlens, ys_lens): """Compute CTC loss. -- GitLab