diff --git a/paddlespeech/t2s/exps/ernie_sat/train.py b/paddlespeech/t2s/exps/ernie_sat/train.py index 977b8fc52b426f2779a19225067717a4f5efbedb..020b0d0fa06f34e280d269adb5624dd416352af7 100644 --- a/paddlespeech/t2s/exps/ernie_sat/train.py +++ b/paddlespeech/t2s/exps/ernie_sat/train.py @@ -154,6 +154,7 @@ def train_sp(args, config): dataloader=train_dataloader, text_masking=config["model"]["text_masking"], odim=odim, + vocab_size=vocab_size, output_dir=output_dir) trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) @@ -163,6 +164,7 @@ def train_sp(args, config): dataloader=dev_dataloader, text_masking=config["model"]["text_masking"], odim=odim, + vocab_size=vocab_size, output_dir=output_dir, ) if dist.get_rank() == 0: diff --git a/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py b/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py index 17cfaae966bf53e1a2ac508fd6e367cb337d35c8..219341c8860f1b7e85ea1136523bb6d54e40f4ce 100644 --- a/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py +++ b/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py @@ -40,11 +40,13 @@ class ErnieSATUpdater(StandardUpdater): init_state=None, text_masking: bool=False, odim: int=80, + vocab_size: int=100, output_dir: Path=None): super().__init__(model, optimizer, dataloader, init_state=None) self.scheduler = scheduler - self.criterion = MLMLoss(text_masking=text_masking, odim=odim) + self.criterion = MLMLoss( + text_masking=text_masking, odim=odim, vocab_size=vocab_size) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) self.filehandler = logging.FileHandler(str(log_file)) @@ -104,6 +106,7 @@ class ErnieSATEvaluator(StandardEvaluator): dataloader: DataLoader, text_masking: bool=False, odim: int=80, + vocab_size: int=100, output_dir: Path=None): super().__init__(model, dataloader) @@ -113,7 +116,8 @@ class ErnieSATEvaluator(StandardEvaluator): self.logger = logger self.msg = "" - self.criterion = MLMLoss(text_masking=text_masking, odim=odim) + self.criterion = MLMLoss( + text_masking=text_masking, odim=odim, vocab_size=vocab_size) def evaluate_core(self, batch): self.msg = "Evaluate: " diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index 95f2ff8649bc67a58785e417c50aa1fd937aa243..b3cf45aafd920ec8bc5e3d7a77249358dc57cab5 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -1013,6 +1013,7 @@ class KLDivergenceLoss(nn.Layer): class MLMLoss(nn.Layer): def __init__(self, odim: int, + vocab_size: int=0, lsm_weight: float=0.1, ignore_id: int=-1, text_masking: bool=False): @@ -1025,6 +1026,7 @@ class MLMLoss(nn.Layer): self.l1_loss_func = nn.L1Loss(reduction='none') self.text_masking = text_masking self.odim = odim + self.vocab_size = vocab_size def forward( self, @@ -1059,10 +1061,12 @@ class MLMLoss(nn.Layer): assert text is not None assert text_outs is not None assert text_masked_pos is not None - text_mlm_loss = paddle.sum((self.text_mlm_loss( - paddle.reshape(text_outs, (-1, self.vocab_size)), - paddle.reshape(text, (-1))) * paddle.reshape( - text_masked_pos, - (-1)))) / paddle.sum((text_masked_pos) + 1e-10) + text_outs = paddle.reshape(text_outs, [-1, self.vocab_size]) + text = paddle.reshape(text, [-1]) + text_mlm_loss = self.text_mlm_loss(text_outs, text) + text_masked_pos_reshape = paddle.reshape(text_masked_pos, [-1]) + text_mlm_loss = paddle.sum( + text_mlm_loss * + text_masked_pos_reshape) / paddle.sum((text_masked_pos) + 1e-10) return mlm_loss, text_mlm_loss diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 608a47421d7ec2b94df1b3dbc5ab84d026f78dd3..1490ae836b442c692f5c57af6818ff8c1bb8611d 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -464,14 +464,15 @@ def phones_text_masking(xs_pad: paddle.Tensor, set(range(length)) - set(masked_phn_idxs[0].tolist())) np.random.shuffle(unmasked_phn_idxs) masked_text_idxs = unmasked_phn_idxs[:text_mask_num_lower] - text_masked_pos[idx][masked_text_idxs] = 1 + text_masked_pos[idx, masked_text_idxs] = 1 masked_start = align_start[idx][masked_phn_idxs].tolist() masked_end = align_end[idx][masked_phn_idxs].tolist() for s, e in zip(masked_start, masked_end): masked_pos[idx, s:e] = 1 - non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) + non_eos_mask = paddle.reshape(src_mask, shape=paddle.shape(xs_pad)[:2]) masked_pos = masked_pos * non_eos_mask - non_eos_text_mask = paddle.reshape(text_mask, paddle.shape(xs_pad)[:2]) + non_eos_text_mask = paddle.reshape( + text_mask, shape=paddle.shape(text_pad)[:2]) text_masked_pos = text_masked_pos * non_eos_text_mask masked_pos = paddle.cast(masked_pos, 'bool') text_masked_pos = paddle.cast(text_masked_pos, 'bool')