提交 e522009d 编写于 作者: 小湉湉's avatar 小湉湉

fix wav concat method

上级 9224659c
...@@ -546,19 +546,9 @@ def decode_with_model(mlm_model: nn.Layer, ...@@ -546,19 +546,9 @@ def decode_with_model(mlm_model: nn.Layer,
text_seg_pos=feats['text_seg_pos'], text_seg_pos=feats['text_seg_pos'],
span_bdy=new_span_bdy, span_bdy=new_span_bdy,
use_teacher_forcing=use_teacher_forcing) use_teacher_forcing=use_teacher_forcing)
if 0 in output[0].shape and 0 not in output[-1].shape:
output_feat = paddle.concat(
output[1:-1] + [output[-1].squeeze()], axis=0)
elif 0 not in output[0].shape and 0 in output[-1].shape:
output_feat = paddle.concat(
[output[0].squeeze()] + output[1:-1], axis=0)
elif 0 in output[0].shape and 0 in output[-1].shape:
output_feat = paddle.concat(output[1:-1], axis=0)
else:
output_feat = paddle.concat(
[output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)],
axis=0)
# 拼接音频
output_feat = paddle.concat(x=output, axis=0)
wav_org, _ = librosa.load(wav_path, sr=fs) wav_org, _ = librosa.load(wav_path, sr=fs)
return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
......
...@@ -7,7 +7,6 @@ from typing import Optional ...@@ -7,7 +7,6 @@ from typing import Optional
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
import numpy as np
import paddle import paddle
import yaml import yaml
from paddle import nn from paddle import nn
...@@ -395,13 +394,13 @@ class MLM(nn.Layer): ...@@ -395,13 +394,13 @@ class MLM(nn.Layer):
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]: use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
''' '''
Args: Args:
speech (paddle.Tensor): input speech (B, Tmax, D). speech (paddle.Tensor): input speech (1, Tmax, D).
text (paddle.Tensor): input text (B, Tmax2). text (paddle.Tensor): input text (1, Tmax2).
masked_pos (paddle.Tensor): masked position of input speech (B, Tmax) masked_pos (paddle.Tensor): masked position of input speech (1, Tmax)
speech_mask (paddle.Tensor): mask of speech (B, 1, Tmax). speech_mask (paddle.Tensor): mask of speech (1, 1, Tmax).
text_mask (paddle.Tensor): mask of text (B, 1, Tmax2). text_mask (paddle.Tensor): mask of text (1, 1, Tmax2).
speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax). speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (1, Tmax).
text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2). text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (1, Tmax2).
span_bdy (List[int]): masked mel boundary of input speech (2,) span_bdy (List[int]): masked mel boundary of input speech (2,)
use_teacher_forcing (bool): whether to use teacher forcing use_teacher_forcing (bool): whether to use teacher forcing
Returns: Returns:
...@@ -410,7 +409,6 @@ class MLM(nn.Layer): ...@@ -410,7 +409,6 @@ class MLM(nn.Layer):
[Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])] [Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])]
''' '''
outs = [speech[:, :span_bdy[0]]]
z_cache = None z_cache = None
if use_teacher_forcing: if use_teacher_forcing:
before_outs, zs, *_ = self.forward( before_outs, zs, *_ = self.forward(
...@@ -423,8 +421,11 @@ class MLM(nn.Layer): ...@@ -423,8 +421,11 @@ class MLM(nn.Layer):
text_seg_pos=text_seg_pos) text_seg_pos=text_seg_pos)
if zs is None: if zs is None:
zs = before_outs zs = before_outs
speech = speech.squeeze(0)
outs = [speech[:span_bdy[0]]]
outs += [zs[0][span_bdy[0]:span_bdy[1]]] outs += [zs[0][span_bdy[0]:span_bdy[1]]]
outs += [speech[:, span_bdy[1]:]] outs += [speech[span_bdy[1]:]]
return outs return outs
return None return None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册