diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 5f9ba007e3e3c81b356a5a120dcf5485bfbce727..6ed1177a51ee439566c6faee8bbc8b5c9848a362 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -355,37 +355,7 @@ if not hasattr(paddle.Tensor, 'tolist'): setattr(paddle.Tensor, 'tolist', tolist) - -########### hcak paddle.nn.functional ############# -# hack loss -def ctc_loss(logits, - labels, - input_lengths, - label_lengths, - blank=0, - reduction='mean', - norm_by_times=True): - #logger.info("my ctc loss with norm by times") - ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403 - loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times, - input_lengths, label_lengths) - - loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) - assert reduction in ['mean', 'sum', 'none'] - if reduction == 'mean': - loss_out = paddle.mean(loss_out / label_lengths) - elif reduction == 'sum': - loss_out = paddle.sum(loss_out) - return loss_out - - -logger.debug( - "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!" -) -F.ctc_loss = ctc_loss - - -########### hcak paddle.nn ############# +########### hack paddle.nn ############# from paddle.nn import Layer from typing import Optional from typing import Mapping @@ -532,3 +502,5 @@ if not hasattr(paddle.nn, 'LayerDict'): logger.debug( "register user LayerDict to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'LayerDict', LayerDict) + + diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index 565a11e15c5d078e07e84a7c60a0b2baaf6b19f1..e0c8006d17272dde54ac948b9f112fa1d67ba92a 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle from paddle import nn +from typing import Union from paddle.nn import functional as F from typeguard import check_argument_types @@ -40,7 +41,7 @@ class CTCDecoderBase(nn.Layer): dropout_rate: float=0.0, reduction: bool=True, batch_average: bool=True, - grad_norm_type: str="instance"): + grad_norm_type: Union[str, None]=None): """CTC decoder Args: @@ -49,7 +50,7 @@ class CTCDecoderBase(nn.Layer): dropout_rate (float): dropout rate (0.0 ~ 1.0) reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' batch_average (bool): do batch dim wise average. - grad_norm_type (str): one of 'instance', 'batch', 'frame', None. + grad_norm_type (str): Default, None. one of 'instance', 'batch', 'frame', None. """ assert check_argument_types() super().__init__() diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index e06f26f8166325b4e115918e3971a5dc4247257b..e11388107f5a9fc3199192f6f7881817cb3680a5 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -54,7 +54,7 @@ class CTCLoss(nn.Layer): self.norm_by_total_logits_len = True else: raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}") - self.kwargs = { + kwargs = { "norm_by_times": self.norm_by_times, "norm_by_batchsize": self.norm_by_batchsize, "norm_by_total_logits_len": self.norm_by_total_logits_len, @@ -66,10 +66,9 @@ class CTCLoss(nn.Layer): except ValueError: # Some function, e.g. built-in function, are failed param = {} - self._kwargs = {k: v for k, v in self.kwargs.items() if k in param} - _notin = {k: v for k, v in self.kwargs.items() if k not in param} + self._kwargs = {k: v for k, v in kwargs.items() if k in param} + _notin = {k: v for k, v in kwargs.items() if k not in param} logger.info(f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}") - #self.loss_fn = partial(self.loss.forward, **_kwargs) def forward(self, logits, ys_pad, hlens, ys_lens): """Compute CTC loss. @@ -89,8 +88,7 @@ class CTCLoss(nn.Layer): # logits: (B, L, D) -> (L, B, D) logits = logits.transpose([1, 0, 2]) ys_pad = ys_pad.astype(paddle.int32) - #loss = self.loss_fn(logits, ys_pad, hlens, ys_lens) - loss = self.loss(logits, ys_pad, hlens, ys_lens) + loss = self.loss(logits, ys_pad, hlens, ys_lens, **self._kwargs) if self.batch_average: # Batch-size average loss = loss / B diff --git a/examples/librispeech/s1/conf/transformer.yaml b/examples/librispeech/s1/conf/transformer.yaml index c9dc1413b36f02c673dcd942323af666fcf5ff35..3cc17004c0ac103efd16e5d4899b910805faeda5 100644 --- a/examples/librispeech/s1/conf/transformer.yaml +++ b/examples/librispeech/s1/conf/transformer.yaml @@ -68,7 +68,7 @@ model: model_conf: ctc_weight: 0.3 ctc_dropoutrate: 0.0 - ctc_grad_norm_type: instance + ctc_grad_norm_type: null lsm_weight: 0.1 # label smoothing option length_normalized_loss: false