提交 30499a76 编写于 作者: H Hui Zhang

not change ctc grad manual

上级 67555cb9
...@@ -355,37 +355,7 @@ if not hasattr(paddle.Tensor, 'tolist'): ...@@ -355,37 +355,7 @@ if not hasattr(paddle.Tensor, 'tolist'):
setattr(paddle.Tensor, 'tolist', tolist) setattr(paddle.Tensor, 'tolist', tolist)
########### hack paddle.nn #############
########### hcak paddle.nn.functional #############
# hack loss
def ctc_loss(logits,
labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean',
norm_by_times=True):
#logger.info("my ctc loss with norm by times")
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,
input_lengths, label_lengths)
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum':
loss_out = paddle.sum(loss_out)
return loss_out
logger.debug(
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
)
F.ctc_loss = ctc_loss
########### hcak paddle.nn #############
from paddle.nn import Layer from paddle.nn import Layer
from typing import Optional from typing import Optional
from typing import Mapping from typing import Mapping
...@@ -532,3 +502,5 @@ if not hasattr(paddle.nn, 'LayerDict'): ...@@ -532,3 +502,5 @@ if not hasattr(paddle.nn, 'LayerDict'):
logger.debug( logger.debug(
"register user LayerDict to paddle.nn, remove this when fixed!") "register user LayerDict to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'LayerDict', LayerDict) setattr(paddle.nn, 'LayerDict', LayerDict)
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle import nn from paddle import nn
from typing import Union
from paddle.nn import functional as F from paddle.nn import functional as F
from typeguard import check_argument_types from typeguard import check_argument_types
...@@ -40,7 +41,7 @@ class CTCDecoderBase(nn.Layer): ...@@ -40,7 +41,7 @@ class CTCDecoderBase(nn.Layer):
dropout_rate: float=0.0, dropout_rate: float=0.0,
reduction: bool=True, reduction: bool=True,
batch_average: bool=True, batch_average: bool=True,
grad_norm_type: str="instance"): grad_norm_type: Union[str, None]=None):
"""CTC decoder """CTC decoder
Args: Args:
...@@ -49,7 +50,7 @@ class CTCDecoderBase(nn.Layer): ...@@ -49,7 +50,7 @@ class CTCDecoderBase(nn.Layer):
dropout_rate (float): dropout rate (0.0 ~ 1.0) dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none'
batch_average (bool): do batch dim wise average. batch_average (bool): do batch dim wise average.
grad_norm_type (str): one of 'instance', 'batch', 'frame', None. grad_norm_type (str): Default, None. one of 'instance', 'batch', 'frame', None.
""" """
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()
......
...@@ -54,7 +54,7 @@ class CTCLoss(nn.Layer): ...@@ -54,7 +54,7 @@ class CTCLoss(nn.Layer):
self.norm_by_total_logits_len = True self.norm_by_total_logits_len = True
else: else:
raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}") raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}")
self.kwargs = { kwargs = {
"norm_by_times": self.norm_by_times, "norm_by_times": self.norm_by_times,
"norm_by_batchsize": self.norm_by_batchsize, "norm_by_batchsize": self.norm_by_batchsize,
"norm_by_total_logits_len": self.norm_by_total_logits_len, "norm_by_total_logits_len": self.norm_by_total_logits_len,
...@@ -66,10 +66,9 @@ class CTCLoss(nn.Layer): ...@@ -66,10 +66,9 @@ class CTCLoss(nn.Layer):
except ValueError: except ValueError:
# Some function, e.g. built-in function, are failed # Some function, e.g. built-in function, are failed
param = {} param = {}
self._kwargs = {k: v for k, v in self.kwargs.items() if k in param} self._kwargs = {k: v for k, v in kwargs.items() if k in param}
_notin = {k: v for k, v in self.kwargs.items() if k not in param} _notin = {k: v for k, v in kwargs.items() if k not in param}
logger.info(f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}") logger.info(f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}")
#self.loss_fn = partial(self.loss.forward, **_kwargs)
def forward(self, logits, ys_pad, hlens, ys_lens): def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss. """Compute CTC loss.
...@@ -89,8 +88,7 @@ class CTCLoss(nn.Layer): ...@@ -89,8 +88,7 @@ class CTCLoss(nn.Layer):
# logits: (B, L, D) -> (L, B, D) # logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2]) logits = logits.transpose([1, 0, 2])
ys_pad = ys_pad.astype(paddle.int32) ys_pad = ys_pad.astype(paddle.int32)
#loss = self.loss_fn(logits, ys_pad, hlens, ys_lens) loss = self.loss(logits, ys_pad, hlens, ys_lens, **self._kwargs)
loss = self.loss(logits, ys_pad, hlens, ys_lens)
if self.batch_average: if self.batch_average:
# Batch-size average # Batch-size average
loss = loss / B loss = loss / B
......
...@@ -68,7 +68,7 @@ model: ...@@ -68,7 +68,7 @@ model:
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0 ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册