mlm_loss.py 1.8 KB
Newer Older
小湉湉's avatar
小湉湉 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
import paddle
from paddle import nn


class MLMLoss(nn.Layer):
    def __init__(self,
                 lsm_weight: float=0.1,
                 ignore_id: int=-1,
                 text_masking: bool=False):
        super().__init__()
        if text_masking:
            self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id)
        if lsm_weight > 50:
            self.l1_loss_func = nn.MSELoss()
        else:
            self.l1_loss_func = nn.L1Loss(reduction='none')
        self.text_masking = text_masking

    def forward(self,
                speech: paddle.Tensor,
                before_outs: paddle.Tensor,
                after_outs: paddle.Tensor,
                masked_pos: paddle.Tensor,
                text: paddle.Tensor=None,
                text_outs: paddle.Tensor=None,
                text_masked_pos: paddle.Tensor=None):

        xs_pad = speech
        mlm_loss_pos = masked_pos > 0
        loss = paddle.sum(
            self.l1_loss_func(
                paddle.reshape(before_outs, (-1, self.odim)),
                paddle.reshape(xs_pad, (-1, self.odim))),
            axis=-1)
        if after_outs is not None:
            loss += paddle.sum(
                self.l1_loss_func(
                    paddle.reshape(after_outs, (-1, self.odim)),
                    paddle.reshape(xs_pad, (-1, self.odim))),
                axis=-1)
        loss_mlm = paddle.sum((loss * paddle.reshape(
            mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)

        if self.text_masking:
            loss_text = paddle.sum((self.text_mlm_loss(
                paddle.reshape(text_outs, (-1, self.vocab_size)),
                paddle.reshape(text, (-1))) * paddle.reshape(
                    text_masked_pos,
                    (-1)))) / paddle.sum((text_masked_pos) + 1e-10)

            return loss_mlm, loss_text

        return loss_mlm