提交 9224659c 编写于 作者: 小湉湉's avatar 小湉湉

add docstring

上级 76b654cb
...@@ -39,9 +39,9 @@ ERNIE-SAT 中我们提出了两项创新: ...@@ -39,9 +39,9 @@ ERNIE-SAT 中我们提出了两项创新:
### 2.预训练模型 ### 2.预训练模型
预训练模型 ERNIE-SAT 的模型如下所示: 预训练模型 ERNIE-SAT 的模型如下所示:
- [ERNIE-SAT_ZH](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-zh.tar.gz) - [ERNIE-SAT_ZH](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-zh.tar.gz)
- [ERNIE-SAT_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en.tar.gz) - [ERNIE-SAT_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en.tar.gz)
- [ERNIE-SAT_ZH_and_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en_zh.tar.gz) - [ERNIE-SAT_ZH_and_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en_zh.tar.gz)
创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压: 创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
...@@ -108,7 +108,7 @@ prompt/dev ...@@ -108,7 +108,7 @@ prompt/dev
3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset} 3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。 4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。
5. `--lang` 对应模型的语言可以是 `zh``en` 5. `--lang` 对应模型的语言可以是 `zh``en`
6. `--ngpu` 要使用的GPU数,如果 ngpu==0,则使用 cpu。 6. `--ngpu` 要使用的 GPU 数,如果 ngpu==0,则使用 cpu。
7. ` --model_name` 模型名称 7. ` --model_name` 模型名称
8. ` --uid` 特定提示(prompt)语音的 id 8. ` --uid` 特定提示(prompt)语音的 id
9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本) 9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本)
...@@ -125,4 +125,3 @@ sh run_sedit_en.sh # 语音编辑任务(英文) ...@@ -125,4 +125,3 @@ sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh # 个性化语音合成任务(英文) sh run_gen_en.sh # 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆) sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
``` ```
#!/usr/bin/env python
""" Usage: """ Usage:
align.py wavfile trsfile outwordfile outphonefile align.py wavfile trsfile outwordfile outphonefile
""" """
import multiprocessing as mp
import os import os
import sys import sys
from tqdm import tqdm
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN = 'tools/aligner/english' MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin' MODEL_DIR_ZH = 'tools/aligner/mandarin'
...@@ -15,6 +11,142 @@ HVITE = 'tools/htk/HTKTools/HVite' ...@@ -15,6 +11,142 @@ HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy' HCOPY = 'tools/htk/HTKTools/HCopy'
def get_unk_phns(word_str: str):
tmpbase = '/tmp/tp.'
f = open(tmpbase + 'temp.words', 'w')
f.write(word_str)
f.close()
os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase +
'temp.phons')
f = open(tmpbase + 'temp.phons', 'r')
lines2 = f.readline().strip().split()
f.close()
phns = []
for phn in lines2:
phons = phn.replace('\n', '').replace(' ', '')
seq = []
j = 0
while (j < len(phons)):
if (phons[j] > 'Z'):
if (phons[j] == 'j'):
seq.append('JH')
elif (phons[j] == 'h'):
seq.append('HH')
else:
seq.append(phons[j].upper())
j += 1
else:
p = phons[j:j + 2]
if (p == 'WH'):
seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
seq.append(p)
elif (p == 'AX'):
seq.append('AH0')
else:
seq.append(p + '1')
j += 2
phns.extend(seq)
return phns
def words2phns(line: str):
'''
Args:
line (str): input text.
eg: for that reason cover is impossible to be given.
Returns:
List[str]: phones of input text.
eg:
['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0',
'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1',
'G', 'IH1', 'V', 'AH0', 'N']
Dict(str, str): key - idx_word
value - phones
eg:
{'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],
'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'],
'6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']}
'''
dictfile = MODEL_DIR_EN + '/dict'
line = line.strip()
words = []
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd)
phns.extend(get_unk_phns(wrd))
else:
wrd2phns[str(index) +
"_" + wrd.upper()] = word2phns_dict[wrd.upper()].split()
phns.extend(word2phns_dict[wrd.upper()].split())
return phns, wrd2phns
def words2phns_zh(line: str):
dictfile = MODEL_DIR_ZH + '/dict'
line = line.strip()
words = []
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
print("出现非法词错误,请输入正确的文本...")
else:
wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split()
phns.extend(word2phns_dict[wrd].split())
return phns, wrd2phns
def prep_txt_zh(line: str, tmpbase: str, dictfile: str): def prep_txt_zh(line: str, tmpbase: str, dictfile: str):
words = [] words = []
...@@ -82,7 +214,7 @@ def prep_txt_en(line: str, tmpbase, dictfile): ...@@ -82,7 +214,7 @@ def prep_txt_en(line: str, tmpbase, dictfile):
try: try:
os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase + os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase +
'_unk.phons') '_unk.phons')
except: except Exception:
print('english2phoneme error!') print('english2phoneme error!')
sys.exit(1) sys.exit(1)
...@@ -148,19 +280,22 @@ def _get_user(): ...@@ -148,19 +280,22 @@ def _get_user():
def alignment(wav_path: str, text: str): def alignment(wav_path: str, text: str):
'''
intervals: List[phn, start, end]
'''
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files #prepare wav and trs files
try: try:
os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -') os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -')
except: except Exception:
print('sox error!') print('sox error!')
return None return None
#prepare clean_transcript file #prepare clean_transcript file
try: try:
prep_txt_en(text, tmpbase, MODEL_DIR_EN + '/dict') prep_txt_en(line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_EN + '/dict')
except: except Exception:
print('prep_txt error!') print('prep_txt error!')
return None return None
...@@ -169,7 +304,7 @@ def alignment(wav_path: str, text: str): ...@@ -169,7 +304,7 @@ def alignment(wav_path: str, text: str):
with open(tmpbase + '.txt', 'r') as fid: with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline() txt = fid.readline()
prep_mlf(txt, tmpbase) prep_mlf(txt, tmpbase)
except: except Exception:
print('prep_mlf error!') print('prep_mlf error!')
return None return None
...@@ -177,7 +312,7 @@ def alignment(wav_path: str, text: str): ...@@ -177,7 +312,7 @@ def alignment(wav_path: str, text: str):
try: try:
os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase + os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp') '.wav' + ' ' + tmpbase + '.plp')
except: except Exception:
print('HCopy error!') print('HCopy error!')
return None return None
...@@ -188,7 +323,7 @@ def alignment(wav_path: str, text: str): ...@@ -188,7 +323,7 @@ def alignment(wav_path: str, text: str):
+ '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase +
'.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase + '.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null') '.plp 2>&1 > /dev/null')
except: except Exception:
print('HVite error!') print('HVite error!')
return None return None
...@@ -200,7 +335,7 @@ def alignment(wav_path: str, text: str): ...@@ -200,7 +335,7 @@ def alignment(wav_path: str, text: str):
with open(tmpbase + '.aligned', 'r') as fid: with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines() lines = fid.readlines()
i = 2 i = 2
times2 = [] intervals = []
word2phns = {} word2phns = {}
current_word = '' current_word = ''
index = 0 index = 0
...@@ -210,7 +345,7 @@ def alignment(wav_path: str, text: str): ...@@ -210,7 +345,7 @@ def alignment(wav_path: str, text: str):
phn = splited_line[2] phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000 pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000 pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen]) intervals.append([phn, pst, pen])
# splited_line[-1]!='sp' # splited_line[-1]!='sp'
if len(splited_line) == 5: if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1] current_word = str(index) + '_' + splited_line[-1]
...@@ -219,10 +354,10 @@ def alignment(wav_path: str, text: str): ...@@ -219,10 +354,10 @@ def alignment(wav_path: str, text: str):
elif len(splited_line) == 4: elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn word2phns[current_word] += ' ' + phn
i += 1 i += 1
return times2, word2phns return intervals, word2phns
def alignment_zh(wav_path, text_string): def alignment_zh(wav_path: str, text: str):
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files #prepare wav and trs files
...@@ -230,18 +365,19 @@ def alignment_zh(wav_path, text_string): ...@@ -230,18 +365,19 @@ def alignment_zh(wav_path, text_string):
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase + os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -') '.wav remix -')
except: except Exception:
print('sox error!') print('sox error!')
return None return None
#prepare clean_transcript file #prepare clean_transcript file
try: try:
unk_words = prep_txt_zh(text_string, tmpbase, MODEL_DIR_ZH + '/dict') unk_words = prep_txt_zh(
line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_ZH + '/dict')
if unk_words: if unk_words:
print('Error! Please add the following words to dictionary:') print('Error! Please add the following words to dictionary:')
for unk in unk_words: for unk in unk_words:
print("非法words: ", unk) print("非法words: ", unk)
except: except Exception:
print('prep_txt error!') print('prep_txt error!')
return None return None
...@@ -250,7 +386,7 @@ def alignment_zh(wav_path, text_string): ...@@ -250,7 +386,7 @@ def alignment_zh(wav_path, text_string):
with open(tmpbase + '.txt', 'r') as fid: with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline() txt = fid.readline()
prep_mlf(txt, tmpbase) prep_mlf(txt, tmpbase)
except: except Exception:
print('prep_mlf error!') print('prep_mlf error!')
return None return None
...@@ -258,7 +394,7 @@ def alignment_zh(wav_path, text_string): ...@@ -258,7 +394,7 @@ def alignment_zh(wav_path, text_string):
try: try:
os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase + os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp') '.wav' + ' ' + tmpbase + '.plp')
except: except Exception:
print('HCopy error!') print('HCopy error!')
return None return None
...@@ -270,7 +406,7 @@ def alignment_zh(wav_path, text_string): ...@@ -270,7 +406,7 @@ def alignment_zh(wav_path, text_string):
+ '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase + + '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null') '.plp 2>&1 > /dev/null')
except: except Exception:
print('HVite error!') print('HVite error!')
return None return None
...@@ -283,7 +419,7 @@ def alignment_zh(wav_path, text_string): ...@@ -283,7 +419,7 @@ def alignment_zh(wav_path, text_string):
lines = fid.readlines() lines = fid.readlines()
i = 2 i = 2
times2 = [] intervals = []
word2phns = {} word2phns = {}
current_word = '' current_word = ''
index = 0 index = 0
...@@ -293,7 +429,7 @@ def alignment_zh(wav_path, text_string): ...@@ -293,7 +429,7 @@ def alignment_zh(wav_path, text_string):
phn = splited_line[2] phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000 pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000 pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen]) intervals.append([phn, pst, pen])
# splited_line[-1]!='sp' # splited_line[-1]!='sp'
if len(splited_line) == 5: if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1] current_word = str(index) + '_' + splited_line[-1]
...@@ -302,4 +438,4 @@ def alignment_zh(wav_path, text_string): ...@@ -302,4 +438,4 @@ def alignment_zh(wav_path, text_string):
elif len(splited_line) == 4: elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn word2phns[current_word] += ' ' + phn
i += 1 i += 1
return times2, word2phns return intervals, word2phns
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
import numpy as np
import paddle
from dataset import get_seg_pos
from dataset import phones_masking
from dataset import phones_text_masking
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import pad_list
class MLMCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(self,
feats_extract,
float_pad_value: Union[float, int]=0.0,
int_pad_value: int=-32768,
not_sequence: Collection[str]=(),
mlm_prob: float=0.8,
mean_phn_span: int=8,
attention_window: int=0,
pad_speech: bool=False,
seg_emb: bool=False,
text_masking: bool=False):
self.mlm_prob = mlm_prob
self.mean_phn_span = mean_phn_span
self.feats_extract = feats_extract
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.attention_window = attention_window
self.pad_speech = pad_speech
self.seg_emb = seg_emb
self.text_masking = text_masking
def __repr__(self):
return (f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})")
def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
return mlm_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
mlm_prob=self.mlm_prob,
mean_phn_span=self.mean_phn_span,
feats_extract=self.feats_extract,
attention_window=self.attention_window,
pad_speech=self.pad_speech,
seg_emb=self.seg_emb,
text_masking=self.text_masking)
def mlm_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int]=0.0,
int_pad_value: int=-32768,
not_sequence: Collection[str]=(),
mlm_prob: float=0.8,
mean_phn_span: int=8,
feats_extract=None,
attention_window: int=0,
pad_speech: bool=False,
seg_emb: bool=False,
text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(not k.endswith("_lens")
for k in data[0]), f"*_lens is reserved: {list(data[0])}"
output = {}
for key in data[0]:
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list = [paddle.to_tensor(a) for a in array_list]
# tensor: (Batch, Length, ...)
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
# lens: (Batch,)
if key not in not_sequence:
lens = paddle.to_tensor(
[d[key].shape[0] for d in data], dtype=paddle.int64)
output[key + "_lens"] = lens
feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
feats = paddle.to_tensor(feats)
feats_lens = paddle.shape(feats)[0]
feats = paddle.unsqueeze(feats, 0)
text = output["text"]
text_lens = output["text_lens"]
align_start = output["align_start"]
align_start_lens = output["align_start_lens"]
align_end = output["align_end"]
max_tlen = max(text_lens)
max_slen = max(feats_lens)
speech_pad = feats[:, :max_slen]
text_pad = text
text_mask = make_non_pad_mask(
text_lens, text_pad, length_dim=1).unsqueeze(-2)
speech_mask = make_non_pad_mask(
feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
span_bdy = None
if 'span_bdy' in output.keys():
span_bdy = output['span_bdy']
# dual_mask 的是混合中英时候同时 mask 语音和文本
# ernie sat 在实现跨语言的时候都 mask 了
if text_masking:
masked_pos, text_masked_pos = phones_text_masking(
xs_pad=speech_pad,
src_mask=speech_mask,
text_pad=text_pad,
text_mask=text_mask,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
span_bdy=span_bdy)
# 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了
# a3t 和 ernie sat 的区别主要在于做 mask 的时候
else:
masked_pos = phones_masking(
xs_pad=speech_pad,
src_mask=speech_mask,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
span_bdy=span_bdy)
text_masked_pos = paddle.zeros(paddle.shape(text_pad))
output_dict = {}
speech_seg_pos, text_seg_pos = get_seg_pos(
speech_pad=speech_pad,
text_pad=text_pad,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
seg_emb=seg_emb)
output_dict['speech'] = speech_pad
output_dict['text'] = text_pad
output_dict['masked_pos'] = masked_pos
output_dict['text_masked_pos'] = text_masked_pos
output_dict['speech_mask'] = speech_mask
output_dict['text_mask'] = text_mask
output_dict['speech_seg_pos'] = speech_seg_pos
output_dict['text_seg_pos'] = text_seg_pos
output = (uttids, output_dict)
return output
def build_collate_fn(
sr: int=24000,
n_fft: int=2048,
hop_length: int=300,
win_length: int=None,
n_mels: int=80,
fmin: int=80,
fmax: int=7600,
mlm_prob: float=0.8,
mean_phn_span: int=8,
train: bool=False,
seg_emb: bool=False,
epoch: int=-1, ):
feats_extract_class = LogMelFBank
feats_extract = feats_extract_class(
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mels=n_mels,
fmin=fmin,
fmax=fmax)
pad_speech = False
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_prob_factor = 0.8
return MLMCollateFn(
feats_extract=feats_extract,
float_pad_value=0.0,
int_pad_value=0,
mlm_prob=mlm_prob * mlm_prob_factor,
mean_phn_span=mean_phn_span,
pad_speech=pad_speech,
seg_emb=seg_emb)
...@@ -4,6 +4,68 @@ import numpy as np ...@@ -4,6 +4,68 @@ import numpy as np
import paddle import paddle
# mask phones
def phones_masking(xs_pad: paddle.Tensor,
src_mask: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
mlm_prob: float=0.8,
mean_phn_span: int=8,
span_bdy: paddle.Tensor=None):
'''
Args:
xs_pad (paddle.Tensor): input speech (B, Tmax, D).
src_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
mlm_prob (float):
mean_phn_span (int):
span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2).
Returns:
paddle.Tensor[bool]: masked position of input speech (B, Tmax).
'''
bz, sent_len, _ = paddle.shape(xs_pad)
masked_pos = paddle.zeros((bz, sent_len))
if mlm_prob == 1.0:
masked_pos += 1
elif mean_phn_span == 0:
# only speech
length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_idxs = random_spans_noise_mask(
length=length, mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span).nonzero()
masked_pos[:, masked_phn_idxs] = 1
else:
for idx in range(bz):
# for inference
if span_bdy is not None:
for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
masked_pos[idx, s:e] = 1
# for training
else:
length = align_start_lens[idx]
if length < 2:
continue
masked_phn_idxs = random_spans_noise_mask(
length=length,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span).nonzero()
masked_start = align_start[idx][masked_phn_idxs].tolist()
masked_end = align_end[idx][masked_phn_idxs].tolist()
for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask
masked_pos = paddle.cast(masked_pos, 'bool')
return masked_pos
# mask speech and phones
def phones_text_masking(xs_pad: paddle.Tensor, def phones_text_masking(xs_pad: paddle.Tensor,
src_mask: paddle.Tensor, src_mask: paddle.Tensor,
text_pad: paddle.Tensor, text_pad: paddle.Tensor,
...@@ -11,37 +73,56 @@ def phones_text_masking(xs_pad: paddle.Tensor, ...@@ -11,37 +73,56 @@ def phones_text_masking(xs_pad: paddle.Tensor,
align_start: paddle.Tensor, align_start: paddle.Tensor,
align_end: paddle.Tensor, align_end: paddle.Tensor,
align_start_lens: paddle.Tensor, align_start_lens: paddle.Tensor,
mlm_prob: float, mlm_prob: float=0.8,
mean_phn_span: float, mean_phn_span: int=8,
span_bdy: paddle.Tensor=None): span_bdy: paddle.Tensor=None):
'''
Args:
xs_pad (paddle.Tensor): input speech (B, Tmax, D).
src_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
text_pad (paddle.Tensor): input text (B, Tmax2).
text_mask (paddle.Tensor): mask of text (B, 1, Tmax2).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
mlm_prob (float):
mean_phn_span (int):
span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2).
Returns:
paddle.Tensor[bool]: masked position of input speech (B, Tmax).
paddle.Tensor[bool]: masked position of input text (B, Tmax2).
'''
bz, sent_len, _ = paddle.shape(xs_pad) bz, sent_len, _ = paddle.shape(xs_pad)
masked_pos = paddle.zeros((bz, sent_len)) masked_pos = paddle.zeros((bz, sent_len))
_, text_len = paddle.shape(text_pad) _, text_len = paddle.shape(text_pad)
text_mask_num_lower = math.ceil(text_len * (1 - mlm_prob) * 0.5) text_mask_num_lower = math.ceil(text_len * (1 - mlm_prob) * 0.5)
text_masked_pos = paddle.zeros((bz, text_len)) text_masked_pos = paddle.zeros((bz, text_len))
y_masks = None
if mlm_prob == 1.0: if mlm_prob == 1.0:
masked_pos += 1 masked_pos += 1
# y_masks = tril_masks
elif mean_phn_span == 0: elif mean_phn_span == 0:
# only speech # only speech
length = sent_len length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50) mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_idxs = random_spans_noise_mask(length, mlm_prob, masked_phn_idxs = random_spans_noise_mask(
mean_phn_span).nonzero() length=length, mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span).nonzero()
masked_pos[:, masked_phn_idxs] = 1 masked_pos[:, masked_phn_idxs] = 1
else: else:
for idx in range(bz): for idx in range(bz):
# for inference
if span_bdy is not None: if span_bdy is not None:
for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]): for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
masked_pos[idx, s:e] = 1 masked_pos[idx, s:e] = 1
# for training
else: else:
length = align_start_lens[idx] length = align_start_lens[idx]
if length < 2: if length < 2:
continue continue
masked_phn_idxs = random_spans_noise_mask( masked_phn_idxs = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero() length=length,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span).nonzero()
unmasked_phn_idxs = list( unmasked_phn_idxs = list(
set(range(length)) - set(masked_phn_idxs[0].tolist())) set(range(length)) - set(masked_phn_idxs[0].tolist()))
np.random.shuffle(unmasked_phn_idxs) np.random.shuffle(unmasked_phn_idxs)
...@@ -58,60 +139,76 @@ def phones_text_masking(xs_pad: paddle.Tensor, ...@@ -58,60 +139,76 @@ def phones_text_masking(xs_pad: paddle.Tensor,
masked_pos = paddle.cast(masked_pos, 'bool') masked_pos = paddle.cast(masked_pos, 'bool')
text_masked_pos = paddle.cast(text_masked_pos, 'bool') text_masked_pos = paddle.cast(text_masked_pos, 'bool')
return masked_pos, text_masked_pos, y_masks return masked_pos, text_masked_pos
def get_seg_pos_reduce_duration( def get_seg_pos(speech_pad: paddle.Tensor,
speech_pad: paddle.Tensor,
text_pad: paddle.Tensor, text_pad: paddle.Tensor,
align_start: paddle.Tensor, align_start: paddle.Tensor,
align_end: paddle.Tensor, align_end: paddle.Tensor,
align_start_lens: paddle.Tensor, align_start_lens: paddle.Tensor,
sega_emb: bool, seg_emb: bool=False):
masked_pos: paddle.Tensor, '''
feats_lens: paddle.Tensor, ): Args:
speech_pad (paddle.Tensor): input speech (B, Tmax, D).
text_pad (paddle.Tensor): input text (B, Tmax2).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
seg_emb (bool): whether to use segment embedding.
Returns:
paddle.Tensor[int]: n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax).
eg:
Tensor(shape=[1, 328], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 3 , 4 , 4 , 4 ,
5 , 5 , 5 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 7 , 7 , 7 , 7 , 7 , 7 , 7 ,
7 , 8 , 8 , 8 , 8 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 10, 10, 10, 10, 10,
10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13,
13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15,
15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17,
17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20,
20, 20, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 23, 23,
23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25,
25, 26, 26, 26, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 29,
29, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 32,
32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 35, 35,
35, 35, 35, 35, 35, 35, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38,
38, 38, 0 , 0 ]])
paddle.Tensor[int]: n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2).
eg:
Tensor(shape=[1, 38], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38]])
'''
bz, speech_len, _ = paddle.shape(speech_pad) bz, speech_len, _ = paddle.shape(speech_pad)
text_seg_pos = paddle.zeros(paddle.shape(text_pad)) _, text_len = paddle.shape(text_pad)
speech_seg_pos = paddle.zeros((bz, speech_len), dtype=text_pad.dtype)
reordered_idx = paddle.zeros((bz, speech_len), dtype=align_start_lens.dtype) text_seg_pos = paddle.zeros((bz, text_len), dtype='int64')
speech_seg_pos = paddle.zeros((bz, speech_len), dtype='int64')
durations = paddle.ones((bz, speech_len), dtype=align_start_lens.dtype) if not seg_emb:
max_reduced_length = 0 return speech_seg_pos, text_seg_pos
if not sega_emb:
return speech_pad, masked_pos, speech_seg_pos, text_seg_pos, durations
for idx in range(bz): for idx in range(bz):
first_idx = []
last_idx = []
align_length = align_start_lens[idx] align_length = align_start_lens[idx]
for j in range(align_length): for j in range(align_length):
s, e = align_start[idx][j], align_end[idx][j] s, e = align_start[idx][j], align_end[idx][j]
if j == 0: speech_seg_pos[idx, s:e] = j + 1
if paddle.sum(masked_pos[idx][0:s]) == 0: text_seg_pos[idx, j] = j + 1
first_idx.extend(range(0, s))
else:
first_idx.extend([0])
last_idx.extend(range(1, s))
if paddle.sum(masked_pos[idx][s:e]) == 0:
first_idx.extend(range(s, e))
else:
first_idx.extend([s])
last_idx.extend(range(s + 1, e))
durations[idx][s] = e - s
speech_seg_pos[idx][s:e] = j + 1
text_seg_pos[idx][j] = j + 1
max_reduced_length = max(
len(first_idx) + feats_lens[idx] - e, max_reduced_length)
first_idx.extend(range(e, speech_len))
reordered_idx[idx] = paddle.to_tensor(
(first_idx + last_idx), dtype=align_start_lens.dtype)
feats_lens[idx] = len(first_idx)
reordered_idx = reordered_idx[:, :max_reduced_length]
return reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens return speech_seg_pos, text_seg_pos
def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float): # randomly select the range of speech and text to mask during training
def random_spans_noise_mask(length: int,
mlm_prob: float=0.8,
mean_phn_span: float=8):
"""This function is copy of `random_spans_helper """This function is copy of `random_spans_helper
<https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ . <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens. Noise mask consisting of random spans of noise tokens.
...@@ -126,7 +223,7 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float): ...@@ -126,7 +223,7 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
noise_density: a float - approximate density of output mask noise_density: a float - approximate density of output mask
mean_noise_span_length: a number mean_noise_span_length: a number
Returns: Returns:
a boolean tensor with shape [length] np.ndarray: a boolean tensor with shape [length]
""" """
orig_length = length orig_length = length
...@@ -171,87 +268,3 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float): ...@@ -171,87 +268,3 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
is_noise = np.equal(span_num % 2, 1) is_noise = np.equal(span_num % 2, 1)
return is_noise[:orig_length] return is_noise[:orig_length]
def pad_to_longformer_att_window(text: paddle.Tensor,
max_len: int,
max_tlen: int,
attention_window: int=0):
round = max_len % attention_window
if round != 0:
max_tlen += (attention_window - round)
n_batch = paddle.shape(text)[0]
text_pad = paddle.zeros(
(n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype)
for i in range(n_batch):
text_pad[i, :paddle.shape(text[i])[0]] = text[i]
else:
text_pad = text[:, :max_tlen]
return text_pad, max_tlen
def phones_masking(xs_pad: paddle.Tensor,
src_mask: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
mlm_prob: float,
mean_phn_span: int,
span_bdy: paddle.Tensor=None):
bz, sent_len, _ = paddle.shape(xs_pad)
masked_pos = paddle.zeros((bz, sent_len))
y_masks = None
if mlm_prob == 1.0:
masked_pos += 1
elif mean_phn_span == 0:
# only speech
length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_idxs = random_spans_noise_mask(length, mlm_prob,
mean_phn_span).nonzero()
masked_pos[:, masked_phn_idxs] = 1
else:
for idx in range(bz):
if span_bdy is not None:
for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
masked_pos[idx, s:e] = 1
else:
length = align_start_lens[idx]
if length < 2:
continue
masked_phn_idxs = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero()
masked_start = align_start[idx][masked_phn_idxs].tolist()
masked_end = align_end[idx][masked_phn_idxs].tolist()
for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask
masked_pos = paddle.cast(masked_pos, 'bool')
return masked_pos, y_masks
def get_seg_pos(speech_pad: paddle.Tensor,
text_pad: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
sega_emb: bool):
bz, speech_len, _ = paddle.shape(speech_pad)
_, text_len = paddle.shape(text_pad)
text_seg_pos = paddle.zeros((bz, text_len), dtype='int64')
speech_seg_pos = paddle.zeros((bz, speech_len), dtype='int64')
if not sega_emb:
return speech_seg_pos, text_seg_pos
for idx in range(bz):
align_length = align_start_lens[idx]
for j in range(align_length):
s, e = align_start[idx][j], align_end[idx][j]
speech_seg_pos[idx, s:e] = j + 1
text_seg_pos[idx, j] = j + 1
return speech_seg_pos, text_seg_pos
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Collection
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Tuple
from typing import Union
import librosa import librosa
import numpy as np import numpy as np
...@@ -15,60 +11,42 @@ import paddle ...@@ -15,60 +11,42 @@ import paddle
import soundfile as sf import soundfile as sf
import torch import torch
from paddle import nn from paddle import nn
from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
from align import alignment
from align import alignment_zh
from dataset import get_seg_pos
from dataset import get_seg_pos_reduce_duration
from dataset import pad_to_longformer_att_window
from dataset import phones_masking
from dataset import phones_text_masking
from model_paddle import build_model_from_file
from read_text import load_num_sequence_text
from read_text import read_2column_text
from sedit_arg_parser import parse_args from sedit_arg_parser import parse_args
from utils import build_vocoder_from_file from utils import build_vocoder_from_file
from utils import evaluate_durations from utils import evaluate_durations
from utils import get_voc_out from utils import get_voc_out
from utils import is_chinese from utils import is_chinese
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.modules.nets_utils import pad_list from align import alignment
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask from align import alignment_zh
from align import words2phns
from align import words2phns_zh
from collect_fn import build_collate_fn
from mlm import build_model_from_file
from read_text import load_num_sequence_text
from read_text import read_2col_text
# from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin'
def plot_mel_and_vocode_wav(uid: str, def plot_mel_and_vocode_wav(wav_path: str,
wav_path: str,
prefix: str="./prompt/dev/",
source_lang: str='english', source_lang: str='english',
target_lang: str='english', target_lang: str='english',
model_name: str="conformer", model_name: str="paddle_checkpoint_en",
full_origin_str: str="",
old_str: str="", old_str: str="",
new_str: str="", new_str: str="",
duration_preditor_path: str=None,
use_pt_vocoder: bool=False, use_pt_vocoder: bool=False,
sid: str=None,
non_autoreg: bool=True): non_autoreg: bool=True):
wav_org, input_feat, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output( wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output(
uid=uid,
prefix=prefix,
source_lang=source_lang, source_lang=source_lang,
target_lang=target_lang, target_lang=target_lang,
model_name=model_name, model_name=model_name,
wav_path=wav_path, wav_path=wav_path,
old_str=old_str, old_str=old_str,
new_str=new_str, new_str=new_str,
duration_preditor_path=duration_preditor_path, use_teacher_forcing=non_autoreg)
use_teacher_forcing=non_autoreg,
sid=sid)
masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]] masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
...@@ -79,10 +57,10 @@ def plot_mel_and_vocode_wav(uid: str, ...@@ -79,10 +57,10 @@ def plot_mel_and_vocode_wav(uid: str,
vocoder = load_vocoder('vctk_parallel_wavegan.v1.long') vocoder = load_vocoder('vctk_parallel_wavegan.v1.long')
replaced_wav = vocoder(output_feat).cpu().numpy() replaced_wav = vocoder(output_feat).cpu().numpy()
else: else:
replaced_wav = get_voc_out(output_feat, target_lang) replaced_wav = get_voc_out(output_feat)
elif target_lang == 'chinese': elif target_lang == 'chinese':
replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, target_lang) replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat)
old_time_bdy = [hop_length * x for x in old_span_bdy] old_time_bdy = [hop_length * x for x in old_span_bdy]
new_time_bdy = [hop_length * x for x in new_span_bdy] new_time_bdy = [hop_length * x for x in new_span_bdy]
...@@ -109,125 +87,6 @@ def plot_mel_and_vocode_wav(uid: str, ...@@ -109,125 +87,6 @@ def plot_mel_and_vocode_wav(uid: str,
return data_dict, old_span_bdy return data_dict, old_span_bdy
def get_unk_phns(word_str: str):
tmpbase = '/tmp/tp.'
f = open(tmpbase + 'temp.words', 'w')
f.write(word_str)
f.close()
os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase +
'temp.phons')
f = open(tmpbase + 'temp.phons', 'r')
lines2 = f.readline().strip().split()
f.close()
phns = []
for phn in lines2:
phons = phn.replace('\n', '').replace(' ', '')
seq = []
j = 0
while (j < len(phons)):
if (phons[j] > 'Z'):
if (phons[j] == 'j'):
seq.append('JH')
elif (phons[j] == 'h'):
seq.append('HH')
else:
seq.append(phons[j].upper())
j += 1
else:
p = phons[j:j + 2]
if (p == 'WH'):
seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
seq.append(p)
elif (p == 'AX'):
seq.append('AH0')
else:
seq.append(p + '1')
j += 2
phns.extend(seq)
return phns
def words2phns(line: str):
dictfile = MODEL_DIR_EN + '/dict'
line = line.strip()
words = []
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd)
phns.extend(get_unk_phns(wrd))
else:
wrd2phns[str(index) +
"_" + wrd.upper()] = word2phns_dict[wrd.upper()].split()
phns.extend(word2phns_dict[wrd.upper()].split())
return phns, wrd2phns
def words2phns_zh(line: str):
dictfile = MODEL_DIR_ZH + '/dict'
line = line.strip()
words = []
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
print("出现非法词错误,请输入正确的文本...")
else:
wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split()
phns.extend(word2phns_dict[wrd].split())
return phns, wrd2phns
def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"): def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "") vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "")
vocoder_file = download_pretrained_model(vocoder_tag) vocoder_file = download_pretrained_model(vocoder_tag)
...@@ -236,50 +95,52 @@ def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"): ...@@ -236,50 +95,52 @@ def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
return vocoder return vocoder
def load_model(model_name: str): def load_model(model_name: str="paddle_checkpoint_en"):
config_path = './pretrained_model/{}/config.yaml'.format(model_name) config_path = './pretrained_model/{}/config.yaml'.format(model_name)
model_path = './pretrained_model/{}/model.pdparams'.format(model_name) model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
mlm_model, args = build_model_from_file( mlm_model, conf = build_model_from_file(
config_file=config_path, model_file=model_path) config_file=config_path, model_file=model_path)
return mlm_model, args return mlm_model, conf
def read_data(uid: str, prefix: str): def read_data(uid: str, prefix: os.PathLike):
mfa_text = read_2column_text(prefix + '/text')[uid] # 获取 uid 对应的文本
mfa_wav_path = read_2column_text(prefix + '/wav.scp')[uid] mfa_text = read_2col_text(prefix + '/text')[uid]
if 'mnt' not in mfa_wav_path: # 获取 uid 对应的音频路径
mfa_wav_path = prefix.split('dump')[0] + mfa_wav_path mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid]
if not os.path.isabs(mfa_wav_path):
mfa_wav_path = prefix + mfa_wav_path
return mfa_text, mfa_wav_path return mfa_text, mfa_wav_path
def get_align_data(uid: str, prefix: str): def get_align_data(uid: str, prefix: os.PathLike):
mfa_path = prefix + "mfa_" mfa_path = prefix + "mfa_"
mfa_text = read_2column_text(mfa_path + 'text')[uid] mfa_text = read_2col_text(mfa_path + 'text')[uid]
mfa_start = load_num_sequence_text( mfa_start = load_num_sequence_text(
mfa_path + 'start', loader_type='text_float')[uid] mfa_path + 'start', loader_type='text_float')[uid]
mfa_end = load_num_sequence_text( mfa_end = load_num_sequence_text(
mfa_path + 'end', loader_type='text_float')[uid] mfa_path + 'end', loader_type='text_float')[uid]
mfa_wav_path = read_2column_text(mfa_path + 'wav.scp')[uid] mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid]
return mfa_text, mfa_start, mfa_end, mfa_wav_path return mfa_text, mfa_start, mfa_end, mfa_wav_path
# 获取需要被 mask 的 mel 帧的范围
def get_masked_mel_bdy(mfa_start: List[float], def get_masked_mel_bdy(mfa_start: List[float],
mfa_end: List[float], mfa_end: List[float],
fs: int, fs: int,
hop_length: int, hop_length: int,
span_to_repl: List[List[int]]): span_to_repl: List[List[int]]):
align_start = paddle.to_tensor(mfa_start).unsqueeze(0) align_start = np.array(mfa_start)
align_end = paddle.to_tensor(mfa_end).unsqueeze(0) align_end = np.array(mfa_end)
align_start = paddle.floor(fs * align_start / hop_length).int() align_start = np.floor(fs * align_start / hop_length).astype('int')
align_end = paddle.floor(fs * align_end / hop_length).int() align_end = np.floor(fs * align_end / hop_length).astype('int')
if span_to_repl[0] >= len(mfa_start): if span_to_repl[0] >= len(mfa_start):
span_bdy = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]] span_bdy = [align_end[-1], align_end[-1]]
else: else:
span_bdy = [ span_bdy = [
align_start[0].tolist()[span_to_repl[0]], align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1]
align_end[0].tolist()[span_to_repl[1] - 1]
] ]
return span_bdy return span_bdy, align_start, align_end
def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]): def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
...@@ -317,18 +178,22 @@ def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]): ...@@ -317,18 +178,22 @@ def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
return dic return dic
def get_max_idx(dic):
return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1]
def get_phns_and_spans(wav_path: str, def get_phns_and_spans(wav_path: str,
old_str: str="", old_str: str="",
new_str: str="", new_str: str="",
source_lang: str="english", source_lang: str="english",
target_lang: str="english"): target_lang: str="english"):
append_new_str = (old_str == new_str[:len(old_str)]) is_append = (old_str == new_str[:len(old_str)])
old_phns, mfa_start, mfa_end = [], [], [] old_phns, mfa_start, mfa_end = [], [], []
# source
if source_lang == "english": if source_lang == "english":
times2, word2phns = alignment(wav_path, old_str) intervals, word2phns = alignment(wav_path, old_str)
elif source_lang == "chinese": elif source_lang == "chinese":
times2, word2phns = alignment_zh(wav_path, old_str) intervals, word2phns = alignment_zh(wav_path, old_str)
_, tp_word2phns = words2phns_zh(old_str) _, tp_word2phns = words2phns_zh(old_str)
for key, value in tp_word2phns.items(): for key, value in tp_word2phns.items():
...@@ -337,51 +202,46 @@ def get_phns_and_spans(wav_path: str, ...@@ -337,51 +202,46 @@ def get_phns_and_spans(wav_path: str,
tp_word2phns[key] = cur_val tp_word2phns[key] = cur_val
word2phns = recover_dict(word2phns, tp_word2phns) word2phns = recover_dict(word2phns, tp_word2phns)
else: else:
assert source_lang == "chinese" or source_lang == "english", "source_lang is wrong..." assert source_lang == "chinese" or source_lang == "english", \
"source_lang is wrong..."
for item in times2: for item in intervals:
old_phns.append(item[0])
mfa_start.append(float(item[1])) mfa_start.append(float(item[1]))
mfa_end.append(float(item[2])) mfa_end.append(float(item[2]))
old_phns.append(item[0]) # target
if is_append and (source_lang != target_lang):
if append_new_str and (source_lang != target_lang): cross_lingual_clone = True
is_cross_lingual_clone = True
else: else:
is_cross_lingual_clone = False cross_lingual_clone = False
if is_cross_lingual_clone: if cross_lingual_clone:
new_str_origin = new_str[:len(old_str)] str_origin = new_str[:len(old_str)]
new_str_append = new_str[len(old_str):] str_append = new_str[len(old_str):]
if target_lang == "chinese": if target_lang == "chinese":
new_phns_origin, new_origin_word2phns = words2phns(new_str_origin) phns_origin, origin_word2phns = words2phns(str_origin)
new_phns_append, temp_new_append_word2phns = words2phns_zh( phns_append, append_word2phns_tmp = words2phns_zh(str_append)
new_str_append)
elif target_lang == "english": elif target_lang == "english":
# 原始句子 # 原始句子
new_phns_origin, new_origin_word2phns = words2phns_zh( phns_origin, origin_word2phns = words2phns_zh(str_origin)
new_str_origin) # clone 句子
# clone句子 phns_append, append_word2phns_tmp = words2phns(str_append)
new_phns_append, temp_new_append_word2phns = words2phns(
new_str_append)
else: else:
assert target_lang == "chinese" or target_lang == "english", \ assert target_lang == "chinese" or target_lang == "english", \
"cloning is not support for this language, please check it." "cloning is not support for this language, please check it."
new_phns = new_phns_origin + new_phns_append new_phns = phns_origin + phns_append
new_append_word2phns = {} append_word2phns = {}
length = len(new_origin_word2phns) length = len(origin_word2phns)
for key, value in temp_new_append_word2phns.items(): for key, value in append_word2phns_tmp.items():
idx, wrd = key.split('_') idx, wrd = key.split('_')
new_append_word2phns[str(int(idx) + length) + '_' + wrd] = value append_word2phns[str(int(idx) + length) + '_' + wrd] = value
new_word2phns = origin_word2phns.copy()
new_word2phns = dict( new_word2phns.update(append_word2phns)
list(new_origin_word2phns.items()) + list(
new_append_word2phns.items()))
else: else:
if source_lang == target_lang and target_lang == "english": if source_lang == target_lang and target_lang == "english":
...@@ -417,16 +277,17 @@ def get_phns_and_spans(wav_path: str, ...@@ -417,16 +277,17 @@ def get_phns_and_spans(wav_path: str,
right_idx = 0 right_idx = 0
new_phns_right = [] new_phns_right = []
sp_count = 0 sp_count = 0
word2phns_max_idx = int(list(word2phns.keys())[-1].split('_')[0]) word2phns_max_idx = get_max_idx(word2phns)
new_word2phns_max_idx = int(list(new_word2phns.keys())[-1].split('_')[0]) new_word2phns_max_idx = get_max_idx(new_word2phns)
new_phns_mid = [] new_phns_mid = []
if append_new_str: if is_append:
new_phns_right = [] new_phns_right = []
new_phns_mid = new_phns[left_idx:] new_phns_mid = new_phns[left_idx:]
span_to_repl[0] = len(new_phns_left) span_to_repl[0] = len(new_phns_left)
span_to_add[0] = len(new_phns_left) span_to_add[0] = len(new_phns_left)
span_to_add[1] = len(new_phns_left) + len(new_phns_mid) span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
span_to_repl[1] = len(old_phns) - len(new_phns_right) span_to_repl[1] = len(old_phns) - len(new_phns_right)
# speech edit
else: else:
for key in list(word2phns.keys())[::-1]: for key in list(word2phns.keys())[::-1]:
idx, wrd = key.split('_') idx, wrd = key.split('_')
...@@ -451,47 +312,57 @@ def get_phns_and_spans(wav_path: str, ...@@ -451,47 +312,57 @@ def get_phns_and_spans(wav_path: str,
len(old_phns)) len(old_phns))
break break
new_phns = new_phns_left + new_phns_mid + new_phns_right new_phns = new_phns_left + new_phns_mid + new_phns_right
'''
For that reason cover should not be given.
For that reason cover is impossible to be given.
span_to_repl: [17, 23] "should not"
span_to_add: [17, 30] "is impossible to"
'''
return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add
def duration_adjust_factor(original_dur: List[int], # mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
def duration_adjust_factor(orig_dur: List[int],
pred_dur: List[int], pred_dur: List[int],
phns: List[str]): phns: List[str]):
length = 0 length = 0
factor_list = [] factor_list = []
for ori, pred, phn in zip(original_dur, pred_dur, phns): for orig, pred, phn in zip(orig_dur, pred_dur, phns):
if pred == 0 or phn == 'sp': if pred == 0 or phn == 'sp':
continue continue
else: else:
factor_list.append(ori / pred) factor_list.append(orig / pred)
factor_list = np.array(factor_list) factor_list = np.array(factor_list)
factor_list.sort() factor_list.sort()
if len(factor_list) < 5: if len(factor_list) < 5:
return 1 return 1
length = 2 length = 2
return np.average(factor_list[length:-length]) avg = np.average(factor_list[length:-length])
return avg
def prepare_features_with_duration(uid: str, def prep_feats_with_dur(wav_path: str,
prefix: str,
wav_path: str,
mlm_model: nn.Layer, mlm_model: nn.Layer,
source_lang: str="English", source_lang: str="English",
target_lang: str="English", target_lang: str="English",
old_str: str="", old_str: str="",
new_str: str="", new_str: str="",
duration_preditor_path: str=None,
sid: str=None,
mask_reconstruct: bool=False, mask_reconstruct: bool=False,
duration_adjust: bool=True, duration_adjust: bool=True,
start_end_sp: bool=False, start_end_sp: bool=False,
train_args=None): fs: int=24000,
wav_org, rate = librosa.load( hop_length: int=300):
wav_path, sr=train_args.feats_extract_conf['fs']) '''
fs = train_args.feats_extract_conf['fs'] Returns:
hop_length = train_args.feats_extract_conf['hop_length'] np.ndarray: new wav, replace the part to be edited in original wav with 0
List[str]: new phones
List[float]: mfa start of new wav
List[float]: mfa end of new wav
List[int]: masked mel boundary of original wav
List[int]: masked mel boundary of new wav
'''
wav_org, _ = librosa.load(wav_path, sr=fs)
mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans( mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans(
wav_path=wav_path, wav_path=wav_path,
...@@ -503,144 +374,130 @@ def prepare_features_with_duration(uid: str, ...@@ -503,144 +374,130 @@ def prepare_features_with_duration(uid: str,
if start_end_sp: if start_end_sp:
if new_phns[-1] != 'sp': if new_phns[-1] != 'sp':
new_phns = new_phns + ['sp'] new_phns = new_phns + ['sp']
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if target_lang == "english": if target_lang == "english" or target_lang == "chinese":
old_durations = evaluate_durations(old_phns, target_lang=target_lang) old_durs = evaluate_durations(old_phns, target_lang=source_lang)
elif target_lang == "chinese":
if source_lang == "english":
old_durations = evaluate_durations(
old_phns, target_lang=source_lang)
elif source_lang == "chinese":
old_durations = evaluate_durations(
old_phns, target_lang=source_lang)
else: else:
assert target_lang == "chinese" or target_lang == "english", "calculate duration_predict is not support for this language..." assert target_lang == "chinese" or target_lang == "english", \
"calculate duration_predict is not support for this language..."
original_old_durations = [e - s for e, s in zip(mfa_end, mfa_start)] orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
if '[MASK]' in new_str: if '[MASK]' in new_str:
new_phns = old_phns new_phns = old_phns
span_to_add = span_to_repl span_to_add = span_to_repl
d_factor_left = duration_adjust_factor( d_factor_left = duration_adjust_factor(
original_old_durations[:span_to_repl[0]], orig_dur=orig_old_durs[:span_to_repl[0]],
old_durations[:span_to_repl[0]], old_phns[:span_to_repl[0]]) pred_dur=old_durs[:span_to_repl[0]],
phns=old_phns[:span_to_repl[0]])
d_factor_right = duration_adjust_factor( d_factor_right = duration_adjust_factor(
original_old_durations[span_to_repl[1]:], orig_dur=orig_old_durs[span_to_repl[1]:],
old_durations[span_to_repl[1]:], old_phns[span_to_repl[1]:]) pred_dur=old_durs[span_to_repl[1]:],
phns=old_phns[span_to_repl[1]:])
d_factor = (d_factor_left + d_factor_right) / 2 d_factor = (d_factor_left + d_factor_right) / 2
new_durations_adjusted = [d_factor * i for i in old_durations] new_durs_adjusted = [d_factor * i for i in old_durs]
else: else:
if duration_adjust: if duration_adjust:
d_factor = duration_adjust_factor(original_old_durations, d_factor = duration_adjust_factor(
old_durations, old_phns) orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
print("d_factor:", d_factor)
d_factor = d_factor * 1.25 d_factor = d_factor * 1.25
else: else:
d_factor = 1 d_factor = 1
if target_lang == "english": if target_lang == "english" or target_lang == "chinese":
new_durations = evaluate_durations( new_durs = evaluate_durations(new_phns, target_lang=target_lang)
new_phns, target_lang=target_lang) else:
assert target_lang == "chinese" or target_lang == "english", \
elif target_lang == "chinese": "calculate duration_predict is not support for this language..."
new_durations = evaluate_durations(
new_phns, target_lang=target_lang) new_durs_adjusted = [d_factor * i for i in new_durs]
new_durations_adjusted = [d_factor * i for i in new_durations] new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
if span_to_repl[0] < len(old_phns) and old_phns[span_to_repl[ dur_offset = new_span_dur_sum - old_span_dur_sum
0]] == new_phns[span_to_add[0]]:
new_durations_adjusted[span_to_add[0]] = original_old_durations[
span_to_repl[0]]
if span_to_repl[1] < len(old_phns) and span_to_add[1] < len(new_phns):
if old_phns[span_to_repl[1]] == new_phns[span_to_add[1]]:
new_durations_adjusted[span_to_add[1]] = original_old_durations[
span_to_repl[1]]
new_span_duration_sum = sum(
new_durations_adjusted[span_to_add[0]:span_to_add[1]])
old_span_duration_sum = sum(
original_old_durations[span_to_repl[0]:span_to_repl[1]])
duration_offset = new_span_duration_sum - old_span_duration_sum
new_mfa_start = mfa_start[:span_to_repl[0]] new_mfa_start = mfa_start[:span_to_repl[0]]
new_mfa_end = mfa_end[:span_to_repl[0]] new_mfa_end = mfa_end[:span_to_repl[0]]
for i in new_durations_adjusted[span_to_add[0]:span_to_add[1]]: for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
if len(new_mfa_end) == 0: if len(new_mfa_end) == 0:
new_mfa_start.append(0) new_mfa_start.append(0)
new_mfa_end.append(i) new_mfa_end.append(i)
else: else:
new_mfa_start.append(new_mfa_end[-1]) new_mfa_start.append(new_mfa_end[-1])
new_mfa_end.append(new_mfa_end[-1] + i) new_mfa_end.append(new_mfa_end[-1] + i)
new_mfa_start += [i + duration_offset for i in mfa_start[span_to_repl[1]:]] new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
new_mfa_end += [i + duration_offset for i in mfa_end[span_to_repl[1]:]] new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
# 3. get new wav # 3. get new wav
# 在原始句子后拼接
if span_to_repl[0] >= len(mfa_start): if span_to_repl[0] >= len(mfa_start):
left_idx = len(wav_org) left_idx = len(wav_org)
right_idx = left_idx right_idx = left_idx
# 在原始句子中间替换
else: else:
left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs)) left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs))
right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs)) right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs))
new_blank_wav = np.zeros( blank_wav = np.zeros(
(int(np.ceil(new_span_duration_sum * fs)), ), dtype=wav_org.dtype) (int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype)
new_wav_org = np.concatenate( # 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
[wav_org[:left_idx], new_blank_wav, wav_org[right_idx:]]) new_wav = np.concatenate(
[wav_org[:left_idx], blank_wav, wav_org[right_idx:]])
# 4. get old and new mel span to be mask # 4. get old and new mel span to be mask
# [92, 92] # [92, 92]
old_span_bdy = get_masked_mel_bdy(mfa_start, mfa_end, fs, hop_length,
span_to_repl) old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy(
mfa_start=mfa_start,
mfa_end=mfa_end,
fs=fs,
hop_length=hop_length,
span_to_repl=span_to_repl)
# [92, 174] # [92, 174]
new_span_bdy = get_masked_mel_bdy(new_mfa_start, new_mfa_end, fs, # new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别
hop_length, span_to_add) new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy(
mfa_start=new_mfa_start,
mfa_end=new_mfa_end,
fs=fs,
hop_length=hop_length,
span_to_repl=span_to_add)
return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy # old_span_bdy, new_span_bdy 是帧级别的范围
return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy
def prepare_features(uid: str, def prep_feats(mlm_model: nn.Layer,
mlm_model: nn.Layer,
processor,
wav_path: str, wav_path: str,
prefix: str="./prompt/dev/",
source_lang: str="english", source_lang: str="english",
target_lang: str="english", target_lang: str="english",
old_str: str="", old_str: str="",
new_str: str="", new_str: str="",
duration_preditor_path: str=None,
sid: str=None,
duration_adjust: bool=True, duration_adjust: bool=True,
start_end_sp: bool=False, start_end_sp: bool=False,
mask_reconstruct: bool=False, mask_reconstruct: bool=False,
train_args=None): fs: int=24000,
wav_org, phns_list, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prepare_features_with_duration( hop_length: int=300,
uid=uid, token_list: List[str]=[]):
prefix=prefix, wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur(
source_lang=source_lang, source_lang=source_lang,
target_lang=target_lang, target_lang=target_lang,
mlm_model=mlm_model, mlm_model=mlm_model,
old_str=old_str, old_str=old_str,
new_str=new_str, new_str=new_str,
wav_path=wav_path, wav_path=wav_path,
duration_preditor_path=duration_preditor_path,
sid=sid,
duration_adjust=duration_adjust, duration_adjust=duration_adjust,
start_end_sp=start_end_sp, start_end_sp=start_end_sp,
mask_reconstruct=mask_reconstruct, mask_reconstruct=mask_reconstruct,
train_args=train_args) fs=fs,
speech = wav_org hop_length=hop_length)
align_start = np.array(mfa_start)
align_end = np.array(mfa_end)
token_to_id = {item: i for i, item in enumerate(train_args.token_list)}
text = np.array(
list(
map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns_list)))
token_to_id = {item: i for i, item in enumerate(token_list)}
text = np.array(
list(map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns)))
span_bdy = np.array(new_span_bdy) span_bdy = np.array(new_span_bdy)
batch = [('1', { batch = [('1', {
"speech": speech, "speech": wav,
"align_start": align_start, "align_start": mfa_start,
"align_end": align_end, "align_end": mfa_end,
"text": text, "text": text,
"span_bdy": span_bdy "span_bdy": span_bdy
})] })]
...@@ -648,375 +505,135 @@ def prepare_features(uid: str, ...@@ -648,375 +505,135 @@ def prepare_features(uid: str,
return batch, old_span_bdy, new_span_bdy return batch, old_span_bdy, new_span_bdy
def decode_with_model(uid: str, def decode_with_model(mlm_model: nn.Layer,
mlm_model: nn.Layer,
processor,
collate_fn, collate_fn,
wav_path: str, wav_path: str,
prefix: str="./prompt/dev/",
source_lang: str="english", source_lang: str="english",
target_lang: str="english", target_lang: str="english",
old_str: str="", old_str: str="",
new_str: str="", new_str: str="",
duration_preditor_path: str=None,
sid: str=None,
decoder: bool=False,
use_teacher_forcing: bool=False, use_teacher_forcing: bool=False,
duration_adjust: bool=True, duration_adjust: bool=True,
start_end_sp: bool=False, start_end_sp: bool=False,
train_args=None): fs: int=24000,
fs, hop_length = train_args.feats_extract_conf[ hop_length: int=300,
'fs'], train_args.feats_extract_conf['hop_length'] token_list: List[str]=[]):
batch, old_span_bdy, new_span_bdy = prep_feats(
batch, old_span_bdy, new_span_bdy = prepare_features(
uid=uid,
prefix=prefix,
source_lang=source_lang, source_lang=source_lang,
target_lang=target_lang, target_lang=target_lang,
mlm_model=mlm_model, mlm_model=mlm_model,
processor=processor,
wav_path=wav_path, wav_path=wav_path,
old_str=old_str, old_str=old_str,
new_str=new_str, new_str=new_str,
duration_preditor_path=duration_preditor_path,
sid=sid,
duration_adjust=duration_adjust, duration_adjust=duration_adjust,
start_end_sp=start_end_sp, start_end_sp=start_end_sp,
train_args=train_args) fs=fs,
hop_length=hop_length,
token_list=token_list)
feats = collate_fn(batch)[1] feats = collate_fn(batch)[1]
if 'text_masked_pos' in feats.keys(): if 'text_masked_pos' in feats.keys():
feats.pop('text_masked_pos') feats.pop('text_masked_pos')
for k, v in feats.items():
feats[k] = paddle.to_tensor(v) output = mlm_model.inference(
rtn = mlm_model.inference( text=feats['text'],
**feats, span_bdy=new_span_bdy, use_teacher_forcing=use_teacher_forcing) speech=feats['speech'],
output = rtn['feat_gen'] masked_pos=feats['masked_pos'],
speech_mask=feats['speech_mask'],
text_mask=feats['text_mask'],
speech_seg_pos=feats['speech_seg_pos'],
text_seg_pos=feats['text_seg_pos'],
span_bdy=new_span_bdy,
use_teacher_forcing=use_teacher_forcing)
if 0 in output[0].shape and 0 not in output[-1].shape: if 0 in output[0].shape and 0 not in output[-1].shape:
output_feat = paddle.concat( output_feat = paddle.concat(
output[1:-1] + [output[-1].squeeze()], axis=0).cpu() output[1:-1] + [output[-1].squeeze()], axis=0)
elif 0 not in output[0].shape and 0 in output[-1].shape: elif 0 not in output[0].shape and 0 in output[-1].shape:
output_feat = paddle.concat( output_feat = paddle.concat(
[output[0].squeeze()] + output[1:-1], axis=0).cpu() [output[0].squeeze()] + output[1:-1], axis=0)
elif 0 in output[0].shape and 0 in output[-1].shape: elif 0 in output[0].shape and 0 in output[-1].shape:
output_feat = paddle.concat(output[1:-1], axis=0).cpu() output_feat = paddle.concat(output[1:-1], axis=0)
else: else:
output_feat = paddle.concat( output_feat = paddle.concat(
[output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)], [output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)],
axis=0).cpu() axis=0)
wav_org, _ = librosa.load(
wav_path, sr=train_args.feats_extract_conf['fs'])
return wav_org, None, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
class MLMCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(self,
feats_extract,
float_pad_value: Union[float, int]=0.0,
int_pad_value: int=-32768,
not_sequence: Collection[str]=(),
mlm_prob: float=0.8,
mean_phn_span: int=8,
attention_window: int=0,
pad_speech: bool=False,
sega_emb: bool=False,
duration_collect: bool=False,
text_masking: bool=False):
self.mlm_prob = mlm_prob
self.mean_phn_span = mean_phn_span
self.feats_extract = feats_extract
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.attention_window = attention_window
self.pad_speech = pad_speech
self.sega_emb = sega_emb
self.duration_collect = duration_collect
self.text_masking = text_masking
def __repr__(self):
return (f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})")
def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
return mlm_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
mlm_prob=self.mlm_prob,
mean_phn_span=self.mean_phn_span,
feats_extract=self.feats_extract,
attention_window=self.attention_window,
pad_speech=self.pad_speech,
sega_emb=self.sega_emb,
duration_collect=self.duration_collect,
text_masking=self.text_masking)
def mlm_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int]=0.0,
int_pad_value: int=-32768,
not_sequence: Collection[str]=(),
mlm_prob: float=0.8,
mean_phn_span: int=8,
feats_extract=None,
attention_window: int=0,
pad_speech: bool=False,
sega_emb: bool=False,
duration_collect: bool=False,
text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(not k.endswith("_lens")
for k in data[0]), f"*_lens is reserved: {list(data[0])}"
output = {}
for key in data[0]:
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list = [paddle.to_tensor(a) for a in array_list]
# tensor: (Batch, Length, ...)
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
# lens: (Batch,)
if key not in not_sequence:
lens = paddle.to_tensor(
[d[key].shape[0] for d in data], dtype=paddle.int64)
output[key + "_lens"] = lens
feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
feats = paddle.to_tensor(feats)
feats_lens = paddle.shape(feats)[0]
feats = paddle.unsqueeze(feats, 0)
if 'text' not in output:
text = paddle.zeros(paddle.shape(feats_lens.unsqueeze(-1))) - 2
text_lens = paddle.zeros(paddle.shape(feats_lens)) + 1
max_tlen = 1
align_start = paddle.zeros(paddle.shape(text))
align_end = paddle.zeros(paddle.shape(text))
align_start_lens = paddle.zeros(paddle.shape(feats_lens))
sega_emb = False
mean_phn_span = 0
mlm_prob = 0.15
else:
text = output["text"]
text_lens = output["text_lens"]
align_start = output["align_start"]
align_start_lens = output["align_start_lens"]
align_end = output["align_end"]
align_start = paddle.floor(feats_extract.sr * align_start /
feats_extract.hop_length).int()
align_end = paddle.floor(feats_extract.sr * align_end /
feats_extract.hop_length).int()
max_tlen = max(text_lens)
max_slen = max(feats_lens)
speech_pad = feats[:, :max_slen]
if attention_window > 0 and pad_speech:
speech_pad, max_slen = pad_to_longformer_att_window(
speech_pad, max_slen, max_slen, attention_window)
max_len = max_slen + max_tlen
if attention_window > 0:
text_pad, max_tlen = pad_to_longformer_att_window(
text, max_len, max_tlen, attention_window)
else:
text_pad = text
text_mask = make_non_pad_mask(
text_lens, text_pad, length_dim=1).unsqueeze(-2)
if attention_window > 0:
text_mask = text_mask * 2
speech_mask = make_non_pad_mask(
feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
span_bdy = None
if 'span_bdy' in output.keys():
span_bdy = output['span_bdy']
if text_masking:
masked_pos, text_masked_pos, _ = phones_text_masking(
speech_pad, speech_mask, text_pad, text_mask, align_start,
align_end, align_start_lens, mlm_prob, mean_phn_span, span_bdy)
else:
text_masked_pos = paddle.zeros(paddle.shape(text_pad))
masked_pos, _ = phones_masking(speech_pad, speech_mask, align_start,
align_end, align_start_lens, mlm_prob,
mean_phn_span, span_bdy)
output_dict = {}
if duration_collect and 'text' in output:
reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens = get_seg_pos_reduce_duration(
speech_pad, text_pad, align_start, align_end, align_start_lens,
sega_emb, masked_pos, feats_lens)
speech_mask = make_non_pad_mask(
feats_lens, speech_pad[:, :reordered_idx.shape[1], 0],
length_dim=1).unsqueeze(-2)
output_dict['durations'] = durations
output_dict['reordered_idx'] = reordered_idx
else:
speech_seg_pos, text_seg_pos = get_seg_pos(speech_pad, text_pad,
align_start, align_end,
align_start_lens, sega_emb)
output_dict['speech'] = speech_pad
output_dict['text'] = text_pad
output_dict['masked_pos'] = masked_pos
output_dict['text_masked_pos'] = text_masked_pos
output_dict['speech_mask'] = speech_mask
output_dict['text_mask'] = text_mask
output_dict['speech_seg_pos'] = speech_seg_pos
output_dict['text_seg_pos'] = text_seg_pos
output_dict['speech_lens'] = output["speech_lens"]
output_dict['text_lens'] = text_lens
output = (uttids, output_dict)
return output
def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
# -> Callable[
# [Collection[Tuple[str, Dict[str, np.ndarray]]]],
# Tuple[List[str], Dict[str, Tensor]],
# ]:
# assert check_argument_types()
# return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
feats_extract_class = LogMelFBank
if args.feats_extract_conf['win_length'] is None:
args.feats_extract_conf['win_length'] = args.feats_extract_conf['n_fft']
args_dic = {}
for k, v in args.feats_extract_conf.items():
if k == 'fs':
args_dic['sr'] = v
else:
args_dic[k] = v
feats_extract = feats_extract_class(**args_dic)
sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False
if args.encoder_conf['selfattention_layer_type'] == 'longformer':
attention_window = args.encoder_conf['attention_window']
pad_speech = True if 'pre_speech_layer' in args.encoder_conf and args.encoder_conf[
'pre_speech_layer'] > 0 else False
else:
attention_window = 0
pad_speech = False
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_prob_factor = 0.8
if 'duration_predictor_layers' in args.model_conf.keys(
) and args.model_conf['duration_predictor_layers'] > 0:
duration_collect = True
else:
duration_collect = False
return MLMCollateFn( wav_org, _ = librosa.load(wav_path, sr=fs)
feats_extract, return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
float_pad_value=0.0,
int_pad_value=0,
mlm_prob=args.model_conf['mlm_prob'] * mlm_prob_factor,
mean_phn_span=args.model_conf['mean_phn_span'],
attention_window=attention_window,
pad_speech=pad_speech,
sega_emb=sega_emb,
duration_collect=duration_collect)
def get_mlm_output(uid: str, def get_mlm_output(wav_path: str,
wav_path: str, model_name: str="paddle_checkpoint_en",
prefix: str="./prompt/dev/",
model_name: str="conformer",
source_lang: str="english", source_lang: str="english",
target_lang: str="english", target_lang: str="english",
old_str: str="", old_str: str="",
new_str: str="", new_str: str="",
duration_preditor_path: str=None,
sid: str=None,
decoder: bool=False,
use_teacher_forcing: bool=False, use_teacher_forcing: bool=False,
duration_adjust: bool=True, duration_adjust: bool=True,
start_end_sp: bool=False): start_end_sp: bool=False):
mlm_model, train_args = load_model(model_name) mlm_model, train_conf = load_model(model_name)
mlm_model.eval() mlm_model.eval()
processor = None
collate_fn = build_collate_fn(train_args, False) collate_fn = build_collate_fn(
sr=train_conf.feats_extract_conf['fs'],
n_fft=train_conf.feats_extract_conf['n_fft'],
hop_length=train_conf.feats_extract_conf['hop_length'],
win_length=train_conf.feats_extract_conf['win_length'],
n_mels=train_conf.feats_extract_conf['n_mels'],
fmin=train_conf.feats_extract_conf['fmin'],
fmax=train_conf.feats_extract_conf['fmax'],
mlm_prob=train_conf['mlm_prob'],
mean_phn_span=train_conf['mean_phn_span'],
train=False,
seg_emb=train_conf.encoder_conf['input_layer'] == 'sega_mlm')
return decode_with_model( return decode_with_model(
uid=uid,
prefix=prefix,
source_lang=source_lang, source_lang=source_lang,
target_lang=target_lang, target_lang=target_lang,
mlm_model=mlm_model, mlm_model=mlm_model,
processor=processor,
collate_fn=collate_fn, collate_fn=collate_fn,
wav_path=wav_path, wav_path=wav_path,
old_str=old_str, old_str=old_str,
new_str=new_str, new_str=new_str,
duration_preditor_path=duration_preditor_path,
sid=sid,
decoder=decoder,
use_teacher_forcing=use_teacher_forcing, use_teacher_forcing=use_teacher_forcing,
duration_adjust=duration_adjust, duration_adjust=duration_adjust,
start_end_sp=start_end_sp, start_end_sp=start_end_sp,
train_args=train_args) fs=train_conf.feats_extract_conf['fs'],
hop_length=train_conf.feats_extract_conf['hop_length'],
token_list=train_conf.token_list)
def evaluate(uid: str, def evaluate(uid: str,
source_lang: str="english", source_lang: str="english",
target_lang: str="english", target_lang: str="english",
use_pt_vocoder: bool=False, use_pt_vocoder: bool=False,
prefix: str="./prompt/dev/", prefix: os.PathLike="./prompt/dev/",
model_name: str="conformer", model_name: str="paddle_checkpoint_en",
old_str: str="",
new_str: str="", new_str: str="",
prompt_decoding: bool=False, prompt_decoding: bool=False,
task_name: str=None): task_name: str=None):
duration_preditor_path = None # get origin text and path of origin wav
spemd = None old_str, wav_path = read_data(uid=uid, prefix=prefix)
full_origin_str, wav_path = read_data(uid=uid, prefix=prefix)
if task_name == 'edit': if task_name == 'edit':
new_str = new_str new_str = new_str
elif task_name == 'synthesize': elif task_name == 'synthesize':
new_str = full_origin_str + new_str new_str = old_str + new_str
else: else:
new_str = full_origin_str + ' '.join( new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)])
[ch for ch in new_str if is_chinese(ch)])
print('new_str is ', new_str) print('new_str is ', new_str)
if not old_str:
old_str = full_origin_str
results_dict, old_span = plot_mel_and_vocode_wav( results_dict, old_span = plot_mel_and_vocode_wav(
uid=uid,
prefix=prefix,
source_lang=source_lang, source_lang=source_lang,
target_lang=target_lang, target_lang=target_lang,
model_name=model_name, model_name=model_name,
wav_path=wav_path, wav_path=wav_path,
full_origin_str=full_origin_str,
old_str=old_str, old_str=old_str,
new_str=new_str, new_str=new_str,
use_pt_vocoder=use_pt_vocoder, use_pt_vocoder=use_pt_vocoder)
duration_preditor_path=duration_preditor_path,
sid=spemd)
return results_dict return results_dict
......
import argparse import argparse
import logging
import math
import os import os
import sys import sys
from pathlib import Path
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Optional from typing import Optional
...@@ -20,17 +17,18 @@ for dir_name in os.listdir(pypath): ...@@ -20,17 +17,18 @@ for dir_name in os.listdir(pypath):
if os.path.isdir(dir_path): if os.path.isdir(dir_path):
sys.path.append(dir_path) sys.path.append(dir_path)
from paddlespeech.s2t.utils.error_rate import ErrorCalculator
from paddlespeech.t2s.modules.activation import get_activation from paddlespeech.t2s.modules.activation import get_activation
from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule
from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer
from paddlespeech.t2s.modules.masked_fill import masked_fill from paddlespeech.t2s.modules.masked_fill import masked_fill
from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.tacotron2.decoder import Postnet from paddlespeech.t2s.modules.tacotron2.decoder import Postnet
from paddlespeech.t2s.modules.transformer.embedding import LegacyRelPositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding
from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling
from paddlespeech.t2s.modules.transformer.attention import LegacyRelPositionMultiHeadedAttention
from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention
from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention
from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward
...@@ -39,65 +37,10 @@ from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredCo ...@@ -39,65 +37,10 @@ from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredCo
from paddlespeech.t2s.modules.transformer.repeat import repeat from paddlespeech.t2s.modules.transformer.repeat import repeat
from paddlespeech.t2s.modules.layer_norm import LayerNorm from paddlespeech.t2s.modules.layer_norm import LayerNorm
from yacs.config import CfgNode
class LegacyRelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]:
return
pe = paddle.zeros((paddle.shape(x)[1], self.d_model))
if self.reverse:
position = paddle.arange(
paddle.shape(x)[1] - 1, -1, -1.0,
dtype=paddle.float32).unsqueeze(1)
else:
position = paddle.arange(
0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe
def forward(self, x: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[:, :paddle.shape(x)[1]]
return self.dropout(x), self.dropout(pos_emb)
# MLM -> Mask Language Model
class mySequential(nn.Sequential): class mySequential(nn.Sequential):
def forward(self, *inputs): def forward(self, *inputs):
for module in self._sub_layers.values(): for module in self._sub_layers.values():
...@@ -108,12 +51,8 @@ class mySequential(nn.Sequential): ...@@ -108,12 +51,8 @@ class mySequential(nn.Sequential):
return inputs return inputs
class NewMaskInputLayer(nn.Layer): class MaskInputLayer(nn.Layer):
__constants__ = ['out_features'] def __init__(self, out_features: int) -> None:
out_features: int
def __init__(self, out_features: int, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
self.mask_feature = paddle.create_parameter( self.mask_feature = paddle.create_parameter(
shape=(1, 1, out_features), shape=(1, 1, out_features),
...@@ -121,109 +60,14 @@ class NewMaskInputLayer(nn.Layer): ...@@ -121,109 +60,14 @@ class NewMaskInputLayer(nn.Layer):
default_initializer=paddle.nn.initializer.Assign( default_initializer=paddle.nn.initializer.Assign(
paddle.normal(shape=(1, 1, out_features)))) paddle.normal(shape=(1, 1, out_features))))
def forward(self, input: paddle.Tensor, masked_pos=None) -> paddle.Tensor: def forward(self, input: paddle.Tensor,
masked_pos: paddle.Tensor=None) -> paddle.Tensor:
masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input) masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input)
masked_input = masked_fill(input, masked_pos, 0) + masked_fill( masked_input = masked_fill(input, masked_pos, 0) + masked_fill(
paddle.expand_as(self.mask_feature, input), ~masked_pos, 0) paddle.expand_as(self.mask_feature, input), ~masked_pos, 0)
return masked_input return masked_input
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = paddle.create_parameter(
shape=(self.h, self.d_k),
dtype='float32',
default_initializer=paddle.nn.initializer.XavierUniform())
self.pos_bias_v = paddle.create_parameter(
shape=(self.h, self.d_k),
dtype='float32',
default_initializer=paddle.nn.initializer.XavierUniform())
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x(Tensor): Input tensor (batch, head, time1, time2).
Returns:
Tensor:Output tensor.
"""
b, h, t1, t2 = paddle.shape(x)
zero_pad = paddle.zeros((b, h, t1, 1))
x_padded = paddle.concat([zero_pad, x], axis=-1)
x_padded = paddle.reshape(x_padded, [b, h, t2 + 1, t1])
# only keep the positions from 0 to time2
x = paddle.reshape(x_padded[:, :, 1:], [b, h, t1, t2])
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query(Tensor): Query tensor (#batch, time1, size).
key(Tensor): Key tensor (#batch, time2, size).
value(Tensor): Value tensor (#batch, time2, size).
pos_emb(Tensor): Positional embedding tensor (#batch, time1, size).
mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
Returns:
Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
# (batch, time1, head, d_k)
q = paddle.transpose(q, [0, 2, 1, 3])
n_batch_pos = paddle.shape(pos_emb)[0]
p = paddle.reshape(
self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k])
# (batch, head, time1, d_k)
p = paddle.transpose(p, [0, 2, 1, 3])
# (batch, head, time1, d_k)
q_with_bias_u = paddle.transpose((q + self.pos_bias_u), [0, 2, 1, 3])
# (batch, head, time1, d_k)
q_with_bias_v = paddle.transpose((q + self.pos_bias_v), [0, 2, 1, 3])
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = paddle.matmul(q_with_bias_u,
paddle.transpose(k, [0, 1, 3, 2]))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd = paddle.matmul(q_with_bias_v,
paddle.transpose(p, [0, 1, 3, 2]))
matrix_bd = self.rel_shift(matrix_bd)
# (batch, head, time1, time2)
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
class MLMEncoder(nn.Layer): class MLMEncoder(nn.Layer):
"""Conformer encoder module. """Conformer encoder module.
...@@ -253,47 +97,42 @@ class MLMEncoder(nn.Layer): ...@@ -253,47 +97,42 @@ class MLMEncoder(nn.Layer):
cnn_module_kernel (int): Kernerl size of convolution module. cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed. padding_idx (int): Padding idx for input_layer=embed.
stochastic_depth_rate (float): Maximum probability to skip the encoder layer. stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
indices start from 1.
if not None, intermediate outputs are returned (which changes return type
signature.)
""" """
def __init__(self, def __init__(self,
idim, idim: int,
vocab_size=0, vocab_size: int=0,
pre_speech_layer: int=0, pre_speech_layer: int=0,
attention_dim=256, attention_dim: int=256,
attention_heads=4, attention_heads: int=4,
linear_units=2048, linear_units: int=2048,
num_blocks=6, num_blocks: int=6,
dropout_rate=0.1, dropout_rate: float=0.1,
positional_dropout_rate=0.1, positional_dropout_rate: float=0.1,
attention_dropout_rate=0.0, attention_dropout_rate: float=0.0,
input_layer="conv2d", input_layer: str="conv2d",
normalize_before=True, normalize_before: bool=True,
concat_after=False, concat_after: bool=False,
positionwise_layer_type="linear", positionwise_layer_type: str="linear",
positionwise_conv_kernel_size=1, positionwise_conv_kernel_size: int=1,
macaron_style=False, macaron_style: bool=False,
pos_enc_layer_type="abs_pos", pos_enc_layer_type: str="abs_pos",
pos_enc_class=None, pos_enc_class=None,
selfattention_layer_type="selfattn", selfattention_layer_type: str="selfattn",
activation_type="swish", activation_type: str="swish",
use_cnn_module=False, use_cnn_module: bool=False,
zero_triu=False, zero_triu: bool=False,
cnn_module_kernel=31, cnn_module_kernel: int=31,
padding_idx=-1, padding_idx: int=-1,
stochastic_depth_rate=0.0, stochastic_depth_rate: float=0.0,
intermediate_layers=None, text_masking: bool=False):
text_masking=False):
"""Construct an Encoder object.""" """Construct an Encoder object."""
super().__init__() super().__init__()
self._output_size = attention_dim self._output_size = attention_dim
self.text_masking = text_masking self.text_masking = text_masking
if self.text_masking: if self.text_masking:
self.text_masking_layer = NewMaskInputLayer(attention_dim) self.text_masking_layer = MaskInputLayer(attention_dim)
activation = get_activation(activation_type) activation = get_activation(activation_type)
if pos_enc_layer_type == "abs_pos": if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding pos_enc_class = PositionalEncoding
...@@ -330,7 +169,7 @@ class MLMEncoder(nn.Layer): ...@@ -330,7 +169,7 @@ class MLMEncoder(nn.Layer):
elif input_layer == "mlm": elif input_layer == "mlm":
self.segment_emb = None self.segment_emb = None
self.speech_embed = mySequential( self.speech_embed = mySequential(
NewMaskInputLayer(idim), MaskInputLayer(idim),
nn.Linear(idim, attention_dim), nn.Linear(idim, attention_dim),
nn.LayerNorm(attention_dim), nn.LayerNorm(attention_dim),
nn.ReLU(), nn.ReLU(),
...@@ -343,7 +182,7 @@ class MLMEncoder(nn.Layer): ...@@ -343,7 +182,7 @@ class MLMEncoder(nn.Layer):
self.segment_emb = nn.Embedding( self.segment_emb = nn.Embedding(
500, attention_dim, padding_idx=padding_idx) 500, attention_dim, padding_idx=padding_idx)
self.speech_embed = mySequential( self.speech_embed = mySequential(
NewMaskInputLayer(idim), MaskInputLayer(idim),
nn.Linear(idim, attention_dim), nn.Linear(idim, attention_dim),
nn.LayerNorm(attention_dim), nn.LayerNorm(attention_dim),
nn.ReLU(), nn.ReLU(),
...@@ -365,7 +204,6 @@ class MLMEncoder(nn.Layer): ...@@ -365,7 +204,6 @@ class MLMEncoder(nn.Layer):
# self-attention module definition # self-attention module definition
if selfattention_layer_type == "selfattn": if selfattention_layer_type == "selfattn":
logging.info("encoder self-attention layer type = self-attention")
encoder_selfattn_layer = MultiHeadedAttention encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, attention_dim, encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_dropout_rate, ) attention_dropout_rate, )
...@@ -375,8 +213,6 @@ class MLMEncoder(nn.Layer): ...@@ -375,8 +213,6 @@ class MLMEncoder(nn.Layer):
encoder_selfattn_layer_args = (attention_heads, attention_dim, encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_dropout_rate, ) attention_dropout_rate, )
elif selfattention_layer_type == "rel_selfattn": elif selfattention_layer_type == "rel_selfattn":
logging.info(
"encoder self-attention layer type = relative self-attention")
assert pos_enc_layer_type == "rel_pos" assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, attention_dim, encoder_selfattn_layer_args = (attention_heads, attention_dim,
...@@ -436,49 +272,38 @@ class MLMEncoder(nn.Layer): ...@@ -436,49 +272,38 @@ class MLMEncoder(nn.Layer):
if self.normalize_before: if self.normalize_before:
self.after_norm = LayerNorm(attention_dim) self.after_norm = LayerNorm(attention_dim)
self.intermediate_layers = intermediate_layers
def forward(self, def forward(self,
speech_pad, speech: paddle.Tensor,
text_pad, text: paddle.Tensor,
masked_pos, masked_pos: paddle.Tensor,
speech_mask=None, speech_mask: paddle.Tensor=None,
text_mask=None, text_mask: paddle.Tensor=None,
speech_seg_pos=None, speech_seg_pos: paddle.Tensor=None,
text_seg_pos=None): text_seg_pos: paddle.Tensor=None):
"""Encode input sequence. """Encode input sequence.
""" """
if masked_pos is not None: if masked_pos is not None:
speech_pad = self.speech_embed(speech_pad, masked_pos) speech = self.speech_embed(speech, masked_pos)
else: else:
speech_pad = self.speech_embed(speech_pad) speech = self.speech_embed(speech)
# pure speech input if text is not None:
if -2 in np.array(text_pad): text = self.text_embed(text)
text_pad = text_pad + 3
text_mask = paddle.unsqueeze(bool(text_pad), 1)
text_seg_pos = paddle.zeros_like(text_pad)
text_pad = self.text_embed(text_pad)
text_pad = (text_pad[0] + self.segment_emb(text_seg_pos),
text_pad[1])
text_seg_pos = None
elif text_pad is not None:
text_pad = self.text_embed(text_pad)
if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb: if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb:
speech_seg_emb = self.segment_emb(speech_seg_pos) speech_seg_emb = self.segment_emb(speech_seg_pos)
text_seg_emb = self.segment_emb(text_seg_pos) text_seg_emb = self.segment_emb(text_seg_pos)
text_pad = (text_pad[0] + text_seg_emb, text_pad[1]) text = (text[0] + text_seg_emb, text[1])
speech_pad = (speech_pad[0] + speech_seg_emb, speech_pad[1]) speech = (speech[0] + speech_seg_emb, speech[1])
if self.pre_speech_encoders: if self.pre_speech_encoders:
speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask) speech, _ = self.pre_speech_encoders(speech, speech_mask)
if text_pad is not None: if text is not None:
xs = paddle.concat([speech_pad[0], text_pad[0]], axis=1) xs = paddle.concat([speech[0], text[0]], axis=1)
xs_pos_emb = paddle.concat([speech_pad[1], text_pad[1]], axis=1) xs_pos_emb = paddle.concat([speech[1], text[1]], axis=1)
masks = paddle.concat([speech_mask, text_mask], axis=-1) masks = paddle.concat([speech_mask, text_mask], axis=-1)
else: else:
xs = speech_pad[0] xs = speech[0]
xs_pos_emb = speech_pad[1] xs_pos_emb = speech[1]
masks = speech_mask masks = speech_mask
xs, masks = self.encoders((xs, xs_pos_emb), masks) xs, masks = self.encoders((xs, xs_pos_emb), masks)
...@@ -492,7 +317,7 @@ class MLMEncoder(nn.Layer): ...@@ -492,7 +317,7 @@ class MLMEncoder(nn.Layer):
class MLMDecoder(MLMEncoder): class MLMDecoder(MLMEncoder):
def forward(self, xs, masks, masked_pos=None, segment_emb=None): def forward(self, xs: paddle.Tensor, masks: paddle.Tensor):
"""Encode input sequence. """Encode input sequence.
Args: Args:
...@@ -504,51 +329,19 @@ class MLMDecoder(MLMEncoder): ...@@ -504,51 +329,19 @@ class MLMDecoder(MLMEncoder):
paddle.Tensor: Mask tensor (#batch, time). paddle.Tensor: Mask tensor (#batch, time).
""" """
if not self.training:
masked_pos = None
xs = self.embed(xs) xs = self.embed(xs)
if segment_emb:
xs = (xs[0] + segment_emb, xs[1])
if self.intermediate_layers is None:
xs, masks = self.encoders(xs, masks) xs, masks = self.encoders(xs, masks)
else:
intermediate_outputs = []
for layer_idx, encoder_layer in enumerate(self.encoders):
xs, masks = encoder_layer(xs, masks)
if (self.intermediate_layers is not None and
layer_idx + 1 in self.intermediate_layers):
encoder_output = xs
# intermediate branches also require normalization.
if self.normalize_before:
encoder_output = self.after_norm(encoder_output)
intermediate_outputs.append(encoder_output)
if isinstance(xs, tuple): if isinstance(xs, tuple):
xs = xs[0] xs = xs[0]
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)
if self.intermediate_layers is not None:
return xs, masks, intermediate_outputs
return xs, masks return xs, masks
def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window): # encoder and decoder is nn.Layer, not str
round = max_len % attention_window class MLM(nn.Layer):
if round != 0:
max_tlen += (attention_window - round)
n_batch = paddle.shape(text)[0]
text_pad = paddle.zeros(
shape=(n_batch, max_tlen, *paddle.shape(text[0])[1:]),
dtype=text.dtype)
for i in range(n_batch):
text_pad[i, :paddle.shape(text[i])[0]] = text[i]
else:
text_pad = text[:, :max_tlen]
return text_pad, max_tlen
class MLMModel(nn.Layer):
def __init__(self, def __init__(self,
token_list: Union[Tuple[str, ...], List[str]], token_list: Union[Tuple[str, ...], List[str]],
odim: int, odim: int,
...@@ -557,44 +350,15 @@ class MLMModel(nn.Layer): ...@@ -557,44 +350,15 @@ class MLMModel(nn.Layer):
postnet_layers: int=0, postnet_layers: int=0,
postnet_chans: int=0, postnet_chans: int=0,
postnet_filts: int=0, postnet_filts: int=0,
ignore_id: int=-1, text_masking: bool=False):
lsm_weight: float=0.0,
length_normalized_loss: bool=False,
report_cer: bool=True,
report_wer: bool=True,
sym_space: str="<space>",
sym_blank: str="<blank>",
masking_schema: str="span",
mean_phn_span: int=3,
mlm_prob: float=0.25,
dynamic_mlm_prob=False,
decoder_seg_pos=False,
text_masking=False):
super().__init__() super().__init__()
# note that eos is the same as sos (equivalent ID)
self.odim = odim self.odim = odim
self.ignore_id = ignore_id
self.token_list = token_list.copy() self.token_list = token_list.copy()
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.vocab_size = encoder.text_embed[0]._num_embeddings self.vocab_size = encoder.text_embed[0]._num_embeddings
if report_cer or report_wer:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer)
else:
self.error_calculator = None
self.mlm_weight = 1.0
self.mlm_prob = mlm_prob
self.mlm_layer = 12
self.finetune_wo_mlm = True
self.max_span = 50
self.min_span = 4
self.mean_phn_span = mean_phn_span
self.masking_schema = masking_schema
if self.decoder is None or not (hasattr(self.decoder, if self.decoder is None or not (hasattr(self.decoder,
'output_layer') and 'output_layer') and
self.decoder.output_layer is not None): self.decoder.output_layer is not None):
...@@ -606,15 +370,9 @@ class MLMModel(nn.Layer): ...@@ -606,15 +370,9 @@ class MLMModel(nn.Layer):
self.encoder.text_embed[0]._embedding_dim, self.encoder.text_embed[0]._embedding_dim,
self.vocab_size, self.vocab_size,
weight_attr=self.encoder.text_embed[0]._weight_attr) weight_attr=self.encoder.text_embed[0]._weight_attr)
self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id)
else: else:
self.text_sfc = None self.text_sfc = None
self.text_mlm_loss = None
self.decoder_seg_pos = decoder_seg_pos
if lsm_weight > 50:
self.l1_loss_func = nn.MSELoss()
else:
self.l1_loss_func = nn.L1Loss(reduction='none')
self.postnet = (None if postnet_layers == 0 else Postnet( self.postnet = (None if postnet_layers == 0 else Postnet(
idim=self.encoder._output_size, idim=self.encoder._output_size,
odim=odim, odim=odim,
...@@ -624,119 +382,77 @@ class MLMModel(nn.Layer): ...@@ -624,119 +382,77 @@ class MLMModel(nn.Layer):
use_batch_norm=True, use_batch_norm=True,
dropout_rate=0.5, )) dropout_rate=0.5, ))
def collect_feats(self,
speech,
speech_lens,
text,
text_lens,
masked_pos,
speech_mask,
text_mask,
speech_seg_pos,
text_seg_pos,
y_masks=None) -> Dict[str, paddle.Tensor]:
return {"feats": speech, "feats_lens": speech_lens}
def forward(self, batch, speech_seg_pos, y_masks=None):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad']
if self.decoder is not None:
ys_in = self._add_first_frame_and_remove_last_frame(
batch['speech_pad'])
encoder_out, h_masks = self.encoder(**batch)
if self.decoder is not None:
zs, _ = self.decoder(ys_in, y_masks, encoder_out,
bool(h_masks),
self.encoder.segment_emb(speech_seg_pos))
speech_hidden_states = zs
else:
speech_hidden_states = encoder_out[:, :paddle.shape(batch[
'speech_pad'])[1], :]
if self.sfc is not None:
before_outs = paddle.reshape(
self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
else:
before_outs = speech_hidden_states
if self.postnet is not None:
after_outs = before_outs + paddle.transpose(
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
(0, 2, 1))
else:
after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_pos']
def inference( def inference(
self, self,
speech, speech: paddle.Tensor,
text, text: paddle.Tensor,
masked_pos, masked_pos: paddle.Tensor,
speech_mask, speech_mask: paddle.Tensor,
text_mask, text_mask: paddle.Tensor,
speech_seg_pos, speech_seg_pos: paddle.Tensor,
text_seg_pos, text_seg_pos: paddle.Tensor,
span_bdy, span_bdy: List[int],
y_masks=None,
speech_lens=None,
text_lens=None,
feats: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
sids: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
threshold: float=0.5,
minlenratio: float=0.0,
maxlenratio: float=10.0,
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]: use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
'''
Args:
speech (paddle.Tensor): input speech (B, Tmax, D).
text (paddle.Tensor): input text (B, Tmax2).
masked_pos (paddle.Tensor): masked position of input speech (B, Tmax)
speech_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
text_mask (paddle.Tensor): mask of text (B, 1, Tmax2).
speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax).
text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2).
span_bdy (List[int]): masked mel boundary of input speech (2,)
use_teacher_forcing (bool): whether to use teacher forcing
Returns:
List[Tensor]:
eg:
[Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])]
'''
batch = dict( outs = [speech[:, :span_bdy[0]]]
speech_pad=speech, z_cache = None
text_pad=text, if use_teacher_forcing:
before_outs, zs, *_ = self.forward(
speech=speech,
text=text,
masked_pos=masked_pos, masked_pos=masked_pos,
speech_mask=speech_mask, speech_mask=speech_mask,
text_mask=text_mask, text_mask=text_mask,
speech_seg_pos=speech_seg_pos, speech_seg_pos=speech_seg_pos,
text_seg_pos=text_seg_pos, ) text_seg_pos=text_seg_pos)
# # inference with teacher forcing
# hs, h_masks = self.encoder(**batch)
outs = [batch['speech_pad'][:, :span_bdy[0]]]
z_cache = None
if use_teacher_forcing:
before, zs, _, _ = self.forward(
batch, speech_seg_pos, y_masks=y_masks)
if zs is None: if zs is None:
zs = before zs = before_outs
outs += [zs[0][span_bdy[0]:span_bdy[1]]] outs += [zs[0][span_bdy[0]:span_bdy[1]]]
outs += [batch['speech_pad'][:, span_bdy[1]:]] outs += [speech[:, span_bdy[1]:]]
return dict(feat_gen=outs) return outs
return None return None
def _add_first_frame_and_remove_last_frame(
self, ys: paddle.Tensor) -> paddle.Tensor:
ys_in = paddle.concat(
[
paddle.zeros(
shape=(paddle.shape(ys)[0], 1, paddle.shape(ys)[2]),
dtype=ys.dtype), ys[:, :-1]
],
axis=1)
return ys_in
class MLMEncAsDecoder(MLM):
class MLMEncAsDecoderModel(MLMModel): def forward(self,
def forward(self, batch, speech_seg_pos, y_masks=None): speech: paddle.Tensor,
text: paddle.Tensor,
masked_pos: paddle.Tensor,
speech_mask: paddle.Tensor,
text_mask: paddle.Tensor,
speech_seg_pos: paddle.Tensor,
text_seg_pos: paddle.Tensor):
# feats: (Batch, Length, Dim) # feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2) # -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad'] encoder_out, h_masks = self.encoder(
encoder_out, h_masks = self.encoder(**batch) # segment_emb speech=speech,
text=text,
masked_pos=masked_pos,
speech_mask=speech_mask,
text_mask=text_mask,
speech_seg_pos=speech_seg_pos,
text_seg_pos=text_seg_pos)
if self.decoder is not None: if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks) zs, _ = self.decoder(encoder_out, h_masks)
else: else:
zs = encoder_out zs = encoder_out
speech_hidden_states = zs[:, :paddle.shape(batch['speech_pad'])[1], :] speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
if self.sfc is not None: if self.sfc is not None:
before_outs = paddle.reshape( before_outs = paddle.reshape(
self.sfc(speech_hidden_states), self.sfc(speech_hidden_states),
...@@ -749,53 +465,35 @@ class MLMEncAsDecoderModel(MLMModel): ...@@ -749,53 +465,35 @@ class MLMEncAsDecoderModel(MLMModel):
[0, 2, 1]) [0, 2, 1])
else: else:
after_outs = None after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch[ return before_outs, after_outs, None
'masked_pos']
class MLMDualMaksing(MLM):
class MLMDualMaksingModel(MLMModel): def forward(self,
def _calc_mlm_loss(self, speech: paddle.Tensor,
before_outs: paddle.Tensor, text: paddle.Tensor,
after_outs: paddle.Tensor, masked_pos: paddle.Tensor,
text_outs: paddle.Tensor, speech_mask: paddle.Tensor,
batch): text_mask: paddle.Tensor,
xs_pad = batch['speech_pad'] speech_seg_pos: paddle.Tensor,
text_pad = batch['text_pad'] text_seg_pos: paddle.Tensor):
masked_pos = batch['masked_pos']
text_masked_pos = batch['text_masked_pos']
mlm_loss_pos = masked_pos > 0
loss = paddle.sum(
self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
if after_outs is not None:
loss += paddle.sum(
self.l1_loss_func(
paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
loss_mlm = paddle.sum((loss * paddle.reshape(
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
loss_text = paddle.sum((self.text_mlm_loss(
paddle.reshape(text_outs, (-1, self.vocab_size)),
paddle.reshape(text_pad, (-1))) * paddle.reshape(
text_masked_pos, (-1)))) / paddle.sum((text_masked_pos) + 1e-10)
return loss_mlm, loss_text
def forward(self, batch, speech_seg_pos, y_masks=None):
# feats: (Batch, Length, Dim) # feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2) # -> encoder_out: (Batch, Length2, Dim2)
encoder_out, h_masks = self.encoder(**batch) # segment_emb encoder_out, h_masks = self.encoder(
speech=speech,
text=text,
masked_pos=masked_pos,
speech_mask=speech_mask,
text_mask=text_mask,
speech_seg_pos=speech_seg_pos,
text_seg_pos=text_seg_pos)
if self.decoder is not None: if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks) zs, _ = self.decoder(encoder_out, h_masks)
else: else:
zs = encoder_out zs = encoder_out
speech_hidden_states = zs[:, :paddle.shape(batch['speech_pad'])[1], :] speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
if self.text_sfc: if self.text_sfc:
text_hiddent_states = zs[:, paddle.shape(batch['speech_pad'])[ text_hiddent_states = zs[:, paddle.shape(speech)[1]:, :]
1]:, :]
text_outs = paddle.reshape( text_outs = paddle.reshape(
self.text_sfc(text_hiddent_states), self.text_sfc(text_hiddent_states),
(paddle.shape(text_hiddent_states)[0], -1, self.vocab_size)) (paddle.shape(text_hiddent_states)[0], -1, self.vocab_size))
...@@ -811,27 +509,25 @@ class MLMDualMaksingModel(MLMModel): ...@@ -811,27 +509,25 @@ class MLMDualMaksingModel(MLMModel):
[0, 2, 1]) [0, 2, 1])
else: else:
after_outs = None after_outs = None
return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_pos'],batch['text_masked_pos'] return before_outs, after_outs, text_outs
def build_model_from_file(config_file, model_file): def build_model_from_file(config_file, model_file):
state_dict = paddle.load(model_file) state_dict = paddle.load(model_file)
model_class = MLMDualMaksingModel if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \ model_class = MLMDualMaksing if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \
else MLMEncAsDecoderModel else MLMEncAsDecoder
# 构建模型 # 构建模型
args = yaml.safe_load(Path(config_file).open("r", encoding="utf-8")) with open(config_file) as f:
args = argparse.Namespace(**args) conf = CfgNode(yaml.safe_load(f))
model = build_model(conf, model_class)
model = build_model(args, model_class)
model.set_state_dict(state_dict) model.set_state_dict(state_dict)
return model, args return model, conf
def build_model(args: argparse.Namespace, # select encoder and decoder here
model_class=MLMEncAsDecoderModel) -> MLMModel: def build_model(args: argparse.Namespace, model_class=MLMEncAsDecoder) -> MLM:
if isinstance(args.token_list, str): if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f: with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f] token_list = [line.rstrip() for line in f]
...@@ -842,9 +538,8 @@ def build_model(args: argparse.Namespace, ...@@ -842,9 +538,8 @@ def build_model(args: argparse.Namespace,
token_list = list(args.token_list) token_list = list(args.token_list)
else: else:
raise RuntimeError("token_list must be str or list") raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size }")
vocab_size = len(token_list)
odim = 80 odim = 80
pos_enc_class = ScaledPositionalEncoding if args.use_scaled_pos_enc else PositionalEncoding pos_enc_class = ScaledPositionalEncoding if args.use_scaled_pos_enc else PositionalEncoding
...@@ -857,17 +552,8 @@ def build_model(args: argparse.Namespace, ...@@ -857,17 +552,8 @@ def build_model(args: argparse.Namespace,
if conformer_rel_pos_type == "legacy": if conformer_rel_pos_type == "legacy":
if conformer_pos_enc_layer_type == "rel_pos": if conformer_pos_enc_layer_type == "rel_pos":
conformer_pos_enc_layer_type = "legacy_rel_pos" conformer_pos_enc_layer_type = "legacy_rel_pos"
logging.warning(
"Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' "
"due to the compatibility. If you want to use the new one, "
"please use conformer_pos_enc_layer_type = 'latest'.")
if conformer_self_attn_layer_type == "rel_selfattn": if conformer_self_attn_layer_type == "rel_selfattn":
conformer_self_attn_layer_type = "legacy_rel_selfattn" conformer_self_attn_layer_type = "legacy_rel_selfattn"
logging.warning(
"Fallback to "
"conformer_self_attn_layer_type = 'legacy_rel_selfattn' "
"due to the compatibility. If you want to use the new one, "
"please use conformer_pos_enc_layer_type = 'latest'.")
elif conformer_rel_pos_type == "latest": elif conformer_rel_pos_type == "latest":
assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_pos_enc_layer_type != "legacy_rel_pos"
assert conformer_self_attn_layer_type != "legacy_rel_selfattn" assert conformer_self_attn_layer_type != "legacy_rel_selfattn"
......
import paddle
from paddle import nn
class MLMLoss(nn.Layer):
def __init__(self,
lsm_weight: float=0.1,
ignore_id: int=-1,
text_masking: bool=False):
super().__init__()
if text_masking:
self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id)
if lsm_weight > 50:
self.l1_loss_func = nn.MSELoss()
else:
self.l1_loss_func = nn.L1Loss(reduction='none')
self.text_masking = text_masking
def forward(self,
speech: paddle.Tensor,
before_outs: paddle.Tensor,
after_outs: paddle.Tensor,
masked_pos: paddle.Tensor,
text: paddle.Tensor=None,
text_outs: paddle.Tensor=None,
text_masked_pos: paddle.Tensor=None):
xs_pad = speech
mlm_loss_pos = masked_pos > 0
loss = paddle.sum(
self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
if after_outs is not None:
loss += paddle.sum(
self.l1_loss_func(
paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
loss_mlm = paddle.sum((loss * paddle.reshape(
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
if self.text_masking:
loss_text = paddle.sum((self.text_mlm_loss(
paddle.reshape(text_outs, (-1, self.vocab_size)),
paddle.reshape(text, (-1))) * paddle.reshape(
text_masked_pos,
(-1)))) / paddle.sum((text_masked_pos) + 1e-10)
return loss_mlm, loss_text
return loss_mlm
...@@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask) return self.forward_attention(v, scores, mask)
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = paddle.create_parameter(
shape=(self.h, self.d_k),
dtype='float32',
default_initializer=paddle.nn.initializer.XavierUniform())
self.pos_bias_v = paddle.create_parameter(
shape=(self.h, self.d_k),
dtype='float32',
default_initializer=paddle.nn.initializer.XavierUniform())
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x(Tensor): Input tensor (batch, head, time1, time2).
Returns:
Tensor:Output tensor.
"""
b, h, t1, t2 = paddle.shape(x)
zero_pad = paddle.zeros((b, h, t1, 1))
x_padded = paddle.concat([zero_pad, x], axis=-1)
x_padded = paddle.reshape(x_padded, [b, h, t2 + 1, t1])
# only keep the positions from 0 to time2
x = paddle.reshape(x_padded[:, :, 1:], [b, h, t1, t2])
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query(Tensor): Query tensor (#batch, time1, size).
key(Tensor): Key tensor (#batch, time2, size).
value(Tensor): Value tensor (#batch, time2, size).
pos_emb(Tensor): Positional embedding tensor (#batch, time1, size).
mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
Returns:
Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
# (batch, time1, head, d_k)
q = paddle.transpose(q, [0, 2, 1, 3])
n_batch_pos = paddle.shape(pos_emb)[0]
p = paddle.reshape(
self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k])
# (batch, head, time1, d_k)
p = paddle.transpose(p, [0, 2, 1, 3])
# (batch, head, time1, d_k)
q_with_bias_u = paddle.transpose((q + self.pos_bias_u), [0, 2, 1, 3])
# (batch, head, time1, d_k)
q_with_bias_v = paddle.transpose((q + self.pos_bias_v), [0, 2, 1, 3])
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = paddle.matmul(q_with_bias_u,
paddle.transpose(k, [0, 1, 3, 2]))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd = paddle.matmul(q_with_bias_v,
paddle.transpose(p, [0, 1, 3, 2]))
matrix_bd = self.rel_shift(matrix_bd)
# (batch, head, time1, time2)
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
...@@ -185,3 +185,63 @@ class RelPositionalEncoding(nn.Layer): ...@@ -185,3 +185,63 @@ class RelPositionalEncoding(nn.Layer):
pe_size = paddle.shape(self.pe) pe_size = paddle.shape(self.pe)
pos_emb = self.pe[:, pe_size[1] // 2 - T + 1:pe_size[1] // 2 + T, ] pos_emb = self.pe[:, pe_size[1] // 2 - T + 1:pe_size[1] // 2 + T, ]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
class LegacyRelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]:
return
pe = paddle.zeros((paddle.shape(x)[1], self.d_model))
if self.reverse:
position = paddle.arange(
paddle.shape(x)[1] - 1, -1, -1.0,
dtype=paddle.float32).unsqueeze(1)
else:
position = paddle.arange(
0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe
def forward(self, x: paddle.Tensor):
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[:, :paddle.shape(x)[1]]
return self.dropout(x), self.dropout(pos_emb)
...@@ -5,7 +5,7 @@ from typing import List ...@@ -5,7 +5,7 @@ from typing import List
from typing import Union from typing import Union
def read_2column_text(path: Union[Path, str]) -> Dict[str, str]: def read_2col_text(path: Union[Path, str]) -> Dict[str, str]:
"""Read a text file having 2 column as dict object. """Read a text file having 2 column as dict object.
Examples: Examples:
...@@ -13,7 +13,7 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]: ...@@ -13,7 +13,7 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
key1 /some/path/a.wav key1 /some/path/a.wav
key2 /some/path/b.wav key2 /some/path/b.wav
>>> read_2column_text('wav.scp') >>> read_2col_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'} {'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
""" """
......
...@@ -65,12 +65,6 @@ def parse_args(): ...@@ -65,12 +65,6 @@ def parse_args():
help="mean and standard deviation used to normalize spectrogram when training voc." help="mean and standard deviation used to normalize spectrogram when training voc."
) )
# other # other
parser.add_argument(
'--lang',
type=str,
default='en',
help='Choose model language. zh or en')
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
# parser.add_argument("--test_metadata", type=str, help="test metadata.") # parser.add_argument("--test_metadata", type=str, help="test metadata.")
......
...@@ -32,7 +32,6 @@ model_alias = { ...@@ -32,7 +32,6 @@ model_alias = {
"paddlespeech.t2s.models.parallel_wavegan:PWGInference", "paddlespeech.t2s.models.parallel_wavegan:PWGInference",
} }
def is_chinese(ch): def is_chinese(ch):
if u'\u4e00' <= ch <= u'\u9fff': if u'\u4e00' <= ch <= u'\u9fff':
return True return True
...@@ -55,12 +54,10 @@ def build_vocoder_from_file( ...@@ -55,12 +54,10 @@ def build_vocoder_from_file(
raise ValueError(f"{vocoder_file} is not supported format.") raise ValueError(f"{vocoder_file} is not supported format.")
def get_voc_out(mel, target_lang: str="chinese"): def get_voc_out(mel):
# vocoder # vocoder
args = parse_args() args = parse_args()
assert target_lang == "chinese" or target_lang == "english", "In get_voc_out function, target_lang is illegal..."
# print("current vocoder: ", args.voc) # print("current vocoder: ", args.voc)
with open(args.voc_config) as f: with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f)) voc_config = CfgNode(yaml.safe_load(f))
...@@ -167,19 +164,23 @@ def get_voc_inference( ...@@ -167,19 +164,23 @@ def get_voc_inference(
return voc_inference return voc_inference
def evaluate_durations(phns: List[str], def evaluate_durations(phns, target_lang="chinese", fs=24000, hop_length=300):
target_lang: str="chinese",
fs: int=24000,
hop_length: int=300):
args = parse_args() args = parse_args()
if target_lang == 'english': if target_lang == 'english':
args.lang = 'en' args.am = "fastspeech2_ljspeech"
args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_lang == 'chinese': elif target_lang == 'chinese':
args.lang = 'zh' args.am = "fastspeech2_csmsc"
args.am_config="download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict ="download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[])
if args.ngpu == 0: if args.ngpu == 0:
paddle.set_device("cpu") paddle.set_device("cpu")
elif args.ngpu > 0: elif args.ngpu > 0:
...@@ -187,8 +188,6 @@ def evaluate_durations(phns: List[str], ...@@ -187,8 +188,6 @@ def evaluate_durations(phns: List[str],
else: else:
print("ngpu should >= 0 !") print("ngpu should >= 0 !")
assert target_lang == "chinese" or target_lang == "english", "In evaluate_durations function, target_lang is illegal..."
# Init body. # Init body.
with open(args.am_config) as f: with open(args.am_config) as f:
am_config = CfgNode(yaml.safe_load(f)) am_config = CfgNode(yaml.safe_load(f))
...@@ -203,21 +202,19 @@ def evaluate_durations(phns: List[str], ...@@ -203,21 +202,19 @@ def evaluate_durations(phns: List[str],
speaker_dict=args.speaker_dict, speaker_dict=args.speaker_dict,
return_am=True) return_am=True)
torch_phns = phns
vocab_phones = {} vocab_phones = {}
with open(args.phones_dict, "r") as f: with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
for tone, id in phn_id: for tone, id in phn_id:
vocab_phones[tone] = int(id) vocab_phones[tone] = int(id)
vocab_size = len(vocab_phones) vocab_size = len(vocab_phones)
phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns] phonemes = [phn if phn in vocab_phones else "sp" for phn in phns]
phone_ids = [vocab_phones[item] for item in phonemes] phone_ids = [vocab_phones[item] for item in phonemes]
phone_ids_new = phone_ids phone_ids_new = phone_ids
phone_ids_new.append(vocab_size - 1) phone_ids_new.append(vocab_size - 1)
phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64)) phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64))
normalized_mel, d_outs, p_outs, e_outs = am.inference( _, d_outs, _, _ = am.inference(phone_ids_new, spk_id=None, spk_emb=None)
phone_ids_new, spk_id=None, spk_emb=None)
pre_d_outs = d_outs pre_d_outs = d_outs
phoneme_durations_new = pre_d_outs * hop_length / fs phoneme_durations_new = pre_d_outs * hop_length / fs
phoneme_durations_new = phoneme_durations_new.tolist()[:-1] phoneme_durations_new = phoneme_durations_new.tolist()[:-1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册