提交 76b654cb 编写于 作者: 小湉湉's avatar 小湉湉

format ernie sat

上级 6bcb213c
...@@ -113,8 +113,8 @@ prompt/dev ...@@ -113,8 +113,8 @@ prompt/dev
8. ` --uid` 特定提示(prompt)语音的 id 8. ` --uid` 特定提示(prompt)语音的 id
9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本) 9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本)
10. ` --prefix` 特定音频对应的文本、音素相关文件的地址 10. ` --prefix` 特定音频对应的文本、音素相关文件的地址
11. ` --source_language` , 源语言 11. ` --source_lang` , 源语言
12. ` --target_language` , 目标语言 12. ` --target_lang` , 目标语言
13. ` --output_name` , 合成语音名称 13. ` --output_name` , 合成语音名称
14. ` --task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务 14. ` --task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
15. ` --use_pt_vocoder`, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder 15. ` --use_pt_vocoder`, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder
......
#!/usr/bin/env python #!/usr/bin/env python
""" Usage: """ Usage:
align_english.py wavfile trsfile outwordfile outphonefile align.py wavfile trsfile outwordfile outphonefile
""" """
import multiprocessing as mp import multiprocessing as mp
import os import os
...@@ -9,12 +9,45 @@ import sys ...@@ -9,12 +9,45 @@ import sys
from tqdm import tqdm from tqdm import tqdm
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR = 'tools/aligner/english' MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin'
HVITE = 'tools/htk/HTKTools/HVite' HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy' HCOPY = 'tools/htk/HTKTools/HCopy'
def prep_txt(line, tmpbase, dictfile): def prep_txt_zh(line: str, tmpbase: str, dictfile: str):
words = []
line = line.strip()
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
with open(dictfile, 'r') as fid:
for line in fid:
ds.add(line.split()[0])
unk_words = set([])
with open(tmpbase + '.txt', 'w') as fwid:
for wrd in words:
if (wrd not in ds):
unk_words.add(wrd)
fwid.write(wrd + ' ')
fwid.write('\n')
return unk_words
def prep_txt_en(line: str, tmpbase, dictfile):
words = [] words = []
...@@ -97,7 +130,7 @@ def prep_txt(line, tmpbase, dictfile): ...@@ -97,7 +130,7 @@ def prep_txt(line, tmpbase, dictfile):
fw.close() fw.close()
def prep_mlf(txt, tmpbase): def prep_mlf(txt: str, tmpbase: str):
with open(tmpbase + '.mlf', 'w') as fwid: with open(tmpbase + '.mlf', 'w') as fwid:
fwid.write('#!MLF!#\n') fwid.write('#!MLF!#\n')
...@@ -110,7 +143,55 @@ def prep_mlf(txt, tmpbase): ...@@ -110,7 +143,55 @@ def prep_mlf(txt, tmpbase):
fwid.write('.\n') fwid.write('.\n')
def gen_res(tmpbase, outfile1, outfile2): def _get_user():
return os.path.expanduser('~').split("/")[-1]
def alignment(wav_path: str, text: str):
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -')
except:
print('sox error!')
return None
#prepare clean_transcript file
try:
prep_txt_en(text, tmpbase, MODEL_DIR_EN + '/dict')
except:
print('prep_txt error!')
return None
#prepare mlf file
try:
with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline()
prep_mlf(txt, tmpbase)
except:
print('prep_mlf error!')
return None
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR_EN + '/16000/macros -H ' + MODEL_DIR_EN
+ '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase +
'.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
print('HVite error!')
return None
with open(tmpbase + '.txt', 'r') as fid: with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split() words = fid.readline().strip().split()
words = txt.strip().split() words = txt.strip().split()
...@@ -119,59 +200,47 @@ def gen_res(tmpbase, outfile1, outfile2): ...@@ -119,59 +200,47 @@ def gen_res(tmpbase, outfile1, outfile2):
with open(tmpbase + '.aligned', 'r') as fid: with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines() lines = fid.readlines()
i = 2 i = 2
times1 = []
times2 = [] times2 = []
word2phns = {}
current_word = ''
index = 0
while (i < len(lines)): while (i < len(lines)):
if (len(lines[i].split()) >= 4) and ( splited_line = lines[i].strip().split()
lines[i].split()[0] != lines[i].split()[1]): if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = lines[i].split()[2] phn = splited_line[2]
pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000 pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000 pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen]) times2.append([phn, pst, pen])
if (len(lines[i].split()) == 5): # splited_line[-1]!='sp'
if (lines[i].split()[0] != lines[i].split()[1]): if len(splited_line) == 5:
wrd = lines[i].split()[-1].strip() current_word = str(index) + '_' + splited_line[-1]
st = (int(lines[i].split()[0]) / 1000 + 125) / 10000 word2phns[current_word] = phn
j = i + 1 index += 1
while (lines[j] != '.\n') and (len(lines[j].split()) != 5): elif len(splited_line) == 4:
j += 1 word2phns[current_word] += ' ' + phn
en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
times1.append([wrd, st, en])
i += 1 i += 1
return times2, word2phns
with open(outfile1, 'w') as fwid:
for item in times1:
if (item[0] == 'sp'):
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' SIL\n')
else:
wrd = words.pop()
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + wrd + '\n')
if words:
print('not matched::' + alignfile)
sys.exit(1)
with open(outfile2, 'w') as fwid:
for item in times2:
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')
def _get_user():
return os.path.expanduser('~').split("/")[-1]
def alignment(wav_path, text_string): def alignment_zh(wav_path, text_string):
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files #prepare wav and trs files
try: try:
os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -') os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -')
except: except:
print('sox error!') print('sox error!')
return None return None
#prepare clean_transcript file #prepare clean_transcript file
try: try:
prep_txt(text_string, tmpbase, MODEL_DIR + '/dict') unk_words = prep_txt_zh(text_string, tmpbase, MODEL_DIR_ZH + '/dict')
if unk_words:
print('Error! Please add the following words to dictionary:')
for unk in unk_words:
print("非法words: ", unk)
except: except:
print('prep_txt error!') print('prep_txt error!')
return None return None
...@@ -187,7 +256,7 @@ def alignment(wav_path, text_string): ...@@ -187,7 +256,7 @@ def alignment(wav_path, text_string):
#prepare scp #prepare scp
try: try:
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp') '.wav' + ' ' + tmpbase + '.plp')
except: except:
print('HCopy error!') print('HCopy error!')
...@@ -196,10 +265,11 @@ def alignment(wav_path, text_string): ...@@ -196,10 +265,11 @@ def alignment(wav_path, text_string):
#run alignment #run alignment
try: try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + '.mlf -H ' + MODEL_DIR_ZH + '/16000/macros -H ' + MODEL_DIR_ZH
'/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR_ZH
'.dict ' + MODEL_DIR + '/monophones ' + tmpbase + + '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null') '.plp 2>&1 > /dev/null')
except: except:
print('HVite error!') print('HVite error!')
return None return None
...@@ -211,6 +281,7 @@ def alignment(wav_path, text_string): ...@@ -211,6 +281,7 @@ def alignment(wav_path, text_string):
with open(tmpbase + '.aligned', 'r') as fid: with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines() lines = fid.readlines()
i = 2 i = 2
times2 = [] times2 = []
word2phns = {} word2phns = {}
......
#!/usr/bin/env python
""" Usage:
align_mandarin.py wavfile trsfile outwordfile putphonefile
"""
import multiprocessing as mp
import os
import sys
from tqdm import tqdm
MODEL_DIR = 'tools/aligner/mandarin'
HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy'
def prep_txt(line, tmpbase, dictfile):
words = []
line = line.strip()
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
with open(dictfile, 'r') as fid:
for line in fid:
ds.add(line.split()[0])
unk_words = set([])
with open(tmpbase + '.txt', 'w') as fwid:
for wrd in words:
if (wrd not in ds):
unk_words.add(wrd)
fwid.write(wrd + ' ')
fwid.write('\n')
return unk_words
def prep_mlf(txt, tmpbase):
with open(tmpbase + '.mlf', 'w') as fwid:
fwid.write('#!MLF!#\n')
fwid.write('"' + tmpbase + '.lab"\n')
fwid.write('sp\n')
wrds = txt.split()
for wrd in wrds:
fwid.write(wrd.upper() + '\n')
fwid.write('sp\n')
fwid.write('.\n')
def gen_res(tmpbase, outfile1, outfile2):
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
words = txt.strip().split()
words.reverse()
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
times1 = []
times2 = []
while (i < len(lines)):
if (len(lines[i].split()) >= 4) and (
lines[i].split()[0] != lines[i].split()[1]):
phn = lines[i].split()[2]
pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000
pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
if (len(lines[i].split()) == 5):
if (lines[i].split()[0] != lines[i].split()[1]):
wrd = lines[i].split()[-1].strip()
st = (int(lines[i].split()[0]) / 1000 + 125) / 10000
j = i + 1
while (lines[j] != '.\n') and (len(lines[j].split()) != 5):
j += 1
en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
times1.append([wrd, st, en])
i += 1
with open(outfile1, 'w') as fwid:
for item in times1:
if (item[0] == 'sp'):
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' SIL\n')
else:
wrd = words.pop()
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + wrd + '\n')
if words:
print('not matched::' + alignfile)
sys.exit(1)
with open(outfile2, 'w') as fwid:
for item in times2:
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')
def alignment_zh(wav_path, text_string):
tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -')
except:
print('sox error!')
return None
#prepare clean_transcript file
try:
unk_words = prep_txt(text_string, tmpbase, MODEL_DIR + '/dict')
if unk_words:
print('Error! Please add the following words to dictionary:')
for unk in unk_words:
print("非法words: ", unk)
except:
print('prep_txt error!')
return None
#prepare mlf file
try:
with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline()
prep_mlf(txt, tmpbase)
except:
print('prep_mlf error!')
return None
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR +
'/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR +
'/dict ' + MODEL_DIR + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
print('HVite error!')
return None
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
words = txt.strip().split()
words.reverse()
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
times2 = []
word2phns = {}
current_word = ''
index = 0
while (i < len(lines)):
splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn
index += 1
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return times2, word2phns
此差异已折叠。
此差异已折叠。
...@@ -121,12 +121,10 @@ class NewMaskInputLayer(nn.Layer): ...@@ -121,12 +121,10 @@ class NewMaskInputLayer(nn.Layer):
default_initializer=paddle.nn.initializer.Assign( default_initializer=paddle.nn.initializer.Assign(
paddle.normal(shape=(1, 1, out_features)))) paddle.normal(shape=(1, 1, out_features))))
def forward(self, input: paddle.Tensor, def forward(self, input: paddle.Tensor, masked_pos=None) -> paddle.Tensor:
masked_position=None) -> paddle.Tensor: masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input)
masked_position = paddle.expand_as( masked_input = masked_fill(input, masked_pos, 0) + masked_fill(
paddle.unsqueeze(masked_position, -1), input) paddle.expand_as(self.mask_feature, input), ~masked_pos, 0)
masked_input = masked_fill(input, masked_position, 0) + masked_fill(
paddle.expand_as(self.mask_feature, input), ~masked_position, 0)
return masked_input return masked_input
...@@ -443,37 +441,34 @@ class MLMEncoder(nn.Layer): ...@@ -443,37 +441,34 @@ class MLMEncoder(nn.Layer):
def forward(self, def forward(self,
speech_pad, speech_pad,
text_pad, text_pad,
masked_position, masked_pos,
speech_mask=None, speech_mask=None,
text_mask=None, text_mask=None,
speech_segment_pos=None, speech_seg_pos=None,
text_segment_pos=None): text_seg_pos=None):
"""Encode input sequence. """Encode input sequence.
""" """
if masked_position is not None: if masked_pos is not None:
speech_pad = self.speech_embed(speech_pad, masked_position) speech_pad = self.speech_embed(speech_pad, masked_pos)
else: else:
speech_pad = self.speech_embed(speech_pad) speech_pad = self.speech_embed(speech_pad)
# pure speech input # pure speech input
if -2 in np.array(text_pad): if -2 in np.array(text_pad):
text_pad = text_pad + 3 text_pad = text_pad + 3
text_mask = paddle.unsqueeze(bool(text_pad), 1) text_mask = paddle.unsqueeze(bool(text_pad), 1)
text_segment_pos = paddle.zeros_like(text_pad) text_seg_pos = paddle.zeros_like(text_pad)
text_pad = self.text_embed(text_pad) text_pad = self.text_embed(text_pad)
text_pad = (text_pad[0] + self.segment_emb(text_segment_pos), text_pad = (text_pad[0] + self.segment_emb(text_seg_pos),
text_pad[1]) text_pad[1])
text_segment_pos = None text_seg_pos = None
elif text_pad is not None: elif text_pad is not None:
text_pad = self.text_embed(text_pad) text_pad = self.text_embed(text_pad)
segment_emb = None if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb:
if speech_segment_pos is not None and text_segment_pos is not None and self.segment_emb: speech_seg_emb = self.segment_emb(speech_seg_pos)
speech_segment_emb = self.segment_emb(speech_segment_pos) text_seg_emb = self.segment_emb(text_seg_pos)
text_segment_emb = self.segment_emb(text_segment_pos) text_pad = (text_pad[0] + text_seg_emb, text_pad[1])
text_pad = (text_pad[0] + text_segment_emb, text_pad[1]) speech_pad = (speech_pad[0] + speech_seg_emb, speech_pad[1])
speech_pad = (speech_pad[0] + speech_segment_emb, speech_pad[1])
segment_emb = paddle.concat(
[speech_segment_emb, text_segment_emb], axis=1)
if self.pre_speech_encoders: if self.pre_speech_encoders:
speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask) speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask)
...@@ -493,11 +488,11 @@ class MLMEncoder(nn.Layer): ...@@ -493,11 +488,11 @@ class MLMEncoder(nn.Layer):
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)
return xs, masks #, segment_emb return xs, masks
class MLMDecoder(MLMEncoder): class MLMDecoder(MLMEncoder):
def forward(self, xs, masks, masked_position=None, segment_emb=None): def forward(self, xs, masks, masked_pos=None, segment_emb=None):
"""Encode input sequence. """Encode input sequence.
Args: Args:
...@@ -509,9 +504,8 @@ class MLMDecoder(MLMEncoder): ...@@ -509,9 +504,8 @@ class MLMDecoder(MLMEncoder):
paddle.Tensor: Mask tensor (#batch, time). paddle.Tensor: Mask tensor (#batch, time).
""" """
emb, mlm_position = None, None
if not self.training: if not self.training:
masked_position = None masked_pos = None
xs = self.embed(xs) xs = self.embed(xs)
if segment_emb: if segment_emb:
xs = (xs[0] + segment_emb, xs[1]) xs = (xs[0] + segment_emb, xs[1])
...@@ -632,18 +626,18 @@ class MLMModel(nn.Layer): ...@@ -632,18 +626,18 @@ class MLMModel(nn.Layer):
def collect_feats(self, def collect_feats(self,
speech, speech,
speech_lengths, speech_lens,
text, text,
text_lengths, text_lens,
masked_position, masked_pos,
speech_mask, speech_mask,
text_mask, text_mask,
speech_segment_pos, speech_seg_pos,
text_segment_pos, text_seg_pos,
y_masks=None) -> Dict[str, paddle.Tensor]: y_masks=None) -> Dict[str, paddle.Tensor]:
return {"feats": speech, "feats_lengths": speech_lengths} return {"feats": speech, "feats_lens": speech_lens}
def forward(self, batch, speech_segment_pos, y_masks=None): def forward(self, batch, speech_seg_pos, y_masks=None):
# feats: (Batch, Length, Dim) # feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2) # -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad'] speech_pad_placeholder = batch['speech_pad']
...@@ -654,7 +648,7 @@ class MLMModel(nn.Layer): ...@@ -654,7 +648,7 @@ class MLMModel(nn.Layer):
if self.decoder is not None: if self.decoder is not None:
zs, _ = self.decoder(ys_in, y_masks, encoder_out, zs, _ = self.decoder(ys_in, y_masks, encoder_out,
bool(h_masks), bool(h_masks),
self.encoder.segment_emb(speech_segment_pos)) self.encoder.segment_emb(speech_seg_pos))
speech_hidden_states = zs speech_hidden_states = zs
else: else:
speech_hidden_states = encoder_out[:, :paddle.shape(batch[ speech_hidden_states = encoder_out[:, :paddle.shape(batch[
...@@ -672,21 +666,21 @@ class MLMModel(nn.Layer): ...@@ -672,21 +666,21 @@ class MLMModel(nn.Layer):
else: else:
after_outs = None after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch[ return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_position'] 'masked_pos']
def inference( def inference(
self, self,
speech, speech,
text, text,
masked_position, masked_pos,
speech_mask, speech_mask,
text_mask, text_mask,
speech_segment_pos, speech_seg_pos,
text_segment_pos, text_seg_pos,
span_boundary, span_bdy,
y_masks=None, y_masks=None,
speech_lengths=None, speech_lens=None,
text_lengths=None, text_lens=None,
feats: Optional[paddle.Tensor]=None, feats: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None, spembs: Optional[paddle.Tensor]=None,
sids: Optional[paddle.Tensor]=None, sids: Optional[paddle.Tensor]=None,
...@@ -699,24 +693,24 @@ class MLMModel(nn.Layer): ...@@ -699,24 +693,24 @@ class MLMModel(nn.Layer):
batch = dict( batch = dict(
speech_pad=speech, speech_pad=speech,
text_pad=text, text_pad=text,
masked_position=masked_position, masked_pos=masked_pos,
speech_mask=speech_mask, speech_mask=speech_mask,
text_mask=text_mask, text_mask=text_mask,
speech_segment_pos=speech_segment_pos, speech_seg_pos=speech_seg_pos,
text_segment_pos=text_segment_pos, ) text_seg_pos=text_seg_pos, )
# # inference with teacher forcing # # inference with teacher forcing
# hs, h_masks = self.encoder(**batch) # hs, h_masks = self.encoder(**batch)
outs = [batch['speech_pad'][:, :span_boundary[0]]] outs = [batch['speech_pad'][:, :span_bdy[0]]]
z_cache = None z_cache = None
if use_teacher_forcing: if use_teacher_forcing:
before, zs, _, _ = self.forward( before, zs, _, _ = self.forward(
batch, speech_segment_pos, y_masks=y_masks) batch, speech_seg_pos, y_masks=y_masks)
if zs is None: if zs is None:
zs = before zs = before
outs += [zs[0][span_boundary[0]:span_boundary[1]]] outs += [zs[0][span_bdy[0]:span_bdy[1]]]
outs += [batch['speech_pad'][:, span_boundary[1]:]] outs += [batch['speech_pad'][:, span_bdy[1]:]]
return dict(feat_gen=outs) return dict(feat_gen=outs)
return None return None
...@@ -733,7 +727,7 @@ class MLMModel(nn.Layer): ...@@ -733,7 +727,7 @@ class MLMModel(nn.Layer):
class MLMEncAsDecoderModel(MLMModel): class MLMEncAsDecoderModel(MLMModel):
def forward(self, batch, speech_segment_pos, y_masks=None): def forward(self, batch, speech_seg_pos, y_masks=None):
# feats: (Batch, Length, Dim) # feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2) # -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad'] speech_pad_placeholder = batch['speech_pad']
...@@ -756,7 +750,7 @@ class MLMEncAsDecoderModel(MLMModel): ...@@ -756,7 +750,7 @@ class MLMEncAsDecoderModel(MLMModel):
else: else:
after_outs = None after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch[ return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_position'] 'masked_pos']
class MLMDualMaksingModel(MLMModel): class MLMDualMaksingModel(MLMModel):
...@@ -767,9 +761,9 @@ class MLMDualMaksingModel(MLMModel): ...@@ -767,9 +761,9 @@ class MLMDualMaksingModel(MLMModel):
batch): batch):
xs_pad = batch['speech_pad'] xs_pad = batch['speech_pad']
text_pad = batch['text_pad'] text_pad = batch['text_pad']
masked_position = batch['masked_position'] masked_pos = batch['masked_pos']
text_masked_position = batch['text_masked_position'] text_masked_pos = batch['text_masked_pos']
mlm_loss_position = masked_position > 0 mlm_loss_pos = masked_pos > 0
loss = paddle.sum( loss = paddle.sum(
self.l1_loss_func( self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)), paddle.reshape(before_outs, (-1, self.odim)),
...@@ -782,19 +776,17 @@ class MLMDualMaksingModel(MLMModel): ...@@ -782,19 +776,17 @@ class MLMDualMaksingModel(MLMModel):
paddle.reshape(xs_pad, (-1, self.odim))), paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1) axis=-1)
loss_mlm = paddle.sum((loss * paddle.reshape( loss_mlm = paddle.sum((loss * paddle.reshape(
mlm_loss_position, [-1]))) / paddle.sum((mlm_loss_position) + 1e-10) mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
loss_text = paddle.sum((self.text_mlm_loss( loss_text = paddle.sum((self.text_mlm_loss(
paddle.reshape(text_outs, (-1, self.vocab_size)), paddle.reshape(text_outs, (-1, self.vocab_size)),
paddle.reshape(text_pad, (-1))) * paddle.reshape( paddle.reshape(text_pad, (-1))) * paddle.reshape(
text_masked_position, text_masked_pos, (-1)))) / paddle.sum((text_masked_pos) + 1e-10)
(-1)))) / paddle.sum((text_masked_position) + 1e-10)
return loss_mlm, loss_text return loss_mlm, loss_text
def forward(self, batch, speech_segment_pos, y_masks=None): def forward(self, batch, speech_seg_pos, y_masks=None):
# feats: (Batch, Length, Dim) # feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2) # -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad']
encoder_out, h_masks = self.encoder(**batch) # segment_emb encoder_out, h_masks = self.encoder(**batch) # segment_emb
if self.decoder is not None: if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks) zs, _ = self.decoder(encoder_out, h_masks)
...@@ -819,7 +811,7 @@ class MLMDualMaksingModel(MLMModel): ...@@ -819,7 +811,7 @@ class MLMDualMaksingModel(MLMModel):
[0, 2, 1]) [0, 2, 1])
else: else:
after_outs = None after_outs = None
return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_position'],batch['text_masked_position'] return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_pos'],batch['text_masked_pos']
def build_model_from_file(config_file, model_file): def build_model_from_file(config_file, model_file):
......
...@@ -38,7 +38,7 @@ def pad_list(xs, pad_value): ...@@ -38,7 +38,7 @@ def pad_list(xs, pad_value):
""" """
n_batch = len(xs) n_batch = len(xs)
max_len = max(x.shape[0] for x in xs) max_len = max(x.shape[0] for x in xs)
pad = paddle.full([n_batch, max_len, *xs[0].shape[1:]], pad_value) pad = paddle.full([n_batch, max_len, *xs[0].shape[1:]], pad_value, dtype=xs[0].dtype)
for i in range(n_batch): for i in range(n_batch):
pad[i, :xs[i].shape[0]] = xs[i] pad[i, :xs[i].shape[0]] = xs[i]
...@@ -46,13 +46,18 @@ def pad_list(xs, pad_value): ...@@ -46,13 +46,18 @@ def pad_list(xs, pad_value):
return pad return pad
def make_pad_mask(lengths, length_dim=-1):
def make_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of padded part. """Make mask tensor containing indices of padded part.
Args: Args:
lengths (Tensor(int64)): Batch of lengths (B,). lengths (Tensor(int64)): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns: Returns:
Tensor(bool): Mask tensor containing indices of padded part bool. Tensor(bool): Mask tensor containing indices of padded part bool.
Examples: Examples:
...@@ -61,23 +66,98 @@ def make_pad_mask(lengths, length_dim=-1): ...@@ -61,23 +66,98 @@ def make_pad_mask(lengths, length_dim=-1):
>>> lengths = [5, 3, 2] >>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths) >>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0], masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1], [0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]] [0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = paddle.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]])
>>> xs = paddle.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]])
With the reference tensor and dimension indicator.
>>> xs = paddle.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]])
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]],)
""" """
if length_dim == 0: if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim)) raise ValueError("length_dim cannot be 0: {}".format(length_dim))
bs = paddle.shape(lengths)[0] bs = paddle.shape(lengths)[0]
maxlen = lengths.max() if xs is None:
maxlen = lengths.max()
else:
maxlen = paddle.shape(xs)[length_dim]
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64) seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen]) seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen])
seq_length_expand = lengths.unsqueeze(-1) seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand mask = seq_range_expand >= seq_length_expand
return mask if xs is not None:
assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs)
if length_dim < 0:
length_dim = len(paddle.shape(xs)) + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None
for i in range(len(paddle.shape(xs))))
mask = paddle.expand(mask[ind], paddle.shape(xs))
return mask
def make_non_pad_mask(lengths, length_dim=-1): def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part. """Make mask tensor containing indices of non-padded part.
Args: Args:
...@@ -90,16 +170,78 @@ def make_non_pad_mask(lengths, length_dim=-1): ...@@ -90,16 +170,78 @@ def make_non_pad_mask(lengths, length_dim=-1):
Returns: Returns:
Tensor(bool): mask tensor containing indices of padded part bool. Tensor(bool): mask tensor containing indices of padded part bool.
Examples: Examples:
With only lengths. With only lengths.
>>> lengths = [5, 3, 2] >>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths) >>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1], masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0], [1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]] [1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = paddle.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]])
>>> xs = paddle.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]])
With the reference tensor and dimension indicator.
>>> xs = paddle.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]])
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]])
""" """
return paddle.logical_not(make_pad_mask(lengths, length_dim)) return paddle.logical_not(make_pad_mask(lengths, xs, length_dim))
def initialize(model: nn.Layer, init: str): def initialize(model: nn.Layer, init: str):
......
...@@ -10,8 +10,8 @@ python inference.py \ ...@@ -10,8 +10,8 @@ python inference.py \
--uid=Prompt_003_new \ --uid=Prompt_003_new \
--new_str='今天天气很好.' \ --new_str='今天天气很好.' \
--prefix='./prompt/dev/' \ --prefix='./prompt/dev/' \
--source_language=english \ --source_lang=english \
--target_language=chinese \ --target_lang=chinese \
--output_name=pred_clone.wav \ --output_name=pred_clone.wav \
--use_pt_vocoder=False \ --use_pt_vocoder=False \
--voc=pwgan_aishell3 \ --voc=pwgan_aishell3 \
......
...@@ -9,8 +9,8 @@ python inference.py \ ...@@ -9,8 +9,8 @@ python inference.py \
--uid=p299_096 \ --uid=p299_096 \
--new_str='I enjoy my life, do you?' \ --new_str='I enjoy my life, do you?' \
--prefix='./prompt/dev/' \ --prefix='./prompt/dev/' \
--source_language=english \ --source_lang=english \
--target_language=english \ --target_lang=english \
--output_name=pred_gen.wav \ --output_name=pred_gen.wav \
--use_pt_vocoder=False \ --use_pt_vocoder=False \
--voc=pwgan_aishell3 \ --voc=pwgan_aishell3 \
......
...@@ -10,8 +10,8 @@ python inference.py \ ...@@ -10,8 +10,8 @@ python inference.py \
--uid=p243_new \ --uid=p243_new \
--new_str='for that reason cover is impossible to be given.' \ --new_str='for that reason cover is impossible to be given.' \
--prefix='./prompt/dev/' \ --prefix='./prompt/dev/' \
--source_language=english \ --source_lang=english \
--target_language=english \ --target_lang=english \
--output_name=pred_edit.wav \ --output_name=pred_edit.wav \
--use_pt_vocoder=False \ --use_pt_vocoder=False \
--voc=pwgan_aishell3 \ --voc=pwgan_aishell3 \
......
...@@ -80,10 +80,8 @@ def parse_args(): ...@@ -80,10 +80,8 @@ def parse_args():
parser.add_argument("--uid", type=str, help="uid") parser.add_argument("--uid", type=str, help="uid")
parser.add_argument("--new_str", type=str, help="new string") parser.add_argument("--new_str", type=str, help="new string")
parser.add_argument("--prefix", type=str, help="prefix") parser.add_argument("--prefix", type=str, help="prefix")
parser.add_argument("--clone_prefix", type=str, default=None, help="clone prefix") parser.add_argument("--source_lang", type=str, default="english", help="source language")
parser.add_argument("--clone_uid", type=str, default=None, help="clone uid") parser.add_argument("--target_lang", type=str, default="english", help="target language")
parser.add_argument("--source_language", type=str, help="source language")
parser.add_argument("--target_language", type=str, help="target language")
parser.add_argument("--output_name", type=str, help="output name") parser.add_argument("--output_name", type=str, help="output name")
parser.add_argument("--task_name", type=str, help="task name") parser.add_argument("--task_name", type=str, help="task name")
parser.add_argument( parser.add_argument(
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
import yaml import yaml
class ParallelWaveGANPretrainedVocoder(torch.nn.Module): class TorchPWGAN(torch.nn.Module):
"""Wrapper class to load the vocoder trained with parallel_wavegan repo.""" """Wrapper class to load the vocoder trained with parallel_wavegan repo."""
def __init__( def __init__(
......
import os
from typing import List
from typing import Optional
import numpy as np import numpy as np
import paddle import paddle
import yaml import yaml
...@@ -5,11 +9,8 @@ from sedit_arg_parser import parse_args ...@@ -5,11 +9,8 @@ from sedit_arg_parser import parse_args
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from tools.parallel_wavegan_pretrained_vocoder import ParallelWaveGANPretrainedVocoder from tools.torch_pwgan import TorchPWGAN
# new add
model_alias = { model_alias = {
# acoustic model # acoustic model
...@@ -25,6 +26,10 @@ model_alias = { ...@@ -25,6 +26,10 @@ model_alias = {
"paddlespeech.t2s.models.tacotron2:Tacotron2", "paddlespeech.t2s.models.tacotron2:Tacotron2",
"tacotron2_inference": "tacotron2_inference":
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference", "paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
"pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
"pwgan_inference":
"paddlespeech.t2s.models.parallel_wavegan:PWGInference",
} }
...@@ -43,60 +48,65 @@ def build_vocoder_from_file( ...@@ -43,60 +48,65 @@ def build_vocoder_from_file(
# Build vocoder # Build vocoder
if str(vocoder_file).endswith(".pkl"): if str(vocoder_file).endswith(".pkl"):
# If the extension is ".pkl", the model is trained with parallel_wavegan # If the extension is ".pkl", the model is trained with parallel_wavegan
vocoder = ParallelWaveGANPretrainedVocoder(vocoder_file, vocoder = TorchPWGAN(vocoder_file, vocoder_config_file)
vocoder_config_file)
return vocoder.to(device) return vocoder.to(device)
else: else:
raise ValueError(f"{vocoder_file} is not supported format.") raise ValueError(f"{vocoder_file} is not supported format.")
def get_voc_out(mel, target_language="chinese"): def get_voc_out(mel, target_lang: str="chinese"):
# vocoder # vocoder
args = parse_args() args = parse_args()
assert target_language == "chinese" or target_language == "english", "In get_voc_out function, target_language is illegal..." assert target_lang == "chinese" or target_lang == "english", "In get_voc_out function, target_lang is illegal..."
# print("current vocoder: ", args.voc) # print("current vocoder: ", args.voc)
with open(args.voc_config) as f: with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f)) voc_config = CfgNode(yaml.safe_load(f))
# print(voc_config) voc_inference = voc_inference = get_voc_inference(
voc=args.voc,
voc_inference = get_voc_inference(args, voc_config) voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
mel = paddle.to_tensor(mel)
# print("masked_mel: ", mel.shape)
with paddle.no_grad(): with paddle.no_grad():
wav = voc_inference(mel) wav = voc_inference(mel)
# print("shepe of wav (time x n_channels):%s"%wav.shape)
return np.squeeze(wav) return np.squeeze(wav)
# dygraph # dygraph
def get_am_inference(args, am_config): def get_am_inference(am: str='fastspeech2_csmsc',
with open(args.phones_dict, "r") as f: am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
return_am: bool=False):
with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
# print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
tone_size = None tone_size = None
if 'tones_dict' in args and args.tones_dict: if tones_dict is not None:
with open(args.tones_dict, "r") as f: with open(tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()] tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id) tone_size = len(tone_id)
print("tone_size:", tone_size) print("tone_size:", tone_size)
spk_num = None spk_num = None
if 'speaker_dict' in args and args.speaker_dict: if speaker_dict is not None:
with open(args.speaker_dict, 'rt') as f: with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id) spk_num = len(spk_id)
print("spk_num:", spk_num) print("spk_num:", spk_num)
odim = am_config.n_mels odim = am_config.n_mels
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
am_class = dynamic_import(am_name, model_alias) am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias) am_inference_class = dynamic_import(am_name + '_inference', model_alias)
...@@ -113,39 +123,61 @@ def get_am_inference(args, am_config): ...@@ -113,39 +123,61 @@ def get_am_inference(args, am_config):
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) am.set_state_dict(paddle.load(am_ckpt)["main_params"])
am.eval() am.eval()
am_mu, am_std = np.load(args.am_stat) am_mu, am_std = np.load(am_stat)
am_mu = paddle.to_tensor(am_mu) am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std) am_std = paddle.to_tensor(am_std)
am_normalizer = ZScore(am_mu, am_std) am_normalizer = ZScore(am_mu, am_std)
am_inference = am_inference_class(am_normalizer, am) am_inference = am_inference_class(am_normalizer, am)
am_inference.eval() am_inference.eval()
print("acoustic model done!") print("acoustic model done!")
return am, am_inference, am_name, am_dataset, phn_id if return_am:
return am_inference, am
else:
return am_inference
def evaluate_durations(phns, def get_voc_inference(
target_language="chinese", voc: str='pwgan_csmsc',
fs=24000, voc_config: Optional[os.PathLike]=None,
hop_length=300): voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None, ):
# model: {model_name}_{dataset}
voc_name = voc[:voc.rindex('_')]
voc_class = dynamic_import(voc_name, model_alias)
voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
if voc_name != 'wavernn':
voc = voc_class(**voc_config["generator_params"])
voc.set_state_dict(paddle.load(voc_ckpt)["generator_params"])
voc.remove_weight_norm()
voc.eval()
else:
voc = voc_class(**voc_config["model"])
voc.set_state_dict(paddle.load(voc_ckpt)["main_params"])
voc.eval()
voc_mu, voc_std = np.load(voc_stat)
voc_mu = paddle.to_tensor(voc_mu)
voc_std = paddle.to_tensor(voc_std)
voc_normalizer = ZScore(voc_mu, voc_std)
voc_inference = voc_inference_class(voc_normalizer, voc)
voc_inference.eval()
print("voc done!")
return voc_inference
def evaluate_durations(phns: List[str],
target_lang: str="chinese",
fs: int=24000,
hop_length: int=300):
args = parse_args() args = parse_args()
if target_language == 'english': if target_lang == 'english':
args.lang = 'en' args.lang = 'en'
args.am = "fastspeech2_ljspeech"
args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_language == 'chinese': elif target_lang == 'chinese':
args.lang = 'zh' args.lang = 'zh'
args.am = "fastspeech2_csmsc"
args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[]) # args = parser.parse_args(args=[])
if args.ngpu == 0: if args.ngpu == 0:
...@@ -155,23 +187,28 @@ def evaluate_durations(phns, ...@@ -155,23 +187,28 @@ def evaluate_durations(phns,
else: else:
print("ngpu should >= 0 !") print("ngpu should >= 0 !")
assert target_language == "chinese" or target_language == "english", "In evaluate_durations function, target_language is illegal..." assert target_lang == "chinese" or target_lang == "english", "In evaluate_durations function, target_lang is illegal..."
# Init body. # Init body.
with open(args.am_config) as f: with open(args.am_config) as f:
am_config = CfgNode(yaml.safe_load(f)) am_config = CfgNode(yaml.safe_load(f))
# print("========Config========")
# print(am_config) am_inference, am = get_am_inference(
# print("---------------------") am=args.am,
# acoustic model am_config=am_config,
am, am_inference, am_name, am_dataset, phn_id = get_am_inference(args, am_ckpt=args.am_ckpt,
am_config) am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict,
return_am=True)
torch_phns = phns torch_phns = phns
vocab_phones = {} vocab_phones = {}
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
for tone, id in phn_id: for tone, id in phn_id:
vocab_phones[tone] = int(id) vocab_phones[tone] = int(id)
# print("vocab_phones: ", len(vocab_phones))
vocab_size = len(vocab_phones) vocab_size = len(vocab_phones)
phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns] phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns]
...@@ -185,59 +222,3 @@ def evaluate_durations(phns, ...@@ -185,59 +222,3 @@ def evaluate_durations(phns,
phoneme_durations_new = pre_d_outs * hop_length / fs phoneme_durations_new = pre_d_outs * hop_length / fs
phoneme_durations_new = phoneme_durations_new.tolist()[:-1] phoneme_durations_new = phoneme_durations_new.tolist()[:-1]
return phoneme_durations_new return phoneme_durations_new
def sentence2phns(sentence, target_language="en"):
args = parse_args()
if target_language == 'en':
args.lang = 'en'
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_language == 'zh':
args.lang = 'zh'
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
else:
print("target_language should in {'zh', 'en'}!")
frontend = get_frontend(args)
merge_sentences = True
get_tone_ids = False
if target_language == 'zh':
input_ids = frontend.get_input_ids(
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
print_info=False)
phone_ids = input_ids["phone_ids"]
phonemes = frontend.get_phonemes(
sentence, merge_sentences=merge_sentences, print_info=False)
return phonemes[0], input_ids["phone_ids"][0]
elif target_language == 'en':
phonemes = frontend.phoneticize(sentence)
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
phones_list = []
vocab_phones = {}
punc = ":,;。?!“”‘’':,;.?!"
with open(args.phones_dict, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
vocab_phones[phn] = int(id)
phones = phonemes[1:-1]
phones = [phn for phn in phones if not phn.isspace()]
# replace unk phone with sp
phones = [
phn if (phn in vocab_phones and phn not in punc) else "sp"
for phn in phones
]
phones_list.append(phones)
return phones_list[0], input_ids["phone_ids"][0]
else:
print("lang should in {'zh', 'en'}!")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册