diff --git a/ernie-sat/inference.py b/ernie-sat/inference.py index 52e00f443bf0d061ffc23bc57c410932aec2da91..c7fde9f558c26dcd1e38f9ec6364e15a286ed8c4 100644 --- a/ernie-sat/inference.py +++ b/ernie-sat/inference.py @@ -546,19 +546,9 @@ def decode_with_model(mlm_model: nn.Layer, text_seg_pos=feats['text_seg_pos'], span_bdy=new_span_bdy, use_teacher_forcing=use_teacher_forcing) - if 0 in output[0].shape and 0 not in output[-1].shape: - output_feat = paddle.concat( - output[1:-1] + [output[-1].squeeze()], axis=0) - elif 0 not in output[0].shape and 0 in output[-1].shape: - output_feat = paddle.concat( - [output[0].squeeze()] + output[1:-1], axis=0) - elif 0 in output[0].shape and 0 in output[-1].shape: - output_feat = paddle.concat(output[1:-1], axis=0) - else: - output_feat = paddle.concat( - [output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)], - axis=0) + # 拼接音频 + output_feat = paddle.concat(x=output, axis=0) wav_org, _ = librosa.load(wav_path, sr=fs) return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length diff --git a/ernie-sat/mlm.py b/ernie-sat/mlm.py index 682e3481fc84ce86dcaef51be63b66407b02cd8b..2cf7921f6bfb5b190e47da8a25513efbc25f7dee 100644 --- a/ernie-sat/mlm.py +++ b/ernie-sat/mlm.py @@ -7,7 +7,6 @@ from typing import Optional from typing import Tuple from typing import Union -import numpy as np import paddle import yaml from paddle import nn @@ -395,13 +394,13 @@ class MLM(nn.Layer): use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]: ''' Args: - speech (paddle.Tensor): input speech (B, Tmax, D). - text (paddle.Tensor): input text (B, Tmax2). - masked_pos (paddle.Tensor): masked position of input speech (B, Tmax) - speech_mask (paddle.Tensor): mask of speech (B, 1, Tmax). - text_mask (paddle.Tensor): mask of text (B, 1, Tmax2). - speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax). - text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2). + speech (paddle.Tensor): input speech (1, Tmax, D). + text (paddle.Tensor): input text (1, Tmax2). + masked_pos (paddle.Tensor): masked position of input speech (1, Tmax) + speech_mask (paddle.Tensor): mask of speech (1, 1, Tmax). + text_mask (paddle.Tensor): mask of text (1, 1, Tmax2). + speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (1, Tmax). + text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (1, Tmax2). span_bdy (List[int]): masked mel boundary of input speech (2,) use_teacher_forcing (bool): whether to use teacher forcing Returns: @@ -410,7 +409,6 @@ class MLM(nn.Layer): [Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])] ''' - outs = [speech[:, :span_bdy[0]]] z_cache = None if use_teacher_forcing: before_outs, zs, *_ = self.forward( @@ -423,8 +421,11 @@ class MLM(nn.Layer): text_seg_pos=text_seg_pos) if zs is None: zs = before_outs + + speech = speech.squeeze(0) + outs = [speech[:span_bdy[0]]] outs += [zs[0][span_bdy[0]:span_bdy[1]]] - outs += [speech[:, span_bdy[1]:]] + outs += [speech[span_bdy[1]:]] return outs return None