# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import paddle from paddle import nn class CTCLoss(nn.Layer): def __init__(self, use_focal_loss=False, **kwargs): super(CTCLoss, self).__init__() self.loss_func = nn.CTCLoss(blank=0, reduction='none') self.use_focal_loss = use_focal_loss def forward(self, predicts, batch): if isinstance(predicts, (list, tuple)): predicts = predicts[-1] predicts = predicts.transpose((1, 0, 2)) N, B, _ = predicts.shape preds_lengths = paddle.to_tensor([N] * B, dtype='int64') labels = batch[1].astype("int32") label_lengths = batch[2].astype('int64') loss = self.loss_func(predicts, labels, preds_lengths, label_lengths) if self.use_focal_loss: weight = paddle.exp(-loss) weight = paddle.subtract(paddle.to_tensor([1.0]), weight) weight = paddle.square(weight) * self.focal_loss_alpha loss = paddle.multiply(loss, weight) loss = loss.mean() # sum return {'loss': loss}