From 269101323233c681a58c331baa16c2f15fca2d6e Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 8 Oct 2021 03:24:51 +0000 Subject: [PATCH] ctcloss can work w/ paddle2.1.2, but loss larger than before --- deepspeech/__init__.py | 28 ++++++++++++++++++++++++++++ deepspeech/modules/loss.py | 9 +++++---- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 5505ecbf..493f10a6 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -353,3 +353,31 @@ if not hasattr(paddle.Tensor, 'tolist'): logger.debug( "register user tolist to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'tolist', tolist) + + +# hack loss +def ctc_loss(logits, + labels, + input_lengths, + label_lengths, + blank=0, + reduction='mean', + norm_by_times=True): + #logger.info("my ctc loss with norm by times") + ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403 + loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times, + input_lengths, label_lengths) + + loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) + assert reduction in ['mean', 'sum', 'none'] + if reduction == 'mean': + loss_out = paddle.mean(loss_out / label_lengths) + elif reduction == 'sum': + loss_out = paddle.sum(loss_out) + return loss_out + + +logger.debug( + "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!" +) +F.ctc_loss = ctc_loss diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index df5298ea..71ecd266 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -67,10 +67,10 @@ class CTCLoss(nn.Layer): except ValueError: # Some function, e.g. built-in function, are failed param = {} - _kwargs = {k: v for k, v in self.kwargs.items() if k in param} + self._kwargs = {k: v for k, v in self.kwargs.items() if k in param} _notin = {k: v for k, v in self.kwargs.items() if k not in param} - logger.info(f"{self.loss} kwargs:{_kwargs}, not support: {_notin}") - self.loss_fn = partial(self.loss.forward, **_kwargs) + logger.info(f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}") + #self.loss_fn = partial(self.loss.forward, **_kwargs) def forward(self, logits, ys_pad, hlens, ys_lens): """Compute CTC loss. @@ -90,7 +90,8 @@ class CTCLoss(nn.Layer): # logits: (B, L, D) -> (L, B, D) logits = logits.transpose([1, 0, 2]) ys_pad = ys_pad.astype(paddle.int32) - loss = self.loss_fn(logits, ys_pad, hlens, ys_lens) + #loss = self.loss_fn(logits, ys_pad, hlens, ys_lens) + loss = self.loss(logits, ys_pad, hlens, ys_lens) if self.batch_average: # Batch-size average loss = loss / B -- GitLab