提交 26910132 编写于 作者: H Hui Zhang

ctcloss can work w/ paddle2.1.2, but loss larger than before

上级 86a221d5
......@@ -353,3 +353,31 @@ if not hasattr(paddle.Tensor, 'tolist'):
logger.debug(
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
# hack loss
def ctc_loss(logits,
labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean',
norm_by_times=True):
#logger.info("my ctc loss with norm by times")
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,
input_lengths, label_lengths)
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum':
loss_out = paddle.sum(loss_out)
return loss_out
logger.debug(
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
)
F.ctc_loss = ctc_loss
......@@ -67,10 +67,10 @@ class CTCLoss(nn.Layer):
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
_kwargs = {k: v for k, v in self.kwargs.items() if k in param}
self._kwargs = {k: v for k, v in self.kwargs.items() if k in param}
_notin = {k: v for k, v in self.kwargs.items() if k not in param}
logger.info(f"{self.loss} kwargs:{_kwargs}, not support: {_notin}")
self.loss_fn = partial(self.loss.forward, **_kwargs)
logger.info(f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}")
#self.loss_fn = partial(self.loss.forward, **_kwargs)
def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss.
......@@ -90,7 +90,8 @@ class CTCLoss(nn.Layer):
# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2])
ys_pad = ys_pad.astype(paddle.int32)
loss = self.loss_fn(logits, ys_pad, hlens, ys_lens)
#loss = self.loss_fn(logits, ys_pad, hlens, ys_lens)
loss = self.loss(logits, ys_pad, hlens, ys_lens)
if self.batch_average:
# Batch-size average
loss = loss / B
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册