loss.py 6.6 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
H
Hui Zhang 已提交
14
# Modified from wenet(https://github.com/wenet-e2e/wenet)
15 16
import inspect

H
Hui Zhang 已提交
17 18 19 20
import paddle
from paddle import nn
from paddle.nn import functional as F

21
from paddlespeech.s2t.utils.log import Log
H
Hui Zhang 已提交
22

23
logger = Log(__name__).getlog()
H
Hui Zhang 已提交
24

25
__all__ = ['CTCLoss', "LabelSmoothingLoss"]
H
Hui Zhang 已提交
26 27 28


class CTCLoss(nn.Layer):
H
Hui Zhang 已提交
29 30 31 32 33
    def __init__(self,
                 blank=0,
                 reduction='sum',
                 batch_average=False,
                 grad_norm_type=None):
H
Hui Zhang 已提交
34 35
        super().__init__()
        # last token id as blank id
36
        self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
H
Hui Zhang 已提交
37
        self.batch_average = batch_average
38

H
Hui Zhang 已提交
39 40
        logger.info(
            f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}")
41
        logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
H
Hui Zhang 已提交
42

43
        assert grad_norm_type in ('instance', 'batch', 'frame', None)
H
Hui Zhang 已提交
44 45 46
        self.norm_by_times = False
        self.norm_by_batchsize = False
        self.norm_by_total_logits_len = False
47 48 49 50
        if grad_norm_type is None:
            # no grad norm
            pass
        elif grad_norm_type == 'instance':
H
Hui Zhang 已提交
51
            self.norm_by_times = True
52
        elif grad_norm_type == 'batch':
J
Junkun 已提交
53
            self.norm_by_batchsize = True
54
        elif grad_norm_type == 'frame':
H
Hui Zhang 已提交
55
            self.norm_by_total_logits_len = True
56 57
        else:
            raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}")
H
Hui Zhang 已提交
58
        kwargs = {
59 60 61 62 63 64 65 66 67 68 69
            "norm_by_times": self.norm_by_times,
            "norm_by_batchsize": self.norm_by_batchsize,
            "norm_by_total_logits_len": self.norm_by_total_logits_len,
        }

        # Derive only the args which the func has
        try:
            param = inspect.signature(self.loss.forward).parameters
        except ValueError:
            # Some function, e.g. built-in function, are failed
            param = {}
H
Hui Zhang 已提交
70 71
        self._kwargs = {k: v for k, v in kwargs.items() if k in param}
        _notin = {k: v for k, v in kwargs.items() if k not in param}
72
        logger.info(f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}")
H
Hui Zhang 已提交
73

74 75 76 77
    def forward(self, logits, ys_pad, hlens, ys_lens):
        """Compute CTC loss.

        Args:
H
Hui Zhang 已提交
78 79 80 81
            logits ([paddle.Tensor]): [B, Tmax, D]
            ys_pad ([paddle.Tensor]): [B, Tmax]
            hlens ([paddle.Tensor]): [B]
            ys_lens ([paddle.Tensor]): [B]
82 83 84 85

        Returns:
            [paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}.
        """
H
Hui Zhang 已提交
86
        B = paddle.shape(logits)[0]
87
        # warp-ctc need logits, and do softmax on logits by itself
H
Hui Zhang 已提交
88
        # warp-ctc need activation with shape [T, B, V + 1]
89
        # logits: (B, L, D) -> (L, B, D)
H
Hui Zhang 已提交
90
        logits = logits.transpose([1, 0, 2])
91
        ys_pad = ys_pad.astype(paddle.int32)
H
Hui Zhang 已提交
92
        loss = self.loss(logits, ys_pad, hlens, ys_lens, **self._kwargs)
H
Hui Zhang 已提交
93 94 95
        if self.batch_average:
            # Batch-size average
            loss = loss / B
96
        return loss
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132


class LabelSmoothingLoss(nn.Layer):
    """Label-smoothing loss.
    In a standard CE loss, the label's data distribution is:
        [0,1,2] ->
        [
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
            [0.0, 0.0, 1.0],
        ]
    In the smoothing version CE Loss,some probabilities
    are taken from the true label prob (1.0) and are divided
    among other labels.
        e.g.
        smoothing=0.1
        [0,1,2] ->
        [
            [0.9, 0.05, 0.05],
            [0.05, 0.9, 0.05],
            [0.05, 0.05, 0.9],
        ]

    """

    def __init__(self,
                 size: int,
                 padding_idx: int,
                 smoothing: float,
                 normalize_length: bool=False):
        """Label-smoothing loss.

        Args:
            size (int): the number of class
            padding_idx (int): padding class id which will be ignored for loss
            smoothing (float): smoothing rate (0.0 means the conventional CE)
133 134
            normalize_length (bool):
                True, normalize loss by sequence length;
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
                False, normalize loss by batch size.
                Defaults to False.
        """
        super().__init__()
        self.size = size
        self.padding_idx = padding_idx
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing
        self.normalize_length = normalize_length
        self.criterion = nn.KLDivLoss(reduction="none")

    def forward(self, x: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor:
        """Compute loss between x and target.
        The model outputs and data labels tensors are flatten to
        (batch*seqlen, class) shape and a mask is applied to the
        padding part which should not be calculated for loss.
151

152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
        Args:
            x (paddle.Tensor): prediction (batch, seqlen, class)
            target (paddle.Tensor):
                target signal masked with self.padding_id (batch, seqlen)
        Returns:
            loss (paddle.Tensor) : The KL loss, scalar float value
        """
        B, T, D = paddle.shape(x)
        assert D == self.size
        x = x.reshape((-1, self.size))
        target = target.reshape([-1])

        # use zeros_like instead of torch.no_grad() for true_dist,
        # since no_grad() can not be exported by JIT
        true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))
167
        ignore = target == self.padding_idx  # (B,)
168

H
Hui Zhang 已提交
169
        #TODO(Hui Zhang): target = target * (1 - ignore)  # avoid -1 index
170 171 172 173 174 175 176 177
        target = target.masked_fill(ignore, 0)  # avoid -1 index
        # true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        target_mask = F.one_hot(target, self.size)
        true_dist *= (1 - target_mask)
        true_dist += target_mask * self.confidence

        kl = self.criterion(F.log_softmax(x, axis=1), true_dist)

178 179 180
        #TODO(Hui Zhang): sum not support bool type
        #total = len(target) - int(ignore.sum())
        total = len(target) - int(ignore.type_as(target).sum())
181
        denom = total if self.normalize_length else B
182
        #numer = (kl * (1 - ignore)).sum()
183 184
        numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()
        return numer / denom