提交 9224659c 编写于 作者: 小湉湉's avatar 小湉湉

add docstring

上级 76b654cb
......@@ -39,9 +39,9 @@ ERNIE-SAT 中我们提出了两项创新:
### 2.预训练模型
预训练模型 ERNIE-SAT 的模型如下所示:
- [ERNIE-SAT_ZH](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-zh.tar.gz)
- [ERNIE-SAT_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en.tar.gz)
- [ERNIE-SAT_ZH_and_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en_zh.tar.gz)
- [ERNIE-SAT_ZH](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-zh.tar.gz)
- [ERNIE-SAT_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en.tar.gz)
- [ERNIE-SAT_ZH_and_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en_zh.tar.gz)
创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
......@@ -108,7 +108,7 @@ prompt/dev
3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。
5. `--lang` 对应模型的语言可以是 `zh``en`
6. `--ngpu` 要使用的GPU数,如果 ngpu==0,则使用 cpu。
6. `--ngpu` 要使用的 GPU 数,如果 ngpu==0,则使用 cpu。
7. ` --model_name` 模型名称
8. ` --uid` 特定提示(prompt)语音的 id
9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本)
......@@ -125,4 +125,3 @@ sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh # 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
```
#!/usr/bin/env python
""" Usage:
align.py wavfile trsfile outwordfile outphonefile
"""
import multiprocessing as mp
import os
import sys
from tqdm import tqdm
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin'
......@@ -15,6 +11,142 @@ HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy'
def get_unk_phns(word_str: str):
tmpbase = '/tmp/tp.'
f = open(tmpbase + 'temp.words', 'w')
f.write(word_str)
f.close()
os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase +
'temp.phons')
f = open(tmpbase + 'temp.phons', 'r')
lines2 = f.readline().strip().split()
f.close()
phns = []
for phn in lines2:
phons = phn.replace('\n', '').replace(' ', '')
seq = []
j = 0
while (j < len(phons)):
if (phons[j] > 'Z'):
if (phons[j] == 'j'):
seq.append('JH')
elif (phons[j] == 'h'):
seq.append('HH')
else:
seq.append(phons[j].upper())
j += 1
else:
p = phons[j:j + 2]
if (p == 'WH'):
seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
seq.append(p)
elif (p == 'AX'):
seq.append('AH0')
else:
seq.append(p + '1')
j += 2
phns.extend(seq)
return phns
def words2phns(line: str):
'''
Args:
line (str): input text.
eg: for that reason cover is impossible to be given.
Returns:
List[str]: phones of input text.
eg:
['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0',
'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1',
'G', 'IH1', 'V', 'AH0', 'N']
Dict(str, str): key - idx_word
value - phones
eg:
{'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],
'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'],
'6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']}
'''
dictfile = MODEL_DIR_EN + '/dict'
line = line.strip()
words = []
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd)
phns.extend(get_unk_phns(wrd))
else:
wrd2phns[str(index) +
"_" + wrd.upper()] = word2phns_dict[wrd.upper()].split()
phns.extend(word2phns_dict[wrd.upper()].split())
return phns, wrd2phns
def words2phns_zh(line: str):
dictfile = MODEL_DIR_ZH + '/dict'
line = line.strip()
words = []
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
print("出现非法词错误,请输入正确的文本...")
else:
wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split()
phns.extend(word2phns_dict[wrd].split())
return phns, wrd2phns
def prep_txt_zh(line: str, tmpbase: str, dictfile: str):
words = []
......@@ -82,7 +214,7 @@ def prep_txt_en(line: str, tmpbase, dictfile):
try:
os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase +
'_unk.phons')
except:
except Exception:
print('english2phoneme error!')
sys.exit(1)
......@@ -148,19 +280,22 @@ def _get_user():
def alignment(wav_path: str, text: str):
'''
intervals: List[phn, start, end]
'''
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -')
except:
except Exception:
print('sox error!')
return None
#prepare clean_transcript file
try:
prep_txt_en(text, tmpbase, MODEL_DIR_EN + '/dict')
except:
prep_txt_en(line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_EN + '/dict')
except Exception:
print('prep_txt error!')
return None
......@@ -169,7 +304,7 @@ def alignment(wav_path: str, text: str):
with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline()
prep_mlf(txt, tmpbase)
except:
except Exception:
print('prep_mlf error!')
return None
......@@ -177,7 +312,7 @@ def alignment(wav_path: str, text: str):
try:
os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
except Exception:
print('HCopy error!')
return None
......@@ -188,7 +323,7 @@ def alignment(wav_path: str, text: str):
+ '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase +
'.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
except Exception:
print('HVite error!')
return None
......@@ -200,7 +335,7 @@ def alignment(wav_path: str, text: str):
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
times2 = []
intervals = []
word2phns = {}
current_word = ''
index = 0
......@@ -210,7 +345,7 @@ def alignment(wav_path: str, text: str):
phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
intervals.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
......@@ -219,10 +354,10 @@ def alignment(wav_path: str, text: str):
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return times2, word2phns
return intervals, word2phns
def alignment_zh(wav_path, text_string):
def alignment_zh(wav_path: str, text: str):
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files
......@@ -230,18 +365,19 @@ def alignment_zh(wav_path, text_string):
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -')
except:
except Exception:
print('sox error!')
return None
#prepare clean_transcript file
try:
unk_words = prep_txt_zh(text_string, tmpbase, MODEL_DIR_ZH + '/dict')
unk_words = prep_txt_zh(
line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_ZH + '/dict')
if unk_words:
print('Error! Please add the following words to dictionary:')
for unk in unk_words:
print("非法words: ", unk)
except:
except Exception:
print('prep_txt error!')
return None
......@@ -250,7 +386,7 @@ def alignment_zh(wav_path, text_string):
with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline()
prep_mlf(txt, tmpbase)
except:
except Exception:
print('prep_mlf error!')
return None
......@@ -258,7 +394,7 @@ def alignment_zh(wav_path, text_string):
try:
os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
except Exception:
print('HCopy error!')
return None
......@@ -270,7 +406,7 @@ def alignment_zh(wav_path, text_string):
+ '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
except Exception:
print('HVite error!')
return None
......@@ -283,7 +419,7 @@ def alignment_zh(wav_path, text_string):
lines = fid.readlines()
i = 2
times2 = []
intervals = []
word2phns = {}
current_word = ''
index = 0
......@@ -293,7 +429,7 @@ def alignment_zh(wav_path, text_string):
phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
intervals.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
......@@ -302,4 +438,4 @@ def alignment_zh(wav_path, text_string):
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return times2, word2phns
return intervals, word2phns
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
import numpy as np
import paddle
from dataset import get_seg_pos
from dataset import phones_masking
from dataset import phones_text_masking
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import pad_list
class MLMCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(self,
feats_extract,
float_pad_value: Union[float, int]=0.0,
int_pad_value: int=-32768,
not_sequence: Collection[str]=(),
mlm_prob: float=0.8,
mean_phn_span: int=8,
attention_window: int=0,
pad_speech: bool=False,
seg_emb: bool=False,
text_masking: bool=False):
self.mlm_prob = mlm_prob
self.mean_phn_span = mean_phn_span
self.feats_extract = feats_extract
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.attention_window = attention_window
self.pad_speech = pad_speech
self.seg_emb = seg_emb
self.text_masking = text_masking
def __repr__(self):
return (f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})")
def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
return mlm_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
mlm_prob=self.mlm_prob,
mean_phn_span=self.mean_phn_span,
feats_extract=self.feats_extract,
attention_window=self.attention_window,
pad_speech=self.pad_speech,
seg_emb=self.seg_emb,
text_masking=self.text_masking)
def mlm_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int]=0.0,
int_pad_value: int=-32768,
not_sequence: Collection[str]=(),
mlm_prob: float=0.8,
mean_phn_span: int=8,
feats_extract=None,
attention_window: int=0,
pad_speech: bool=False,
seg_emb: bool=False,
text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(not k.endswith("_lens")
for k in data[0]), f"*_lens is reserved: {list(data[0])}"
output = {}
for key in data[0]:
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list = [paddle.to_tensor(a) for a in array_list]
# tensor: (Batch, Length, ...)
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
# lens: (Batch,)
if key not in not_sequence:
lens = paddle.to_tensor(
[d[key].shape[0] for d in data], dtype=paddle.int64)
output[key + "_lens"] = lens
feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
feats = paddle.to_tensor(feats)
feats_lens = paddle.shape(feats)[0]
feats = paddle.unsqueeze(feats, 0)
text = output["text"]
text_lens = output["text_lens"]
align_start = output["align_start"]
align_start_lens = output["align_start_lens"]
align_end = output["align_end"]
max_tlen = max(text_lens)
max_slen = max(feats_lens)
speech_pad = feats[:, :max_slen]
text_pad = text
text_mask = make_non_pad_mask(
text_lens, text_pad, length_dim=1).unsqueeze(-2)
speech_mask = make_non_pad_mask(
feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
span_bdy = None
if 'span_bdy' in output.keys():
span_bdy = output['span_bdy']
# dual_mask 的是混合中英时候同时 mask 语音和文本
# ernie sat 在实现跨语言的时候都 mask 了
if text_masking:
masked_pos, text_masked_pos = phones_text_masking(
xs_pad=speech_pad,
src_mask=speech_mask,
text_pad=text_pad,
text_mask=text_mask,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
span_bdy=span_bdy)
# 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了
# a3t 和 ernie sat 的区别主要在于做 mask 的时候
else:
masked_pos = phones_masking(
xs_pad=speech_pad,
src_mask=speech_mask,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
span_bdy=span_bdy)
text_masked_pos = paddle.zeros(paddle.shape(text_pad))
output_dict = {}
speech_seg_pos, text_seg_pos = get_seg_pos(
speech_pad=speech_pad,
text_pad=text_pad,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
seg_emb=seg_emb)
output_dict['speech'] = speech_pad
output_dict['text'] = text_pad
output_dict['masked_pos'] = masked_pos
output_dict['text_masked_pos'] = text_masked_pos
output_dict['speech_mask'] = speech_mask
output_dict['text_mask'] = text_mask
output_dict['speech_seg_pos'] = speech_seg_pos
output_dict['text_seg_pos'] = text_seg_pos
output = (uttids, output_dict)
return output
def build_collate_fn(
sr: int=24000,
n_fft: int=2048,
hop_length: int=300,
win_length: int=None,
n_mels: int=80,
fmin: int=80,
fmax: int=7600,
mlm_prob: float=0.8,
mean_phn_span: int=8,
train: bool=False,
seg_emb: bool=False,
epoch: int=-1, ):
feats_extract_class = LogMelFBank
feats_extract = feats_extract_class(
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mels=n_mels,
fmin=fmin,
fmax=fmax)
pad_speech = False
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_prob_factor = 0.8
return MLMCollateFn(
feats_extract=feats_extract,
float_pad_value=0.0,
int_pad_value=0,
mlm_prob=mlm_prob * mlm_prob_factor,
mean_phn_span=mean_phn_span,
pad_speech=pad_speech,
seg_emb=seg_emb)
......@@ -4,6 +4,68 @@ import numpy as np
import paddle
# mask phones
def phones_masking(xs_pad: paddle.Tensor,
src_mask: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
mlm_prob: float=0.8,
mean_phn_span: int=8,
span_bdy: paddle.Tensor=None):
'''
Args:
xs_pad (paddle.Tensor): input speech (B, Tmax, D).
src_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
mlm_prob (float):
mean_phn_span (int):
span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2).
Returns:
paddle.Tensor[bool]: masked position of input speech (B, Tmax).
'''
bz, sent_len, _ = paddle.shape(xs_pad)
masked_pos = paddle.zeros((bz, sent_len))
if mlm_prob == 1.0:
masked_pos += 1
elif mean_phn_span == 0:
# only speech
length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_idxs = random_spans_noise_mask(
length=length, mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span).nonzero()
masked_pos[:, masked_phn_idxs] = 1
else:
for idx in range(bz):
# for inference
if span_bdy is not None:
for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
masked_pos[idx, s:e] = 1
# for training
else:
length = align_start_lens[idx]
if length < 2:
continue
masked_phn_idxs = random_spans_noise_mask(
length=length,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span).nonzero()
masked_start = align_start[idx][masked_phn_idxs].tolist()
masked_end = align_end[idx][masked_phn_idxs].tolist()
for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask
masked_pos = paddle.cast(masked_pos, 'bool')
return masked_pos
# mask speech and phones
def phones_text_masking(xs_pad: paddle.Tensor,
src_mask: paddle.Tensor,
text_pad: paddle.Tensor,
......@@ -11,37 +73,56 @@ def phones_text_masking(xs_pad: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
mlm_prob: float,
mean_phn_span: float,
mlm_prob: float=0.8,
mean_phn_span: int=8,
span_bdy: paddle.Tensor=None):
'''
Args:
xs_pad (paddle.Tensor): input speech (B, Tmax, D).
src_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
text_pad (paddle.Tensor): input text (B, Tmax2).
text_mask (paddle.Tensor): mask of text (B, 1, Tmax2).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
mlm_prob (float):
mean_phn_span (int):
span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2).
Returns:
paddle.Tensor[bool]: masked position of input speech (B, Tmax).
paddle.Tensor[bool]: masked position of input text (B, Tmax2).
'''
bz, sent_len, _ = paddle.shape(xs_pad)
masked_pos = paddle.zeros((bz, sent_len))
_, text_len = paddle.shape(text_pad)
text_mask_num_lower = math.ceil(text_len * (1 - mlm_prob) * 0.5)
text_masked_pos = paddle.zeros((bz, text_len))
y_masks = None
if mlm_prob == 1.0:
masked_pos += 1
# y_masks = tril_masks
elif mean_phn_span == 0:
# only speech
length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_idxs = random_spans_noise_mask(length, mlm_prob,
mean_phn_span).nonzero()
masked_phn_idxs = random_spans_noise_mask(
length=length, mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span).nonzero()
masked_pos[:, masked_phn_idxs] = 1
else:
for idx in range(bz):
# for inference
if span_bdy is not None:
for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
masked_pos[idx, s:e] = 1
# for training
else:
length = align_start_lens[idx]
if length < 2:
continue
masked_phn_idxs = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero()
length=length,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span).nonzero()
unmasked_phn_idxs = list(
set(range(length)) - set(masked_phn_idxs[0].tolist()))
np.random.shuffle(unmasked_phn_idxs)
......@@ -58,60 +139,76 @@ def phones_text_masking(xs_pad: paddle.Tensor,
masked_pos = paddle.cast(masked_pos, 'bool')
text_masked_pos = paddle.cast(text_masked_pos, 'bool')
return masked_pos, text_masked_pos, y_masks
return masked_pos, text_masked_pos
def get_seg_pos_reduce_duration(
speech_pad: paddle.Tensor,
text_pad: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
sega_emb: bool,
masked_pos: paddle.Tensor,
feats_lens: paddle.Tensor, ):
def get_seg_pos(speech_pad: paddle.Tensor,
text_pad: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
seg_emb: bool=False):
'''
Args:
speech_pad (paddle.Tensor): input speech (B, Tmax, D).
text_pad (paddle.Tensor): input text (B, Tmax2).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
seg_emb (bool): whether to use segment embedding.
Returns:
paddle.Tensor[int]: n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax).
eg:
Tensor(shape=[1, 328], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 3 , 4 , 4 , 4 ,
5 , 5 , 5 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 7 , 7 , 7 , 7 , 7 , 7 , 7 ,
7 , 8 , 8 , 8 , 8 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 10, 10, 10, 10, 10,
10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13,
13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15,
15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17,
17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20,
20, 20, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 23, 23,
23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25,
25, 26, 26, 26, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 29,
29, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 32,
32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 35, 35,
35, 35, 35, 35, 35, 35, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38,
38, 38, 0 , 0 ]])
paddle.Tensor[int]: n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2).
eg:
Tensor(shape=[1, 38], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38]])
'''
bz, speech_len, _ = paddle.shape(speech_pad)
text_seg_pos = paddle.zeros(paddle.shape(text_pad))
speech_seg_pos = paddle.zeros((bz, speech_len), dtype=text_pad.dtype)
_, text_len = paddle.shape(text_pad)
reordered_idx = paddle.zeros((bz, speech_len), dtype=align_start_lens.dtype)
text_seg_pos = paddle.zeros((bz, text_len), dtype='int64')
speech_seg_pos = paddle.zeros((bz, speech_len), dtype='int64')
durations = paddle.ones((bz, speech_len), dtype=align_start_lens.dtype)
max_reduced_length = 0
if not sega_emb:
return speech_pad, masked_pos, speech_seg_pos, text_seg_pos, durations
if not seg_emb:
return speech_seg_pos, text_seg_pos
for idx in range(bz):
first_idx = []
last_idx = []
align_length = align_start_lens[idx]
for j in range(align_length):
s, e = align_start[idx][j], align_end[idx][j]
if j == 0:
if paddle.sum(masked_pos[idx][0:s]) == 0:
first_idx.extend(range(0, s))
else:
first_idx.extend([0])
last_idx.extend(range(1, s))
if paddle.sum(masked_pos[idx][s:e]) == 0:
first_idx.extend(range(s, e))
else:
first_idx.extend([s])
last_idx.extend(range(s + 1, e))
durations[idx][s] = e - s
speech_seg_pos[idx][s:e] = j + 1
text_seg_pos[idx][j] = j + 1
max_reduced_length = max(
len(first_idx) + feats_lens[idx] - e, max_reduced_length)
first_idx.extend(range(e, speech_len))
reordered_idx[idx] = paddle.to_tensor(
(first_idx + last_idx), dtype=align_start_lens.dtype)
feats_lens[idx] = len(first_idx)
reordered_idx = reordered_idx[:, :max_reduced_length]
speech_seg_pos[idx, s:e] = j + 1
text_seg_pos[idx, j] = j + 1
return reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens
return speech_seg_pos, text_seg_pos
def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
# randomly select the range of speech and text to mask during training
def random_spans_noise_mask(length: int,
mlm_prob: float=0.8,
mean_phn_span: float=8):
"""This function is copy of `random_spans_helper
<https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
......@@ -126,7 +223,7 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
noise_density: a float - approximate density of output mask
mean_noise_span_length: a number
Returns:
a boolean tensor with shape [length]
np.ndarray: a boolean tensor with shape [length]
"""
orig_length = length
......@@ -171,87 +268,3 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
is_noise = np.equal(span_num % 2, 1)
return is_noise[:orig_length]
def pad_to_longformer_att_window(text: paddle.Tensor,
max_len: int,
max_tlen: int,
attention_window: int=0):
round = max_len % attention_window
if round != 0:
max_tlen += (attention_window - round)
n_batch = paddle.shape(text)[0]
text_pad = paddle.zeros(
(n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype)
for i in range(n_batch):
text_pad[i, :paddle.shape(text[i])[0]] = text[i]
else:
text_pad = text[:, :max_tlen]
return text_pad, max_tlen
def phones_masking(xs_pad: paddle.Tensor,
src_mask: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
mlm_prob: float,
mean_phn_span: int,
span_bdy: paddle.Tensor=None):
bz, sent_len, _ = paddle.shape(xs_pad)
masked_pos = paddle.zeros((bz, sent_len))
y_masks = None
if mlm_prob == 1.0:
masked_pos += 1
elif mean_phn_span == 0:
# only speech
length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_idxs = random_spans_noise_mask(length, mlm_prob,
mean_phn_span).nonzero()
masked_pos[:, masked_phn_idxs] = 1
else:
for idx in range(bz):
if span_bdy is not None:
for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
masked_pos[idx, s:e] = 1
else:
length = align_start_lens[idx]
if length < 2:
continue
masked_phn_idxs = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero()
masked_start = align_start[idx][masked_phn_idxs].tolist()
masked_end = align_end[idx][masked_phn_idxs].tolist()
for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask
masked_pos = paddle.cast(masked_pos, 'bool')
return masked_pos, y_masks
def get_seg_pos(speech_pad: paddle.Tensor,
text_pad: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
sega_emb: bool):
bz, speech_len, _ = paddle.shape(speech_pad)
_, text_len = paddle.shape(text_pad)
text_seg_pos = paddle.zeros((bz, text_len), dtype='int64')
speech_seg_pos = paddle.zeros((bz, speech_len), dtype='int64')
if not sega_emb:
return speech_seg_pos, text_seg_pos
for idx in range(bz):
align_length = align_start_lens[idx]
for j in range(align_length):
s, e = align_start[idx][j], align_end[idx][j]
speech_seg_pos[idx, s:e] = j + 1
text_seg_pos[idx, j] = j + 1
return speech_seg_pos, text_seg_pos
#!/usr/bin/env python3
import argparse
import os
import random
from pathlib import Path
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
import librosa
import numpy as np
......@@ -15,60 +11,42 @@ import paddle
import soundfile as sf
import torch
from paddle import nn
from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
from align import alignment
from align import alignment_zh
from dataset import get_seg_pos
from dataset import get_seg_pos_reduce_duration
from dataset import pad_to_longformer_att_window
from dataset import phones_masking
from dataset import phones_text_masking
from model_paddle import build_model_from_file
from read_text import load_num_sequence_text
from read_text import read_2column_text
from sedit_arg_parser import parse_args
from utils import build_vocoder_from_file
from utils import evaluate_durations
from utils import get_voc_out
from utils import is_chinese
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.modules.nets_utils import pad_list
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from align import alignment
from align import alignment_zh
from align import words2phns
from align import words2phns_zh
from collect_fn import build_collate_fn
from mlm import build_model_from_file
from read_text import load_num_sequence_text
from read_text import read_2col_text
# from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
random.seed(0)
np.random.seed(0)
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin'
def plot_mel_and_vocode_wav(uid: str,
wav_path: str,
prefix: str="./prompt/dev/",
def plot_mel_and_vocode_wav(wav_path: str,
source_lang: str='english',
target_lang: str='english',
model_name: str="conformer",
full_origin_str: str="",
model_name: str="paddle_checkpoint_en",
old_str: str="",
new_str: str="",
duration_preditor_path: str=None,
use_pt_vocoder: bool=False,
sid: str=None,
non_autoreg: bool=True):
wav_org, input_feat, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output(
uid=uid,
prefix=prefix,
wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output(
source_lang=source_lang,
target_lang=target_lang,
model_name=model_name,
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
duration_preditor_path=duration_preditor_path,
use_teacher_forcing=non_autoreg,
sid=sid)
use_teacher_forcing=non_autoreg)
masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
......@@ -79,10 +57,10 @@ def plot_mel_and_vocode_wav(uid: str,
vocoder = load_vocoder('vctk_parallel_wavegan.v1.long')
replaced_wav = vocoder(output_feat).cpu().numpy()
else:
replaced_wav = get_voc_out(output_feat, target_lang)
replaced_wav = get_voc_out(output_feat)
elif target_lang == 'chinese':
replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, target_lang)
replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat)
old_time_bdy = [hop_length * x for x in old_span_bdy]
new_time_bdy = [hop_length * x for x in new_span_bdy]
......@@ -109,125 +87,6 @@ def plot_mel_and_vocode_wav(uid: str,
return data_dict, old_span_bdy
def get_unk_phns(word_str: str):
tmpbase = '/tmp/tp.'
f = open(tmpbase + 'temp.words', 'w')
f.write(word_str)
f.close()
os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase +
'temp.phons')
f = open(tmpbase + 'temp.phons', 'r')
lines2 = f.readline().strip().split()
f.close()
phns = []
for phn in lines2:
phons = phn.replace('\n', '').replace(' ', '')
seq = []
j = 0
while (j < len(phons)):
if (phons[j] > 'Z'):
if (phons[j] == 'j'):
seq.append('JH')
elif (phons[j] == 'h'):
seq.append('HH')
else:
seq.append(phons[j].upper())
j += 1
else:
p = phons[j:j + 2]
if (p == 'WH'):
seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
seq.append(p)
elif (p == 'AX'):
seq.append('AH0')
else:
seq.append(p + '1')
j += 2
phns.extend(seq)
return phns
def words2phns(line: str):
dictfile = MODEL_DIR_EN + '/dict'
line = line.strip()
words = []
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd)
phns.extend(get_unk_phns(wrd))
else:
wrd2phns[str(index) +
"_" + wrd.upper()] = word2phns_dict[wrd.upper()].split()
phns.extend(word2phns_dict[wrd.upper()].split())
return phns, wrd2phns
def words2phns_zh(line: str):
dictfile = MODEL_DIR_ZH + '/dict'
line = line.strip()
words = []
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
print("出现非法词错误,请输入正确的文本...")
else:
wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split()
phns.extend(word2phns_dict[wrd].split())
return phns, wrd2phns
def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "")
vocoder_file = download_pretrained_model(vocoder_tag)
......@@ -236,50 +95,52 @@ def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
return vocoder
def load_model(model_name: str):
def load_model(model_name: str="paddle_checkpoint_en"):
config_path = './pretrained_model/{}/config.yaml'.format(model_name)
model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
mlm_model, args = build_model_from_file(
mlm_model, conf = build_model_from_file(
config_file=config_path, model_file=model_path)
return mlm_model, args
return mlm_model, conf
def read_data(uid: str, prefix: str):
mfa_text = read_2column_text(prefix + '/text')[uid]
mfa_wav_path = read_2column_text(prefix + '/wav.scp')[uid]
if 'mnt' not in mfa_wav_path:
mfa_wav_path = prefix.split('dump')[0] + mfa_wav_path
def read_data(uid: str, prefix: os.PathLike):
# 获取 uid 对应的文本
mfa_text = read_2col_text(prefix + '/text')[uid]
# 获取 uid 对应的音频路径
mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid]
if not os.path.isabs(mfa_wav_path):
mfa_wav_path = prefix + mfa_wav_path
return mfa_text, mfa_wav_path
def get_align_data(uid: str, prefix: str):
def get_align_data(uid: str, prefix: os.PathLike):
mfa_path = prefix + "mfa_"
mfa_text = read_2column_text(mfa_path + 'text')[uid]
mfa_text = read_2col_text(mfa_path + 'text')[uid]
mfa_start = load_num_sequence_text(
mfa_path + 'start', loader_type='text_float')[uid]
mfa_end = load_num_sequence_text(
mfa_path + 'end', loader_type='text_float')[uid]
mfa_wav_path = read_2column_text(mfa_path + 'wav.scp')[uid]
mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid]
return mfa_text, mfa_start, mfa_end, mfa_wav_path
# 获取需要被 mask 的 mel 帧的范围
def get_masked_mel_bdy(mfa_start: List[float],
mfa_end: List[float],
fs: int,
hop_length: int,
span_to_repl: List[List[int]]):
align_start = paddle.to_tensor(mfa_start).unsqueeze(0)
align_end = paddle.to_tensor(mfa_end).unsqueeze(0)
align_start = paddle.floor(fs * align_start / hop_length).int()
align_end = paddle.floor(fs * align_end / hop_length).int()
align_start = np.array(mfa_start)
align_end = np.array(mfa_end)
align_start = np.floor(fs * align_start / hop_length).astype('int')
align_end = np.floor(fs * align_end / hop_length).astype('int')
if span_to_repl[0] >= len(mfa_start):
span_bdy = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]]
span_bdy = [align_end[-1], align_end[-1]]
else:
span_bdy = [
align_start[0].tolist()[span_to_repl[0]],
align_end[0].tolist()[span_to_repl[1] - 1]
align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1]
]
return span_bdy
return span_bdy, align_start, align_end
def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
......@@ -317,18 +178,22 @@ def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
return dic
def get_max_idx(dic):
return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1]
def get_phns_and_spans(wav_path: str,
old_str: str="",
new_str: str="",
source_lang: str="english",
target_lang: str="english"):
append_new_str = (old_str == new_str[:len(old_str)])
is_append = (old_str == new_str[:len(old_str)])
old_phns, mfa_start, mfa_end = [], [], []
# source
if source_lang == "english":
times2, word2phns = alignment(wav_path, old_str)
intervals, word2phns = alignment(wav_path, old_str)
elif source_lang == "chinese":
times2, word2phns = alignment_zh(wav_path, old_str)
intervals, word2phns = alignment_zh(wav_path, old_str)
_, tp_word2phns = words2phns_zh(old_str)
for key, value in tp_word2phns.items():
......@@ -337,51 +202,46 @@ def get_phns_and_spans(wav_path: str,
tp_word2phns[key] = cur_val
word2phns = recover_dict(word2phns, tp_word2phns)
else:
assert source_lang == "chinese" or source_lang == "english", "source_lang is wrong..."
assert source_lang == "chinese" or source_lang == "english", \
"source_lang is wrong..."
for item in times2:
for item in intervals:
old_phns.append(item[0])
mfa_start.append(float(item[1]))
mfa_end.append(float(item[2]))
old_phns.append(item[0])
if append_new_str and (source_lang != target_lang):
is_cross_lingual_clone = True
# target
if is_append and (source_lang != target_lang):
cross_lingual_clone = True
else:
is_cross_lingual_clone = False
cross_lingual_clone = False
if is_cross_lingual_clone:
new_str_origin = new_str[:len(old_str)]
new_str_append = new_str[len(old_str):]
if cross_lingual_clone:
str_origin = new_str[:len(old_str)]
str_append = new_str[len(old_str):]
if target_lang == "chinese":
new_phns_origin, new_origin_word2phns = words2phns(new_str_origin)
new_phns_append, temp_new_append_word2phns = words2phns_zh(
new_str_append)
phns_origin, origin_word2phns = words2phns(str_origin)
phns_append, append_word2phns_tmp = words2phns_zh(str_append)
elif target_lang == "english":
# 原始句子
new_phns_origin, new_origin_word2phns = words2phns_zh(
new_str_origin)
# clone句子
new_phns_append, temp_new_append_word2phns = words2phns(
new_str_append)
phns_origin, origin_word2phns = words2phns_zh(str_origin)
# clone 句子
phns_append, append_word2phns_tmp = words2phns(str_append)
else:
assert target_lang == "chinese" or target_lang == "english", \
"cloning is not support for this language, please check it."
new_phns = new_phns_origin + new_phns_append
new_phns = phns_origin + phns_append
new_append_word2phns = {}
length = len(new_origin_word2phns)
for key, value in temp_new_append_word2phns.items():
append_word2phns = {}
length = len(origin_word2phns)
for key, value in append_word2phns_tmp.items():
idx, wrd = key.split('_')
new_append_word2phns[str(int(idx) + length) + '_' + wrd] = value
new_word2phns = dict(
list(new_origin_word2phns.items()) + list(
new_append_word2phns.items()))
append_word2phns[str(int(idx) + length) + '_' + wrd] = value
new_word2phns = origin_word2phns.copy()
new_word2phns.update(append_word2phns)
else:
if source_lang == target_lang and target_lang == "english":
......@@ -417,16 +277,17 @@ def get_phns_and_spans(wav_path: str,
right_idx = 0
new_phns_right = []
sp_count = 0
word2phns_max_idx = int(list(word2phns.keys())[-1].split('_')[0])
new_word2phns_max_idx = int(list(new_word2phns.keys())[-1].split('_')[0])
word2phns_max_idx = get_max_idx(word2phns)
new_word2phns_max_idx = get_max_idx(new_word2phns)
new_phns_mid = []
if append_new_str:
if is_append:
new_phns_right = []
new_phns_mid = new_phns[left_idx:]
span_to_repl[0] = len(new_phns_left)
span_to_add[0] = len(new_phns_left)
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
span_to_repl[1] = len(old_phns) - len(new_phns_right)
# speech edit
else:
for key in list(word2phns.keys())[::-1]:
idx, wrd = key.split('_')
......@@ -451,47 +312,57 @@ def get_phns_and_spans(wav_path: str,
len(old_phns))
break
new_phns = new_phns_left + new_phns_mid + new_phns_right
'''
For that reason cover should not be given.
For that reason cover is impossible to be given.
span_to_repl: [17, 23] "should not"
span_to_add: [17, 30] "is impossible to"
'''
return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add
def duration_adjust_factor(original_dur: List[int],
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
def duration_adjust_factor(orig_dur: List[int],
pred_dur: List[int],
phns: List[str]):
length = 0
factor_list = []
for ori, pred, phn in zip(original_dur, pred_dur, phns):
for orig, pred, phn in zip(orig_dur, pred_dur, phns):
if pred == 0 or phn == 'sp':
continue
else:
factor_list.append(ori / pred)
factor_list.append(orig / pred)
factor_list = np.array(factor_list)
factor_list.sort()
if len(factor_list) < 5:
return 1
length = 2
return np.average(factor_list[length:-length])
def prepare_features_with_duration(uid: str,
prefix: str,
wav_path: str,
mlm_model: nn.Layer,
source_lang: str="English",
target_lang: str="English",
old_str: str="",
new_str: str="",
duration_preditor_path: str=None,
sid: str=None,
mask_reconstruct: bool=False,
duration_adjust: bool=True,
start_end_sp: bool=False,
train_args=None):
wav_org, rate = librosa.load(
wav_path, sr=train_args.feats_extract_conf['fs'])
fs = train_args.feats_extract_conf['fs']
hop_length = train_args.feats_extract_conf['hop_length']
avg = np.average(factor_list[length:-length])
return avg
def prep_feats_with_dur(wav_path: str,
mlm_model: nn.Layer,
source_lang: str="English",
target_lang: str="English",
old_str: str="",
new_str: str="",
mask_reconstruct: bool=False,
duration_adjust: bool=True,
start_end_sp: bool=False,
fs: int=24000,
hop_length: int=300):
'''
Returns:
np.ndarray: new wav, replace the part to be edited in original wav with 0
List[str]: new phones
List[float]: mfa start of new wav
List[float]: mfa end of new wav
List[int]: masked mel boundary of original wav
List[int]: masked mel boundary of new wav
'''
wav_org, _ = librosa.load(wav_path, sr=fs)
mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans(
wav_path=wav_path,
......@@ -503,144 +374,130 @@ def prepare_features_with_duration(uid: str,
if start_end_sp:
if new_phns[-1] != 'sp':
new_phns = new_phns + ['sp']
if target_lang == "english":
old_durations = evaluate_durations(old_phns, target_lang=target_lang)
elif target_lang == "chinese":
if source_lang == "english":
old_durations = evaluate_durations(
old_phns, target_lang=source_lang)
elif source_lang == "chinese":
old_durations = evaluate_durations(
old_phns, target_lang=source_lang)
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if target_lang == "english" or target_lang == "chinese":
old_durs = evaluate_durations(old_phns, target_lang=source_lang)
else:
assert target_lang == "chinese" or target_lang == "english", "calculate duration_predict is not support for this language..."
assert target_lang == "chinese" or target_lang == "english", \
"calculate duration_predict is not support for this language..."
original_old_durations = [e - s for e, s in zip(mfa_end, mfa_start)]
orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
if '[MASK]' in new_str:
new_phns = old_phns
span_to_add = span_to_repl
d_factor_left = duration_adjust_factor(
original_old_durations[:span_to_repl[0]],
old_durations[:span_to_repl[0]], old_phns[:span_to_repl[0]])
orig_dur=orig_old_durs[:span_to_repl[0]],
pred_dur=old_durs[:span_to_repl[0]],
phns=old_phns[:span_to_repl[0]])
d_factor_right = duration_adjust_factor(
original_old_durations[span_to_repl[1]:],
old_durations[span_to_repl[1]:], old_phns[span_to_repl[1]:])
orig_dur=orig_old_durs[span_to_repl[1]:],
pred_dur=old_durs[span_to_repl[1]:],
phns=old_phns[span_to_repl[1]:])
d_factor = (d_factor_left + d_factor_right) / 2
new_durations_adjusted = [d_factor * i for i in old_durations]
new_durs_adjusted = [d_factor * i for i in old_durs]
else:
if duration_adjust:
d_factor = duration_adjust_factor(original_old_durations,
old_durations, old_phns)
d_factor = duration_adjust_factor(
orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
print("d_factor:", d_factor)
d_factor = d_factor * 1.25
else:
d_factor = 1
if target_lang == "english":
new_durations = evaluate_durations(
new_phns, target_lang=target_lang)
elif target_lang == "chinese":
new_durations = evaluate_durations(
new_phns, target_lang=target_lang)
new_durations_adjusted = [d_factor * i for i in new_durations]
if span_to_repl[0] < len(old_phns) and old_phns[span_to_repl[
0]] == new_phns[span_to_add[0]]:
new_durations_adjusted[span_to_add[0]] = original_old_durations[
span_to_repl[0]]
if span_to_repl[1] < len(old_phns) and span_to_add[1] < len(new_phns):
if old_phns[span_to_repl[1]] == new_phns[span_to_add[1]]:
new_durations_adjusted[span_to_add[1]] = original_old_durations[
span_to_repl[1]]
new_span_duration_sum = sum(
new_durations_adjusted[span_to_add[0]:span_to_add[1]])
old_span_duration_sum = sum(
original_old_durations[span_to_repl[0]:span_to_repl[1]])
duration_offset = new_span_duration_sum - old_span_duration_sum
if target_lang == "english" or target_lang == "chinese":
new_durs = evaluate_durations(new_phns, target_lang=target_lang)
else:
assert target_lang == "chinese" or target_lang == "english", \
"calculate duration_predict is not support for this language..."
new_durs_adjusted = [d_factor * i for i in new_durs]
new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
dur_offset = new_span_dur_sum - old_span_dur_sum
new_mfa_start = mfa_start[:span_to_repl[0]]
new_mfa_end = mfa_end[:span_to_repl[0]]
for i in new_durations_adjusted[span_to_add[0]:span_to_add[1]]:
for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
if len(new_mfa_end) == 0:
new_mfa_start.append(0)
new_mfa_end.append(i)
else:
new_mfa_start.append(new_mfa_end[-1])
new_mfa_end.append(new_mfa_end[-1] + i)
new_mfa_start += [i + duration_offset for i in mfa_start[span_to_repl[1]:]]
new_mfa_end += [i + duration_offset for i in mfa_end[span_to_repl[1]:]]
new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
# 3. get new wav
# 3. get new wav
# 在原始句子后拼接
if span_to_repl[0] >= len(mfa_start):
left_idx = len(wav_org)
right_idx = left_idx
# 在原始句子中间替换
else:
left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs))
right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs))
new_blank_wav = np.zeros(
(int(np.ceil(new_span_duration_sum * fs)), ), dtype=wav_org.dtype)
new_wav_org = np.concatenate(
[wav_org[:left_idx], new_blank_wav, wav_org[right_idx:]])
blank_wav = np.zeros(
(int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype)
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
new_wav = np.concatenate(
[wav_org[:left_idx], blank_wav, wav_org[right_idx:]])
# 4. get old and new mel span to be mask
# [92, 92]
old_span_bdy = get_masked_mel_bdy(mfa_start, mfa_end, fs, hop_length,
span_to_repl)
old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy(
mfa_start=mfa_start,
mfa_end=mfa_end,
fs=fs,
hop_length=hop_length,
span_to_repl=span_to_repl)
# [92, 174]
new_span_bdy = get_masked_mel_bdy(new_mfa_start, new_mfa_end, fs,
hop_length, span_to_add)
return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy
def prepare_features(uid: str,
mlm_model: nn.Layer,
processor,
wav_path: str,
prefix: str="./prompt/dev/",
source_lang: str="english",
target_lang: str="english",
old_str: str="",
new_str: str="",
duration_preditor_path: str=None,
sid: str=None,
duration_adjust: bool=True,
start_end_sp: bool=False,
mask_reconstruct: bool=False,
train_args=None):
wav_org, phns_list, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prepare_features_with_duration(
uid=uid,
prefix=prefix,
# new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别
new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy(
mfa_start=new_mfa_start,
mfa_end=new_mfa_end,
fs=fs,
hop_length=hop_length,
span_to_repl=span_to_add)
# old_span_bdy, new_span_bdy 是帧级别的范围
return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy
def prep_feats(mlm_model: nn.Layer,
wav_path: str,
source_lang: str="english",
target_lang: str="english",
old_str: str="",
new_str: str="",
duration_adjust: bool=True,
start_end_sp: bool=False,
mask_reconstruct: bool=False,
fs: int=24000,
hop_length: int=300,
token_list: List[str]=[]):
wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur(
source_lang=source_lang,
target_lang=target_lang,
mlm_model=mlm_model,
old_str=old_str,
new_str=new_str,
wav_path=wav_path,
duration_preditor_path=duration_preditor_path,
sid=sid,
duration_adjust=duration_adjust,
start_end_sp=start_end_sp,
mask_reconstruct=mask_reconstruct,
train_args=train_args)
speech = wav_org
align_start = np.array(mfa_start)
align_end = np.array(mfa_end)
token_to_id = {item: i for i, item in enumerate(train_args.token_list)}
text = np.array(
list(
map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns_list)))
fs=fs,
hop_length=hop_length)
token_to_id = {item: i for i, item in enumerate(token_list)}
text = np.array(
list(map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns)))
span_bdy = np.array(new_span_bdy)
batch = [('1', {
"speech": speech,
"align_start": align_start,
"align_end": align_end,
"speech": wav,
"align_start": mfa_start,
"align_end": mfa_end,
"text": text,
"span_bdy": span_bdy
})]
......@@ -648,375 +505,135 @@ def prepare_features(uid: str,
return batch, old_span_bdy, new_span_bdy
def decode_with_model(uid: str,
mlm_model: nn.Layer,
processor,
def decode_with_model(mlm_model: nn.Layer,
collate_fn,
wav_path: str,
prefix: str="./prompt/dev/",
source_lang: str="english",
target_lang: str="english",
old_str: str="",
new_str: str="",
duration_preditor_path: str=None,
sid: str=None,
decoder: bool=False,
use_teacher_forcing: bool=False,
duration_adjust: bool=True,
start_end_sp: bool=False,
train_args=None):
fs, hop_length = train_args.feats_extract_conf[
'fs'], train_args.feats_extract_conf['hop_length']
batch, old_span_bdy, new_span_bdy = prepare_features(
uid=uid,
prefix=prefix,
fs: int=24000,
hop_length: int=300,
token_list: List[str]=[]):
batch, old_span_bdy, new_span_bdy = prep_feats(
source_lang=source_lang,
target_lang=target_lang,
mlm_model=mlm_model,
processor=processor,
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
duration_preditor_path=duration_preditor_path,
sid=sid,
duration_adjust=duration_adjust,
start_end_sp=start_end_sp,
train_args=train_args)
fs=fs,
hop_length=hop_length,
token_list=token_list)
feats = collate_fn(batch)[1]
if 'text_masked_pos' in feats.keys():
feats.pop('text_masked_pos')
for k, v in feats.items():
feats[k] = paddle.to_tensor(v)
rtn = mlm_model.inference(
**feats, span_bdy=new_span_bdy, use_teacher_forcing=use_teacher_forcing)
output = rtn['feat_gen']
output = mlm_model.inference(
text=feats['text'],
speech=feats['speech'],
masked_pos=feats['masked_pos'],
speech_mask=feats['speech_mask'],
text_mask=feats['text_mask'],
speech_seg_pos=feats['speech_seg_pos'],
text_seg_pos=feats['text_seg_pos'],
span_bdy=new_span_bdy,
use_teacher_forcing=use_teacher_forcing)
if 0 in output[0].shape and 0 not in output[-1].shape:
output_feat = paddle.concat(
output[1:-1] + [output[-1].squeeze()], axis=0).cpu()
output[1:-1] + [output[-1].squeeze()], axis=0)
elif 0 not in output[0].shape and 0 in output[-1].shape:
output_feat = paddle.concat(
[output[0].squeeze()] + output[1:-1], axis=0).cpu()
[output[0].squeeze()] + output[1:-1], axis=0)
elif 0 in output[0].shape and 0 in output[-1].shape:
output_feat = paddle.concat(output[1:-1], axis=0).cpu()
output_feat = paddle.concat(output[1:-1], axis=0)
else:
output_feat = paddle.concat(
[output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)],
axis=0).cpu()
wav_org, _ = librosa.load(
wav_path, sr=train_args.feats_extract_conf['fs'])
return wav_org, None, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
class MLMCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(self,
feats_extract,
float_pad_value: Union[float, int]=0.0,
int_pad_value: int=-32768,
not_sequence: Collection[str]=(),
mlm_prob: float=0.8,
mean_phn_span: int=8,
attention_window: int=0,
pad_speech: bool=False,
sega_emb: bool=False,
duration_collect: bool=False,
text_masking: bool=False):
self.mlm_prob = mlm_prob
self.mean_phn_span = mean_phn_span
self.feats_extract = feats_extract
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.attention_window = attention_window
self.pad_speech = pad_speech
self.sega_emb = sega_emb
self.duration_collect = duration_collect
self.text_masking = text_masking
def __repr__(self):
return (f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})")
def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
return mlm_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
mlm_prob=self.mlm_prob,
mean_phn_span=self.mean_phn_span,
feats_extract=self.feats_extract,
attention_window=self.attention_window,
pad_speech=self.pad_speech,
sega_emb=self.sega_emb,
duration_collect=self.duration_collect,
text_masking=self.text_masking)
def mlm_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int]=0.0,
int_pad_value: int=-32768,
not_sequence: Collection[str]=(),
mlm_prob: float=0.8,
mean_phn_span: int=8,
feats_extract=None,
attention_window: int=0,
pad_speech: bool=False,
sega_emb: bool=False,
duration_collect: bool=False,
text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(not k.endswith("_lens")
for k in data[0]), f"*_lens is reserved: {list(data[0])}"
output = {}
for key in data[0]:
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list = [paddle.to_tensor(a) for a in array_list]
# tensor: (Batch, Length, ...)
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
# lens: (Batch,)
if key not in not_sequence:
lens = paddle.to_tensor(
[d[key].shape[0] for d in data], dtype=paddle.int64)
output[key + "_lens"] = lens
feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
feats = paddle.to_tensor(feats)
feats_lens = paddle.shape(feats)[0]
feats = paddle.unsqueeze(feats, 0)
if 'text' not in output:
text = paddle.zeros(paddle.shape(feats_lens.unsqueeze(-1))) - 2
text_lens = paddle.zeros(paddle.shape(feats_lens)) + 1
max_tlen = 1
align_start = paddle.zeros(paddle.shape(text))
align_end = paddle.zeros(paddle.shape(text))
align_start_lens = paddle.zeros(paddle.shape(feats_lens))
sega_emb = False
mean_phn_span = 0
mlm_prob = 0.15
else:
text = output["text"]
text_lens = output["text_lens"]
align_start = output["align_start"]
align_start_lens = output["align_start_lens"]
align_end = output["align_end"]
align_start = paddle.floor(feats_extract.sr * align_start /
feats_extract.hop_length).int()
align_end = paddle.floor(feats_extract.sr * align_end /
feats_extract.hop_length).int()
max_tlen = max(text_lens)
max_slen = max(feats_lens)
speech_pad = feats[:, :max_slen]
if attention_window > 0 and pad_speech:
speech_pad, max_slen = pad_to_longformer_att_window(
speech_pad, max_slen, max_slen, attention_window)
max_len = max_slen + max_tlen
if attention_window > 0:
text_pad, max_tlen = pad_to_longformer_att_window(
text, max_len, max_tlen, attention_window)
else:
text_pad = text
text_mask = make_non_pad_mask(
text_lens, text_pad, length_dim=1).unsqueeze(-2)
if attention_window > 0:
text_mask = text_mask * 2
speech_mask = make_non_pad_mask(
feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
span_bdy = None
if 'span_bdy' in output.keys():
span_bdy = output['span_bdy']
if text_masking:
masked_pos, text_masked_pos, _ = phones_text_masking(
speech_pad, speech_mask, text_pad, text_mask, align_start,
align_end, align_start_lens, mlm_prob, mean_phn_span, span_bdy)
else:
text_masked_pos = paddle.zeros(paddle.shape(text_pad))
masked_pos, _ = phones_masking(speech_pad, speech_mask, align_start,
align_end, align_start_lens, mlm_prob,
mean_phn_span, span_bdy)
output_dict = {}
if duration_collect and 'text' in output:
reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens = get_seg_pos_reduce_duration(
speech_pad, text_pad, align_start, align_end, align_start_lens,
sega_emb, masked_pos, feats_lens)
speech_mask = make_non_pad_mask(
feats_lens, speech_pad[:, :reordered_idx.shape[1], 0],
length_dim=1).unsqueeze(-2)
output_dict['durations'] = durations
output_dict['reordered_idx'] = reordered_idx
else:
speech_seg_pos, text_seg_pos = get_seg_pos(speech_pad, text_pad,
align_start, align_end,
align_start_lens, sega_emb)
output_dict['speech'] = speech_pad
output_dict['text'] = text_pad
output_dict['masked_pos'] = masked_pos
output_dict['text_masked_pos'] = text_masked_pos
output_dict['speech_mask'] = speech_mask
output_dict['text_mask'] = text_mask
output_dict['speech_seg_pos'] = speech_seg_pos
output_dict['text_seg_pos'] = text_seg_pos
output_dict['speech_lens'] = output["speech_lens"]
output_dict['text_lens'] = text_lens
output = (uttids, output_dict)
return output
def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
# -> Callable[
# [Collection[Tuple[str, Dict[str, np.ndarray]]]],
# Tuple[List[str], Dict[str, Tensor]],
# ]:
# assert check_argument_types()
# return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
feats_extract_class = LogMelFBank
if args.feats_extract_conf['win_length'] is None:
args.feats_extract_conf['win_length'] = args.feats_extract_conf['n_fft']
args_dic = {}
for k, v in args.feats_extract_conf.items():
if k == 'fs':
args_dic['sr'] = v
else:
args_dic[k] = v
feats_extract = feats_extract_class(**args_dic)
sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False
if args.encoder_conf['selfattention_layer_type'] == 'longformer':
attention_window = args.encoder_conf['attention_window']
pad_speech = True if 'pre_speech_layer' in args.encoder_conf and args.encoder_conf[
'pre_speech_layer'] > 0 else False
else:
attention_window = 0
pad_speech = False
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_prob_factor = 0.8
if 'duration_predictor_layers' in args.model_conf.keys(
) and args.model_conf['duration_predictor_layers'] > 0:
duration_collect = True
else:
duration_collect = False
return MLMCollateFn(
feats_extract,
float_pad_value=0.0,
int_pad_value=0,
mlm_prob=args.model_conf['mlm_prob'] * mlm_prob_factor,
mean_phn_span=args.model_conf['mean_phn_span'],
attention_window=attention_window,
pad_speech=pad_speech,
sega_emb=sega_emb,
duration_collect=duration_collect)
def get_mlm_output(uid: str,
wav_path: str,
prefix: str="./prompt/dev/",
model_name: str="conformer",
axis=0)
wav_org, _ = librosa.load(wav_path, sr=fs)
return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
def get_mlm_output(wav_path: str,
model_name: str="paddle_checkpoint_en",
source_lang: str="english",
target_lang: str="english",
old_str: str="",
new_str: str="",
duration_preditor_path: str=None,
sid: str=None,
decoder: bool=False,
use_teacher_forcing: bool=False,
duration_adjust: bool=True,
start_end_sp: bool=False):
mlm_model, train_args = load_model(model_name)
mlm_model, train_conf = load_model(model_name)
mlm_model.eval()
processor = None
collate_fn = build_collate_fn(train_args, False)
collate_fn = build_collate_fn(
sr=train_conf.feats_extract_conf['fs'],
n_fft=train_conf.feats_extract_conf['n_fft'],
hop_length=train_conf.feats_extract_conf['hop_length'],
win_length=train_conf.feats_extract_conf['win_length'],
n_mels=train_conf.feats_extract_conf['n_mels'],
fmin=train_conf.feats_extract_conf['fmin'],
fmax=train_conf.feats_extract_conf['fmax'],
mlm_prob=train_conf['mlm_prob'],
mean_phn_span=train_conf['mean_phn_span'],
train=False,
seg_emb=train_conf.encoder_conf['input_layer'] == 'sega_mlm')
return decode_with_model(
uid=uid,
prefix=prefix,
source_lang=source_lang,
target_lang=target_lang,
mlm_model=mlm_model,
processor=processor,
collate_fn=collate_fn,
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
duration_preditor_path=duration_preditor_path,
sid=sid,
decoder=decoder,
use_teacher_forcing=use_teacher_forcing,
duration_adjust=duration_adjust,
start_end_sp=start_end_sp,
train_args=train_args)
fs=train_conf.feats_extract_conf['fs'],
hop_length=train_conf.feats_extract_conf['hop_length'],
token_list=train_conf.token_list)
def evaluate(uid: str,
source_lang: str="english",
target_lang: str="english",
use_pt_vocoder: bool=False,
prefix: str="./prompt/dev/",
model_name: str="conformer",
old_str: str="",
prefix: os.PathLike="./prompt/dev/",
model_name: str="paddle_checkpoint_en",
new_str: str="",
prompt_decoding: bool=False,
task_name: str=None):
duration_preditor_path = None
spemd = None
full_origin_str, wav_path = read_data(uid=uid, prefix=prefix)
# get origin text and path of origin wav
old_str, wav_path = read_data(uid=uid, prefix=prefix)
if task_name == 'edit':
new_str = new_str
elif task_name == 'synthesize':
new_str = full_origin_str + new_str
new_str = old_str + new_str
else:
new_str = full_origin_str + ' '.join(
[ch for ch in new_str if is_chinese(ch)])
new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)])
print('new_str is ', new_str)
if not old_str:
old_str = full_origin_str
results_dict, old_span = plot_mel_and_vocode_wav(
uid=uid,
prefix=prefix,
source_lang=source_lang,
target_lang=target_lang,
model_name=model_name,
wav_path=wav_path,
full_origin_str=full_origin_str,
old_str=old_str,
new_str=new_str,
use_pt_vocoder=use_pt_vocoder,
duration_preditor_path=duration_preditor_path,
sid=spemd)
use_pt_vocoder=use_pt_vocoder)
return results_dict
......
import argparse
import logging
import math
import os
import sys
from pathlib import Path
from typing import Dict
from typing import List
from typing import Optional
......@@ -20,17 +17,18 @@ for dir_name in os.listdir(pypath):
if os.path.isdir(dir_path):
sys.path.append(dir_path)
from paddlespeech.s2t.utils.error_rate import ErrorCalculator
from paddlespeech.t2s.modules.activation import get_activation
from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule
from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer
from paddlespeech.t2s.modules.masked_fill import masked_fill
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.tacotron2.decoder import Postnet
from paddlespeech.t2s.modules.transformer.embedding import LegacyRelPositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding
from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling
from paddlespeech.t2s.modules.transformer.attention import LegacyRelPositionMultiHeadedAttention
from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention
from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention
from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward
......@@ -39,65 +37,10 @@ from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredCo
from paddlespeech.t2s.modules.transformer.repeat import repeat
from paddlespeech.t2s.modules.layer_norm import LayerNorm
class LegacyRelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]:
return
pe = paddle.zeros((paddle.shape(x)[1], self.d_model))
if self.reverse:
position = paddle.arange(
paddle.shape(x)[1] - 1, -1, -1.0,
dtype=paddle.float32).unsqueeze(1)
else:
position = paddle.arange(
0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe
def forward(self, x: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[:, :paddle.shape(x)[1]]
return self.dropout(x), self.dropout(pos_emb)
from yacs.config import CfgNode
# MLM -> Mask Language Model
class mySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._sub_layers.values():
......@@ -108,12 +51,8 @@ class mySequential(nn.Sequential):
return inputs
class NewMaskInputLayer(nn.Layer):
__constants__ = ['out_features']
out_features: int
def __init__(self, out_features: int, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
class MaskInputLayer(nn.Layer):
def __init__(self, out_features: int) -> None:
super().__init__()
self.mask_feature = paddle.create_parameter(
shape=(1, 1, out_features),
......@@ -121,109 +60,14 @@ class NewMaskInputLayer(nn.Layer):
default_initializer=paddle.nn.initializer.Assign(
paddle.normal(shape=(1, 1, out_features))))
def forward(self, input: paddle.Tensor, masked_pos=None) -> paddle.Tensor:
def forward(self, input: paddle.Tensor,
masked_pos: paddle.Tensor=None) -> paddle.Tensor:
masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input)
masked_input = masked_fill(input, masked_pos, 0) + masked_fill(
paddle.expand_as(self.mask_feature, input), ~masked_pos, 0)
return masked_input
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = paddle.create_parameter(
shape=(self.h, self.d_k),
dtype='float32',
default_initializer=paddle.nn.initializer.XavierUniform())
self.pos_bias_v = paddle.create_parameter(
shape=(self.h, self.d_k),
dtype='float32',
default_initializer=paddle.nn.initializer.XavierUniform())
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x(Tensor): Input tensor (batch, head, time1, time2).
Returns:
Tensor:Output tensor.
"""
b, h, t1, t2 = paddle.shape(x)
zero_pad = paddle.zeros((b, h, t1, 1))
x_padded = paddle.concat([zero_pad, x], axis=-1)
x_padded = paddle.reshape(x_padded, [b, h, t2 + 1, t1])
# only keep the positions from 0 to time2
x = paddle.reshape(x_padded[:, :, 1:], [b, h, t1, t2])
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query(Tensor): Query tensor (#batch, time1, size).
key(Tensor): Key tensor (#batch, time2, size).
value(Tensor): Value tensor (#batch, time2, size).
pos_emb(Tensor): Positional embedding tensor (#batch, time1, size).
mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
Returns:
Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
# (batch, time1, head, d_k)
q = paddle.transpose(q, [0, 2, 1, 3])
n_batch_pos = paddle.shape(pos_emb)[0]
p = paddle.reshape(
self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k])
# (batch, head, time1, d_k)
p = paddle.transpose(p, [0, 2, 1, 3])
# (batch, head, time1, d_k)
q_with_bias_u = paddle.transpose((q + self.pos_bias_u), [0, 2, 1, 3])
# (batch, head, time1, d_k)
q_with_bias_v = paddle.transpose((q + self.pos_bias_v), [0, 2, 1, 3])
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = paddle.matmul(q_with_bias_u,
paddle.transpose(k, [0, 1, 3, 2]))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd = paddle.matmul(q_with_bias_v,
paddle.transpose(p, [0, 1, 3, 2]))
matrix_bd = self.rel_shift(matrix_bd)
# (batch, head, time1, time2)
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
class MLMEncoder(nn.Layer):
"""Conformer encoder module.
......@@ -253,47 +97,42 @@ class MLMEncoder(nn.Layer):
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
indices start from 1.
if not None, intermediate outputs are returned (which changes return type
signature.)
"""
def __init__(self,
idim,
vocab_size=0,
idim: int,
vocab_size: int=0,
pre_speech_layer: int=0,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="conv2d",
normalize_before=True,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=1,
macaron_style=False,
pos_enc_layer_type="abs_pos",
attention_dim: int=256,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
attention_dropout_rate: float=0.0,
input_layer: str="conv2d",
normalize_before: bool=True,
concat_after: bool=False,
positionwise_layer_type: str="linear",
positionwise_conv_kernel_size: int=1,
macaron_style: bool=False,
pos_enc_layer_type: str="abs_pos",
pos_enc_class=None,
selfattention_layer_type="selfattn",
activation_type="swish",
use_cnn_module=False,
zero_triu=False,
cnn_module_kernel=31,
padding_idx=-1,
stochastic_depth_rate=0.0,
intermediate_layers=None,
text_masking=False):
selfattention_layer_type: str="selfattn",
activation_type: str="swish",
use_cnn_module: bool=False,
zero_triu: bool=False,
cnn_module_kernel: int=31,
padding_idx: int=-1,
stochastic_depth_rate: float=0.0,
text_masking: bool=False):
"""Construct an Encoder object."""
super().__init__()
self._output_size = attention_dim
self.text_masking = text_masking
if self.text_masking:
self.text_masking_layer = NewMaskInputLayer(attention_dim)
self.text_masking_layer = MaskInputLayer(attention_dim)
activation = get_activation(activation_type)
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
......@@ -330,7 +169,7 @@ class MLMEncoder(nn.Layer):
elif input_layer == "mlm":
self.segment_emb = None
self.speech_embed = mySequential(
NewMaskInputLayer(idim),
MaskInputLayer(idim),
nn.Linear(idim, attention_dim),
nn.LayerNorm(attention_dim),
nn.ReLU(),
......@@ -343,7 +182,7 @@ class MLMEncoder(nn.Layer):
self.segment_emb = nn.Embedding(
500, attention_dim, padding_idx=padding_idx)
self.speech_embed = mySequential(
NewMaskInputLayer(idim),
MaskInputLayer(idim),
nn.Linear(idim, attention_dim),
nn.LayerNorm(attention_dim),
nn.ReLU(),
......@@ -365,7 +204,6 @@ class MLMEncoder(nn.Layer):
# self-attention module definition
if selfattention_layer_type == "selfattn":
logging.info("encoder self-attention layer type = self-attention")
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_dropout_rate, )
......@@ -375,8 +213,6 @@ class MLMEncoder(nn.Layer):
encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_dropout_rate, )
elif selfattention_layer_type == "rel_selfattn":
logging.info(
"encoder self-attention layer type = relative self-attention")
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, attention_dim,
......@@ -436,49 +272,38 @@ class MLMEncoder(nn.Layer):
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
self.intermediate_layers = intermediate_layers
def forward(self,
speech_pad,
text_pad,
masked_pos,
speech_mask=None,
text_mask=None,
speech_seg_pos=None,
text_seg_pos=None):
speech: paddle.Tensor,
text: paddle.Tensor,
masked_pos: paddle.Tensor,
speech_mask: paddle.Tensor=None,
text_mask: paddle.Tensor=None,
speech_seg_pos: paddle.Tensor=None,
text_seg_pos: paddle.Tensor=None):
"""Encode input sequence.
"""
if masked_pos is not None:
speech_pad = self.speech_embed(speech_pad, masked_pos)
speech = self.speech_embed(speech, masked_pos)
else:
speech_pad = self.speech_embed(speech_pad)
# pure speech input
if -2 in np.array(text_pad):
text_pad = text_pad + 3
text_mask = paddle.unsqueeze(bool(text_pad), 1)
text_seg_pos = paddle.zeros_like(text_pad)
text_pad = self.text_embed(text_pad)
text_pad = (text_pad[0] + self.segment_emb(text_seg_pos),
text_pad[1])
text_seg_pos = None
elif text_pad is not None:
text_pad = self.text_embed(text_pad)
speech = self.speech_embed(speech)
if text is not None:
text = self.text_embed(text)
if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb:
speech_seg_emb = self.segment_emb(speech_seg_pos)
text_seg_emb = self.segment_emb(text_seg_pos)
text_pad = (text_pad[0] + text_seg_emb, text_pad[1])
speech_pad = (speech_pad[0] + speech_seg_emb, speech_pad[1])
text = (text[0] + text_seg_emb, text[1])
speech = (speech[0] + speech_seg_emb, speech[1])
if self.pre_speech_encoders:
speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask)
speech, _ = self.pre_speech_encoders(speech, speech_mask)
if text_pad is not None:
xs = paddle.concat([speech_pad[0], text_pad[0]], axis=1)
xs_pos_emb = paddle.concat([speech_pad[1], text_pad[1]], axis=1)
if text is not None:
xs = paddle.concat([speech[0], text[0]], axis=1)
xs_pos_emb = paddle.concat([speech[1], text[1]], axis=1)
masks = paddle.concat([speech_mask, text_mask], axis=-1)
else:
xs = speech_pad[0]
xs_pos_emb = speech_pad[1]
xs = speech[0]
xs_pos_emb = speech[1]
masks = speech_mask
xs, masks = self.encoders((xs, xs_pos_emb), masks)
......@@ -492,7 +317,7 @@ class MLMEncoder(nn.Layer):
class MLMDecoder(MLMEncoder):
def forward(self, xs, masks, masked_pos=None, segment_emb=None):
def forward(self, xs: paddle.Tensor, masks: paddle.Tensor):
"""Encode input sequence.
Args:
......@@ -504,51 +329,19 @@ class MLMDecoder(MLMEncoder):
paddle.Tensor: Mask tensor (#batch, time).
"""
if not self.training:
masked_pos = None
xs = self.embed(xs)
if segment_emb:
xs = (xs[0] + segment_emb, xs[1])
if self.intermediate_layers is None:
xs, masks = self.encoders(xs, masks)
else:
intermediate_outputs = []
for layer_idx, encoder_layer in enumerate(self.encoders):
xs, masks = encoder_layer(xs, masks)
if (self.intermediate_layers is not None and
layer_idx + 1 in self.intermediate_layers):
encoder_output = xs
# intermediate branches also require normalization.
if self.normalize_before:
encoder_output = self.after_norm(encoder_output)
intermediate_outputs.append(encoder_output)
xs, masks = self.encoders(xs, masks)
if isinstance(xs, tuple):
xs = xs[0]
if self.normalize_before:
xs = self.after_norm(xs)
if self.intermediate_layers is not None:
return xs, masks, intermediate_outputs
return xs, masks
def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
round = max_len % attention_window
if round != 0:
max_tlen += (attention_window - round)
n_batch = paddle.shape(text)[0]
text_pad = paddle.zeros(
shape=(n_batch, max_tlen, *paddle.shape(text[0])[1:]),
dtype=text.dtype)
for i in range(n_batch):
text_pad[i, :paddle.shape(text[i])[0]] = text[i]
else:
text_pad = text[:, :max_tlen]
return text_pad, max_tlen
class MLMModel(nn.Layer):
# encoder and decoder is nn.Layer, not str
class MLM(nn.Layer):
def __init__(self,
token_list: Union[Tuple[str, ...], List[str]],
odim: int,
......@@ -557,44 +350,15 @@ class MLMModel(nn.Layer):
postnet_layers: int=0,
postnet_chans: int=0,
postnet_filts: int=0,
ignore_id: int=-1,
lsm_weight: float=0.0,
length_normalized_loss: bool=False,
report_cer: bool=True,
report_wer: bool=True,
sym_space: str="<space>",
sym_blank: str="<blank>",
masking_schema: str="span",
mean_phn_span: int=3,
mlm_prob: float=0.25,
dynamic_mlm_prob=False,
decoder_seg_pos=False,
text_masking=False):
text_masking: bool=False):
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.odim = odim
self.ignore_id = ignore_id
self.token_list = token_list.copy()
self.encoder = encoder
self.decoder = decoder
self.vocab_size = encoder.text_embed[0]._num_embeddings
if report_cer or report_wer:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer)
else:
self.error_calculator = None
self.mlm_weight = 1.0
self.mlm_prob = mlm_prob
self.mlm_layer = 12
self.finetune_wo_mlm = True
self.max_span = 50
self.min_span = 4
self.mean_phn_span = mean_phn_span
self.masking_schema = masking_schema
if self.decoder is None or not (hasattr(self.decoder,
'output_layer') and
self.decoder.output_layer is not None):
......@@ -606,15 +370,9 @@ class MLMModel(nn.Layer):
self.encoder.text_embed[0]._embedding_dim,
self.vocab_size,
weight_attr=self.encoder.text_embed[0]._weight_attr)
self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id)
else:
self.text_sfc = None
self.text_mlm_loss = None
self.decoder_seg_pos = decoder_seg_pos
if lsm_weight > 50:
self.l1_loss_func = nn.MSELoss()
else:
self.l1_loss_func = nn.L1Loss(reduction='none')
self.postnet = (None if postnet_layers == 0 else Postnet(
idim=self.encoder._output_size,
odim=odim,
......@@ -624,119 +382,77 @@ class MLMModel(nn.Layer):
use_batch_norm=True,
dropout_rate=0.5, ))
def collect_feats(self,
speech,
speech_lens,
text,
text_lens,
masked_pos,
speech_mask,
text_mask,
speech_seg_pos,
text_seg_pos,
y_masks=None) -> Dict[str, paddle.Tensor]:
return {"feats": speech, "feats_lens": speech_lens}
def forward(self, batch, speech_seg_pos, y_masks=None):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad']
if self.decoder is not None:
ys_in = self._add_first_frame_and_remove_last_frame(
batch['speech_pad'])
encoder_out, h_masks = self.encoder(**batch)
if self.decoder is not None:
zs, _ = self.decoder(ys_in, y_masks, encoder_out,
bool(h_masks),
self.encoder.segment_emb(speech_seg_pos))
speech_hidden_states = zs
else:
speech_hidden_states = encoder_out[:, :paddle.shape(batch[
'speech_pad'])[1], :]
if self.sfc is not None:
before_outs = paddle.reshape(
self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
else:
before_outs = speech_hidden_states
if self.postnet is not None:
after_outs = before_outs + paddle.transpose(
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
(0, 2, 1))
else:
after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_pos']
def inference(
self,
speech,
text,
masked_pos,
speech_mask,
text_mask,
speech_seg_pos,
text_seg_pos,
span_bdy,
y_masks=None,
speech_lens=None,
text_lens=None,
feats: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
sids: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
threshold: float=0.5,
minlenratio: float=0.0,
maxlenratio: float=10.0,
speech: paddle.Tensor,
text: paddle.Tensor,
masked_pos: paddle.Tensor,
speech_mask: paddle.Tensor,
text_mask: paddle.Tensor,
speech_seg_pos: paddle.Tensor,
text_seg_pos: paddle.Tensor,
span_bdy: List[int],
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
'''
Args:
speech (paddle.Tensor): input speech (B, Tmax, D).
text (paddle.Tensor): input text (B, Tmax2).
masked_pos (paddle.Tensor): masked position of input speech (B, Tmax)
speech_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
text_mask (paddle.Tensor): mask of text (B, 1, Tmax2).
speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax).
text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2).
span_bdy (List[int]): masked mel boundary of input speech (2,)
use_teacher_forcing (bool): whether to use teacher forcing
Returns:
List[Tensor]:
eg:
[Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])]
'''
batch = dict(
speech_pad=speech,
text_pad=text,
masked_pos=masked_pos,
speech_mask=speech_mask,
text_mask=text_mask,
speech_seg_pos=speech_seg_pos,
text_seg_pos=text_seg_pos, )
# # inference with teacher forcing
# hs, h_masks = self.encoder(**batch)
outs = [batch['speech_pad'][:, :span_bdy[0]]]
outs = [speech[:, :span_bdy[0]]]
z_cache = None
if use_teacher_forcing:
before, zs, _, _ = self.forward(
batch, speech_seg_pos, y_masks=y_masks)
before_outs, zs, *_ = self.forward(
speech=speech,
text=text,
masked_pos=masked_pos,
speech_mask=speech_mask,
text_mask=text_mask,
speech_seg_pos=speech_seg_pos,
text_seg_pos=text_seg_pos)
if zs is None:
zs = before
zs = before_outs
outs += [zs[0][span_bdy[0]:span_bdy[1]]]
outs += [batch['speech_pad'][:, span_bdy[1]:]]
return dict(feat_gen=outs)
outs += [speech[:, span_bdy[1]:]]
return outs
return None
def _add_first_frame_and_remove_last_frame(
self, ys: paddle.Tensor) -> paddle.Tensor:
ys_in = paddle.concat(
[
paddle.zeros(
shape=(paddle.shape(ys)[0], 1, paddle.shape(ys)[2]),
dtype=ys.dtype), ys[:, :-1]
],
axis=1)
return ys_in
class MLMEncAsDecoderModel(MLMModel):
def forward(self, batch, speech_seg_pos, y_masks=None):
class MLMEncAsDecoder(MLM):
def forward(self,
speech: paddle.Tensor,
text: paddle.Tensor,
masked_pos: paddle.Tensor,
speech_mask: paddle.Tensor,
text_mask: paddle.Tensor,
speech_seg_pos: paddle.Tensor,
text_seg_pos: paddle.Tensor):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad']
encoder_out, h_masks = self.encoder(**batch) # segment_emb
encoder_out, h_masks = self.encoder(
speech=speech,
text=text,
masked_pos=masked_pos,
speech_mask=speech_mask,
text_mask=text_mask,
speech_seg_pos=speech_seg_pos,
text_seg_pos=text_seg_pos)
if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks)
else:
zs = encoder_out
speech_hidden_states = zs[:, :paddle.shape(batch['speech_pad'])[1], :]
speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
if self.sfc is not None:
before_outs = paddle.reshape(
self.sfc(speech_hidden_states),
......@@ -749,53 +465,35 @@ class MLMEncAsDecoderModel(MLMModel):
[0, 2, 1])
else:
after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_pos']
class MLMDualMaksingModel(MLMModel):
def _calc_mlm_loss(self,
before_outs: paddle.Tensor,
after_outs: paddle.Tensor,
text_outs: paddle.Tensor,
batch):
xs_pad = batch['speech_pad']
text_pad = batch['text_pad']
masked_pos = batch['masked_pos']
text_masked_pos = batch['text_masked_pos']
mlm_loss_pos = masked_pos > 0
loss = paddle.sum(
self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
if after_outs is not None:
loss += paddle.sum(
self.l1_loss_func(
paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
loss_mlm = paddle.sum((loss * paddle.reshape(
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
loss_text = paddle.sum((self.text_mlm_loss(
paddle.reshape(text_outs, (-1, self.vocab_size)),
paddle.reshape(text_pad, (-1))) * paddle.reshape(
text_masked_pos, (-1)))) / paddle.sum((text_masked_pos) + 1e-10)
return loss_mlm, loss_text
def forward(self, batch, speech_seg_pos, y_masks=None):
return before_outs, after_outs, None
class MLMDualMaksing(MLM):
def forward(self,
speech: paddle.Tensor,
text: paddle.Tensor,
masked_pos: paddle.Tensor,
speech_mask: paddle.Tensor,
text_mask: paddle.Tensor,
speech_seg_pos: paddle.Tensor,
text_seg_pos: paddle.Tensor):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, h_masks = self.encoder(**batch) # segment_emb
encoder_out, h_masks = self.encoder(
speech=speech,
text=text,
masked_pos=masked_pos,
speech_mask=speech_mask,
text_mask=text_mask,
speech_seg_pos=speech_seg_pos,
text_seg_pos=text_seg_pos)
if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks)
else:
zs = encoder_out
speech_hidden_states = zs[:, :paddle.shape(batch['speech_pad'])[1], :]
speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
if self.text_sfc:
text_hiddent_states = zs[:, paddle.shape(batch['speech_pad'])[
1]:, :]
text_hiddent_states = zs[:, paddle.shape(speech)[1]:, :]
text_outs = paddle.reshape(
self.text_sfc(text_hiddent_states),
(paddle.shape(text_hiddent_states)[0], -1, self.vocab_size))
......@@ -811,27 +509,25 @@ class MLMDualMaksingModel(MLMModel):
[0, 2, 1])
else:
after_outs = None
return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_pos'],batch['text_masked_pos']
return before_outs, after_outs, text_outs
def build_model_from_file(config_file, model_file):
state_dict = paddle.load(model_file)
model_class = MLMDualMaksingModel if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \
else MLMEncAsDecoderModel
model_class = MLMDualMaksing if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \
else MLMEncAsDecoder
# 构建模型
args = yaml.safe_load(Path(config_file).open("r", encoding="utf-8"))
args = argparse.Namespace(**args)
model = build_model(args, model_class)
with open(config_file) as f:
conf = CfgNode(yaml.safe_load(f))
model = build_model(conf, model_class)
model.set_state_dict(state_dict)
return model, args
return model, conf
def build_model(args: argparse.Namespace,
model_class=MLMEncAsDecoderModel) -> MLMModel:
# select encoder and decoder here
def build_model(args: argparse.Namespace, model_class=MLMEncAsDecoder) -> MLM:
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
......@@ -842,9 +538,8 @@ def build_model(args: argparse.Namespace,
token_list = list(args.token_list)
else:
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size }")
vocab_size = len(token_list)
odim = 80
pos_enc_class = ScaledPositionalEncoding if args.use_scaled_pos_enc else PositionalEncoding
......@@ -857,17 +552,8 @@ def build_model(args: argparse.Namespace,
if conformer_rel_pos_type == "legacy":
if conformer_pos_enc_layer_type == "rel_pos":
conformer_pos_enc_layer_type = "legacy_rel_pos"
logging.warning(
"Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' "
"due to the compatibility. If you want to use the new one, "
"please use conformer_pos_enc_layer_type = 'latest'.")
if conformer_self_attn_layer_type == "rel_selfattn":
conformer_self_attn_layer_type = "legacy_rel_selfattn"
logging.warning(
"Fallback to "
"conformer_self_attn_layer_type = 'legacy_rel_selfattn' "
"due to the compatibility. If you want to use the new one, "
"please use conformer_pos_enc_layer_type = 'latest'.")
elif conformer_rel_pos_type == "latest":
assert conformer_pos_enc_layer_type != "legacy_rel_pos"
assert conformer_self_attn_layer_type != "legacy_rel_selfattn"
......
import paddle
from paddle import nn
class MLMLoss(nn.Layer):
def __init__(self,
lsm_weight: float=0.1,
ignore_id: int=-1,
text_masking: bool=False):
super().__init__()
if text_masking:
self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id)
if lsm_weight > 50:
self.l1_loss_func = nn.MSELoss()
else:
self.l1_loss_func = nn.L1Loss(reduction='none')
self.text_masking = text_masking
def forward(self,
speech: paddle.Tensor,
before_outs: paddle.Tensor,
after_outs: paddle.Tensor,
masked_pos: paddle.Tensor,
text: paddle.Tensor=None,
text_outs: paddle.Tensor=None,
text_masked_pos: paddle.Tensor=None):
xs_pad = speech
mlm_loss_pos = masked_pos > 0
loss = paddle.sum(
self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
if after_outs is not None:
loss += paddle.sum(
self.l1_loss_func(
paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
loss_mlm = paddle.sum((loss * paddle.reshape(
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
if self.text_masking:
loss_text = paddle.sum((self.text_mlm_loss(
paddle.reshape(text_outs, (-1, self.vocab_size)),
paddle.reshape(text, (-1))) * paddle.reshape(
text_masked_pos,
(-1)))) / paddle.sum((text_masked_pos) + 1e-10)
return loss_mlm, loss_text
return loss_mlm
......@@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = paddle.create_parameter(
shape=(self.h, self.d_k),
dtype='float32',
default_initializer=paddle.nn.initializer.XavierUniform())
self.pos_bias_v = paddle.create_parameter(
shape=(self.h, self.d_k),
dtype='float32',
default_initializer=paddle.nn.initializer.XavierUniform())
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x(Tensor): Input tensor (batch, head, time1, time2).
Returns:
Tensor:Output tensor.
"""
b, h, t1, t2 = paddle.shape(x)
zero_pad = paddle.zeros((b, h, t1, 1))
x_padded = paddle.concat([zero_pad, x], axis=-1)
x_padded = paddle.reshape(x_padded, [b, h, t2 + 1, t1])
# only keep the positions from 0 to time2
x = paddle.reshape(x_padded[:, :, 1:], [b, h, t1, t2])
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query(Tensor): Query tensor (#batch, time1, size).
key(Tensor): Key tensor (#batch, time2, size).
value(Tensor): Value tensor (#batch, time2, size).
pos_emb(Tensor): Positional embedding tensor (#batch, time1, size).
mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
Returns:
Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
# (batch, time1, head, d_k)
q = paddle.transpose(q, [0, 2, 1, 3])
n_batch_pos = paddle.shape(pos_emb)[0]
p = paddle.reshape(
self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k])
# (batch, head, time1, d_k)
p = paddle.transpose(p, [0, 2, 1, 3])
# (batch, head, time1, d_k)
q_with_bias_u = paddle.transpose((q + self.pos_bias_u), [0, 2, 1, 3])
# (batch, head, time1, d_k)
q_with_bias_v = paddle.transpose((q + self.pos_bias_v), [0, 2, 1, 3])
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = paddle.matmul(q_with_bias_u,
paddle.transpose(k, [0, 1, 3, 2]))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd = paddle.matmul(q_with_bias_v,
paddle.transpose(p, [0, 1, 3, 2]))
matrix_bd = self.rel_shift(matrix_bd)
# (batch, head, time1, time2)
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
......@@ -185,3 +185,63 @@ class RelPositionalEncoding(nn.Layer):
pe_size = paddle.shape(self.pe)
pos_emb = self.pe[:, pe_size[1] // 2 - T + 1:pe_size[1] // 2 + T, ]
return self.dropout(x), self.dropout(pos_emb)
class LegacyRelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]:
return
pe = paddle.zeros((paddle.shape(x)[1], self.d_model))
if self.reverse:
position = paddle.arange(
paddle.shape(x)[1] - 1, -1, -1.0,
dtype=paddle.float32).unsqueeze(1)
else:
position = paddle.arange(
0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe
def forward(self, x: paddle.Tensor):
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[:, :paddle.shape(x)[1]]
return self.dropout(x), self.dropout(pos_emb)
......@@ -5,7 +5,7 @@ from typing import List
from typing import Union
def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
def read_2col_text(path: Union[Path, str]) -> Dict[str, str]:
"""Read a text file having 2 column as dict object.
Examples:
......@@ -13,7 +13,7 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
key1 /some/path/a.wav
key2 /some/path/b.wav
>>> read_2column_text('wav.scp')
>>> read_2col_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
......
......@@ -65,12 +65,6 @@ def parse_args():
help="mean and standard deviation used to normalize spectrogram when training voc."
)
# other
parser.add_argument(
'--lang',
type=str,
default='en',
help='Choose model language. zh or en')
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
# parser.add_argument("--test_metadata", type=str, help="test metadata.")
......
......@@ -32,7 +32,6 @@ model_alias = {
"paddlespeech.t2s.models.parallel_wavegan:PWGInference",
}
def is_chinese(ch):
if u'\u4e00' <= ch <= u'\u9fff':
return True
......@@ -55,12 +54,10 @@ def build_vocoder_from_file(
raise ValueError(f"{vocoder_file} is not supported format.")
def get_voc_out(mel, target_lang: str="chinese"):
def get_voc_out(mel):
# vocoder
args = parse_args()
assert target_lang == "chinese" or target_lang == "english", "In get_voc_out function, target_lang is illegal..."
# print("current vocoder: ", args.voc)
with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f))
......@@ -167,19 +164,23 @@ def get_voc_inference(
return voc_inference
def evaluate_durations(phns: List[str],
target_lang: str="chinese",
fs: int=24000,
hop_length: int=300):
def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300):
args = parse_args()
if target_lang == 'english':
args.lang = 'en'
args.am = "fastspeech2_ljspeech"
args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_lang == 'chinese':
args.lang = 'zh'
args.am = "fastspeech2_csmsc"
args.am_config="download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict ="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[])
if args.ngpu == 0:
paddle.set_device("cpu")
elif args.ngpu > 0:
......@@ -187,8 +188,6 @@ def evaluate_durations(phns: List[str],
else:
print("ngpu should >= 0 !")
assert target_lang == "chinese" or target_lang == "english", "In evaluate_durations function, target_lang is illegal..."
# Init body.
with open(args.am_config) as f:
am_config = CfgNode(yaml.safe_load(f))
......@@ -203,21 +202,19 @@ def evaluate_durations(phns: List[str],
speaker_dict=args.speaker_dict,
return_am=True)
torch_phns = phns
vocab_phones = {}
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
for tone, id in phn_id:
vocab_phones[tone] = int(id)
vocab_size = len(vocab_phones)
phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns]
phonemes = [phn if phn in vocab_phones else "sp" for phn in phns]
phone_ids = [vocab_phones[item] for item in phonemes]
phone_ids_new = phone_ids
phone_ids_new.append(vocab_size - 1)
phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64))
normalized_mel, d_outs, p_outs, e_outs = am.inference(
phone_ids_new, spk_id=None, spk_emb=None)
_, d_outs, _, _ = am.inference(phone_ids_new, spk_id=None, spk_emb=None)
pre_d_outs = d_outs
phoneme_durations_new = pre_d_outs * hop_length / fs
phoneme_durations_new = phoneme_durations_new.tolist()[:-1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册