未验证 提交 b81832ce 编写于 作者: K Kennycao123 提交者: GitHub

Merge pull request #825 from yt605155624/format

[ernie sat]Format ernie sat
...@@ -113,8 +113,8 @@ prompt/dev ...@@ -113,8 +113,8 @@ prompt/dev
8. ` --uid` 特定提示(prompt)语音的 id 8. ` --uid` 特定提示(prompt)语音的 id
9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本) 9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本)
10. ` --prefix` 特定音频对应的文本、音素相关文件的地址 10. ` --prefix` 特定音频对应的文本、音素相关文件的地址
11. ` --source_language` , 源语言 11. ` --source_lang` , 源语言
12. ` --target_language` , 目标语言 12. ` --target_lang` , 目标语言
13. ` --output_name` , 合成语音名称 13. ` --output_name` , 合成语音名称
14. ` --task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务 14. ` --task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
15. ` --use_pt_vocoder`, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder 15. ` --use_pt_vocoder`, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder
......
#!/usr/bin/env python #!/usr/bin/env python
""" Usage: """ Usage:
align_english.py wavfile trsfile outwordfile outphonefile align.py wavfile trsfile outwordfile outphonefile
""" """
import multiprocessing as mp import multiprocessing as mp
import os import os
...@@ -9,12 +9,45 @@ import sys ...@@ -9,12 +9,45 @@ import sys
from tqdm import tqdm from tqdm import tqdm
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR = 'tools/aligner/english' MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin'
HVITE = 'tools/htk/HTKTools/HVite' HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy' HCOPY = 'tools/htk/HTKTools/HCopy'
def prep_txt(line, tmpbase, dictfile): def prep_txt_zh(line: str, tmpbase: str, dictfile: str):
words = []
line = line.strip()
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
with open(dictfile, 'r') as fid:
for line in fid:
ds.add(line.split()[0])
unk_words = set([])
with open(tmpbase + '.txt', 'w') as fwid:
for wrd in words:
if (wrd not in ds):
unk_words.add(wrd)
fwid.write(wrd + ' ')
fwid.write('\n')
return unk_words
def prep_txt_en(line: str, tmpbase, dictfile):
words = [] words = []
...@@ -97,7 +130,7 @@ def prep_txt(line, tmpbase, dictfile): ...@@ -97,7 +130,7 @@ def prep_txt(line, tmpbase, dictfile):
fw.close() fw.close()
def prep_mlf(txt, tmpbase): def prep_mlf(txt: str, tmpbase: str):
with open(tmpbase + '.mlf', 'w') as fwid: with open(tmpbase + '.mlf', 'w') as fwid:
fwid.write('#!MLF!#\n') fwid.write('#!MLF!#\n')
...@@ -110,7 +143,55 @@ def prep_mlf(txt, tmpbase): ...@@ -110,7 +143,55 @@ def prep_mlf(txt, tmpbase):
fwid.write('.\n') fwid.write('.\n')
def gen_res(tmpbase, outfile1, outfile2): def _get_user():
return os.path.expanduser('~').split("/")[-1]
def alignment(wav_path: str, text: str):
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -')
except:
print('sox error!')
return None
#prepare clean_transcript file
try:
prep_txt_en(text, tmpbase, MODEL_DIR_EN + '/dict')
except:
print('prep_txt error!')
return None
#prepare mlf file
try:
with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline()
prep_mlf(txt, tmpbase)
except:
print('prep_mlf error!')
return None
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR_EN + '/16000/macros -H ' + MODEL_DIR_EN
+ '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase +
'.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
print('HVite error!')
return None
with open(tmpbase + '.txt', 'r') as fid: with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split() words = fid.readline().strip().split()
words = txt.strip().split() words = txt.strip().split()
...@@ -119,59 +200,47 @@ def gen_res(tmpbase, outfile1, outfile2): ...@@ -119,59 +200,47 @@ def gen_res(tmpbase, outfile1, outfile2):
with open(tmpbase + '.aligned', 'r') as fid: with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines() lines = fid.readlines()
i = 2 i = 2
times1 = []
times2 = [] times2 = []
word2phns = {}
current_word = ''
index = 0
while (i < len(lines)): while (i < len(lines)):
if (len(lines[i].split()) >= 4) and ( splited_line = lines[i].strip().split()
lines[i].split()[0] != lines[i].split()[1]): if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = lines[i].split()[2] phn = splited_line[2]
pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000 pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000 pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen]) times2.append([phn, pst, pen])
if (len(lines[i].split()) == 5): # splited_line[-1]!='sp'
if (lines[i].split()[0] != lines[i].split()[1]): if len(splited_line) == 5:
wrd = lines[i].split()[-1].strip() current_word = str(index) + '_' + splited_line[-1]
st = (int(lines[i].split()[0]) / 1000 + 125) / 10000 word2phns[current_word] = phn
j = i + 1 index += 1
while (lines[j] != '.\n') and (len(lines[j].split()) != 5): elif len(splited_line) == 4:
j += 1 word2phns[current_word] += ' ' + phn
en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
times1.append([wrd, st, en])
i += 1 i += 1
return times2, word2phns
with open(outfile1, 'w') as fwid:
for item in times1:
if (item[0] == 'sp'):
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' SIL\n')
else:
wrd = words.pop()
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + wrd + '\n')
if words:
print('not matched::' + alignfile)
sys.exit(1)
with open(outfile2, 'w') as fwid:
for item in times2:
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')
def _get_user():
return os.path.expanduser('~').split("/")[-1]
def alignment(wav_path, text_string): def alignment_zh(wav_path, text_string):
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files #prepare wav and trs files
try: try:
os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -') os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -')
except: except:
print('sox error!') print('sox error!')
return None return None
#prepare clean_transcript file #prepare clean_transcript file
try: try:
prep_txt(text_string, tmpbase, MODEL_DIR + '/dict') unk_words = prep_txt_zh(text_string, tmpbase, MODEL_DIR_ZH + '/dict')
if unk_words:
print('Error! Please add the following words to dictionary:')
for unk in unk_words:
print("非法words: ", unk)
except: except:
print('prep_txt error!') print('prep_txt error!')
return None return None
...@@ -187,7 +256,7 @@ def alignment(wav_path, text_string): ...@@ -187,7 +256,7 @@ def alignment(wav_path, text_string):
#prepare scp #prepare scp
try: try:
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp') '.wav' + ' ' + tmpbase + '.plp')
except: except:
print('HCopy error!') print('HCopy error!')
...@@ -196,10 +265,11 @@ def alignment(wav_path, text_string): ...@@ -196,10 +265,11 @@ def alignment(wav_path, text_string):
#run alignment #run alignment
try: try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + '.mlf -H ' + MODEL_DIR_ZH + '/16000/macros -H ' + MODEL_DIR_ZH
'/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR_ZH
'.dict ' + MODEL_DIR + '/monophones ' + tmpbase + + '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null') '.plp 2>&1 > /dev/null')
except: except:
print('HVite error!') print('HVite error!')
return None return None
...@@ -211,6 +281,7 @@ def alignment(wav_path, text_string): ...@@ -211,6 +281,7 @@ def alignment(wav_path, text_string):
with open(tmpbase + '.aligned', 'r') as fid: with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines() lines = fid.readlines()
i = 2 i = 2
times2 = [] times2 = []
word2phns = {} word2phns = {}
......
#!/usr/bin/env python
""" Usage:
align_mandarin.py wavfile trsfile outwordfile putphonefile
"""
import multiprocessing as mp
import os
import sys
from tqdm import tqdm
MODEL_DIR = 'tools/aligner/mandarin'
HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy'
def prep_txt(line, tmpbase, dictfile):
words = []
line = line.strip()
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
with open(dictfile, 'r') as fid:
for line in fid:
ds.add(line.split()[0])
unk_words = set([])
with open(tmpbase + '.txt', 'w') as fwid:
for wrd in words:
if (wrd not in ds):
unk_words.add(wrd)
fwid.write(wrd + ' ')
fwid.write('\n')
return unk_words
def prep_mlf(txt, tmpbase):
with open(tmpbase + '.mlf', 'w') as fwid:
fwid.write('#!MLF!#\n')
fwid.write('"' + tmpbase + '.lab"\n')
fwid.write('sp\n')
wrds = txt.split()
for wrd in wrds:
fwid.write(wrd.upper() + '\n')
fwid.write('sp\n')
fwid.write('.\n')
def gen_res(tmpbase, outfile1, outfile2):
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
words = txt.strip().split()
words.reverse()
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
times1 = []
times2 = []
while (i < len(lines)):
if (len(lines[i].split()) >= 4) and (
lines[i].split()[0] != lines[i].split()[1]):
phn = lines[i].split()[2]
pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000
pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
if (len(lines[i].split()) == 5):
if (lines[i].split()[0] != lines[i].split()[1]):
wrd = lines[i].split()[-1].strip()
st = (int(lines[i].split()[0]) / 1000 + 125) / 10000
j = i + 1
while (lines[j] != '.\n') and (len(lines[j].split()) != 5):
j += 1
en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
times1.append([wrd, st, en])
i += 1
with open(outfile1, 'w') as fwid:
for item in times1:
if (item[0] == 'sp'):
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' SIL\n')
else:
wrd = words.pop()
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + wrd + '\n')
if words:
print('not matched::' + alignfile)
sys.exit(1)
with open(outfile2, 'w') as fwid:
for item in times2:
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')
def alignment_zh(wav_path, text_string):
tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -')
except:
print('sox error!')
return None
#prepare clean_transcript file
try:
unk_words = prep_txt(text_string, tmpbase, MODEL_DIR + '/dict')
if unk_words:
print('Error! Please add the following words to dictionary:')
for unk in unk_words:
print("非法words: ", unk)
except:
print('prep_txt error!')
return None
#prepare mlf file
try:
with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline()
prep_mlf(txt, tmpbase)
except:
print('prep_mlf error!')
return None
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR +
'/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR +
'/dict ' + MODEL_DIR + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
print('HVite error!')
return None
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
words = txt.strip().split()
words.reverse()
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
times2 = []
word2phns = {}
current_word = ''
index = 0
while (i < len(lines)):
splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn
index += 1
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return times2, word2phns
...@@ -4,37 +4,180 @@ import numpy as np ...@@ -4,37 +4,180 @@ import numpy as np
import paddle import paddle
def pad_list(xs, pad_value): def phones_text_masking(xs_pad: paddle.Tensor,
"""Perform padding for the list of tensors. src_mask: paddle.Tensor,
text_pad: paddle.Tensor,
text_mask: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
mlm_prob: float,
mean_phn_span: float,
span_bdy: paddle.Tensor=None):
bz, sent_len, _ = paddle.shape(xs_pad)
masked_pos = paddle.zeros((bz, sent_len))
_, text_len = paddle.shape(text_pad)
text_mask_num_lower = math.ceil(text_len * (1 - mlm_prob) * 0.5)
text_masked_pos = paddle.zeros((bz, text_len))
y_masks = None
if mlm_prob == 1.0:
masked_pos += 1
# y_masks = tril_masks
elif mean_phn_span == 0:
# only speech
length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_idxs = random_spans_noise_mask(length, mlm_prob,
mean_phn_span).nonzero()
masked_pos[:, masked_phn_idxs] = 1
else:
for idx in range(bz):
if span_bdy is not None:
for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
masked_pos[idx, s:e] = 1
else:
length = align_start_lens[idx]
if length < 2:
continue
masked_phn_idxs = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero()
unmasked_phn_idxs = list(
set(range(length)) - set(masked_phn_idxs[0].tolist()))
np.random.shuffle(unmasked_phn_idxs)
masked_text_idxs = unmasked_phn_idxs[:text_mask_num_lower]
text_masked_pos[idx][masked_text_idxs] = 1
masked_start = align_start[idx][masked_phn_idxs].tolist()
masked_end = align_end[idx][masked_phn_idxs].tolist()
for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask
non_eos_text_mask = paddle.reshape(text_mask, paddle.shape(xs_pad)[:2])
text_masked_pos = text_masked_pos * non_eos_text_mask
masked_pos = paddle.cast(masked_pos, 'bool')
text_masked_pos = paddle.cast(text_masked_pos, 'bool')
return masked_pos, text_masked_pos, y_masks
def get_seg_pos_reduce_duration(
speech_pad: paddle.Tensor,
text_pad: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
sega_emb: bool,
masked_pos: paddle.Tensor,
feats_lens: paddle.Tensor, ):
bz, speech_len, _ = paddle.shape(speech_pad)
text_seg_pos = paddle.zeros(paddle.shape(text_pad))
speech_seg_pos = paddle.zeros((bz, speech_len), dtype=text_pad.dtype)
reordered_idx = paddle.zeros((bz, speech_len), dtype=align_start_lens.dtype)
durations = paddle.ones((bz, speech_len), dtype=align_start_lens.dtype)
max_reduced_length = 0
if not sega_emb:
return speech_pad, masked_pos, speech_seg_pos, text_seg_pos, durations
for idx in range(bz):
first_idx = []
last_idx = []
align_length = align_start_lens[idx]
for j in range(align_length):
s, e = align_start[idx][j], align_end[idx][j]
if j == 0:
if paddle.sum(masked_pos[idx][0:s]) == 0:
first_idx.extend(range(0, s))
else:
first_idx.extend([0])
last_idx.extend(range(1, s))
if paddle.sum(masked_pos[idx][s:e]) == 0:
first_idx.extend(range(s, e))
else:
first_idx.extend([s])
last_idx.extend(range(s + 1, e))
durations[idx][s] = e - s
speech_seg_pos[idx][s:e] = j + 1
text_seg_pos[idx][j] = j + 1
max_reduced_length = max(
len(first_idx) + feats_lens[idx] - e, max_reduced_length)
first_idx.extend(range(e, speech_len))
reordered_idx[idx] = paddle.to_tensor(
(first_idx + last_idx), dtype=align_start_lens.dtype)
feats_lens[idx] = len(first_idx)
reordered_idx = reordered_idx[:, :max_reduced_length]
return reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens
def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
"""This function is copy of `random_spans_helper
<https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
The number of noise tokens and the number of noise spans and non-noise spans
are determined deterministically as follows:
num_noise_tokens = round(length * noise_density)
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
Spans alternate between non-noise and noise, beginning with non-noise.
Subject to the above restrictions, all masks are equally likely.
Args: Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. length: an int32 scalar (length of the incoming token sequence)
pad_value (float): Value for padding. noise_density: a float - approximate density of output mask
mean_noise_span_length: a number
Returns: Returns:
Tensor: Padded tensor (B, Tmax, `*`). a boolean tensor with shape [length]
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
""" """
n_batch = len(xs)
max_len = max(paddle.shape(x)[0] for x in xs)
pad = paddle.full((n_batch, max_len), pad_value, dtype=xs[0].dtype)
for i in range(n_batch):
pad[i, :paddle.shape(xs[i])[0]] = xs[i]
return pad orig_length = length
num_noise_tokens = int(np.round(length * mlm_prob))
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
num_noise_spans = int(np.round(num_noise_tokens / mean_phn_span))
# avoid degeneracy by ensuring positive number of noise spans
num_noise_spans = max(num_noise_spans, 1)
num_nonnoise_tokens = length - num_noise_tokens
# pick the lengths of the noise spans and the non-noise spans
def _random_seg(num_items, num_segs):
"""Partition a sequence of items randomly into non-empty segments.
Args:
num_items: an integer scalar > 0
num_segs: an integer scalar in [1, num_items]
Returns:
a Tensor with shape [num_segs] containing positive integers that add
up to num_items
"""
mask_idxs = np.arange(num_items - 1) < (num_segs - 1)
np.random.shuffle(mask_idxs)
first_in_seg = np.pad(mask_idxs, [[1, 0]])
segment_id = np.cumsum(first_in_seg)
# count length of sub segments assuming that list is sorted
_, segment_length = np.unique(segment_id, return_counts=True)
return segment_length
noise_span_lens = _random_seg(num_noise_tokens, num_noise_spans)
nonnoise_span_lens = _random_seg(num_nonnoise_tokens, num_noise_spans)
interleaved_span_lens = np.reshape(
np.stack([nonnoise_span_lens, noise_span_lens], axis=1),
[num_noise_spans * 2])
span_starts = np.cumsum(interleaved_span_lens)[:-1]
span_start_indicator = np.zeros((length, ), dtype=np.int8)
span_start_indicator[span_starts] = True
span_num = np.cumsum(span_start_indicator)
is_noise = np.equal(span_num % 2, 1)
return is_noise[:orig_length]
def pad_to_longformer_att_window(text: paddle.Tensor,
max_len: int,
max_tlen: int,
attention_window: int=0):
def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
round = max_len % attention_window round = max_len % attention_window
if round != 0: if round != 0:
max_tlen += (attention_window - round) max_tlen += (attention_window - round)
...@@ -48,286 +191,67 @@ def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window): ...@@ -48,286 +191,67 @@ def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
return text_pad, max_tlen return text_pad, max_tlen
def make_pad_mask(lengths, xs=None, length_dim=-1): def phones_masking(xs_pad: paddle.Tensor,
"""Make mask tensor containing indices of padded part. src_mask: paddle.Tensor,
align_start: paddle.Tensor,
Args: align_end: paddle.Tensor,
lengths (LongTensor or List): Batch of lengths (B,). align_start_lens: paddle.Tensor,
xs (Tensor, optional): The reference tensor. mlm_prob: float,
If set, masks will be the same shape as this tensor. mean_phn_span: int,
length_dim (int, optional): Dimension indicator of the above tensor. span_bdy: paddle.Tensor=None):
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = list(lengths)
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = paddle.shape(xs)[length_dim]
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
seq_range_expand = paddle.expand(
paddle.unsqueeze(seq_range, 0), (bs, maxlen))
seq_length_expand = paddle.unsqueeze(paddle.to_tensor(lengths), -1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs)
if length_dim < 0:
length_dim = len(paddle.shape(xs)) + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None
for i in range(len(paddle.shape(xs))))
mask = paddle.expand(mask[ind], paddle.shape(xs))
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
return ~make_pad_mask(lengths, xs, length_dim)
def phones_masking(xs_pad,
src_mask,
align_start,
align_end,
align_start_lengths,
mlm_prob,
mean_phn_span,
span_boundary=None):
bz, sent_len, _ = paddle.shape(xs_pad) bz, sent_len, _ = paddle.shape(xs_pad)
mask_num_lower = math.ceil(sent_len * mlm_prob) masked_pos = paddle.zeros((bz, sent_len))
masked_position = np.zeros((bz, sent_len))
y_masks = None y_masks = None
# y_masks = torch.ones(bz,sent_len,sent_len,device=xs_pad.device,dtype=xs_pad.dtype)
# tril_masks = torch.tril(y_masks)
if mlm_prob == 1.0: if mlm_prob == 1.0:
masked_position += 1 masked_pos += 1
# y_masks = tril_masks
elif mean_phn_span == 0: elif mean_phn_span == 0:
# only speech # only speech
length = sent_len length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50) mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_indices = random_spans_noise_mask(length, mlm_prob, masked_phn_idxs = random_spans_noise_mask(length, mlm_prob,
mean_phn_span).nonzero() mean_phn_span).nonzero()
masked_position[:, masked_phn_indices] = 1 masked_pos[:, masked_phn_idxs] = 1
else: else:
for idx in range(bz): for idx in range(bz):
if span_boundary is not None: if span_bdy is not None:
for s, e in zip(span_boundary[idx][::2], for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
span_boundary[idx][1::2]): masked_pos[idx, s:e] = 1
masked_position[idx, s:e] = 1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0
else: else:
length = align_start_lengths[idx].item() length = align_start_lens[idx]
if length < 2: if length < 2:
continue continue
masked_phn_indices = random_spans_noise_mask( masked_phn_idxs = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero() length, mlm_prob, mean_phn_span).nonzero()
masked_start = align_start[idx][masked_phn_indices].tolist() masked_start = align_start[idx][masked_phn_idxs].tolist()
masked_end = align_end[idx][masked_phn_indices].tolist() masked_end = align_end[idx][masked_phn_idxs].tolist()
for s, e in zip(masked_start, masked_end): for s, e in zip(masked_start, masked_end):
masked_position[idx, s:e] = 1 masked_pos[idx, s:e] = 1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e] non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
# y_masks[idx, e:, s:e ] = 0 masked_pos = masked_pos * non_eos_mask
non_eos_mask = np.array( masked_pos = paddle.cast(masked_pos, 'bool')
paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]).float().cpu())
masked_position = masked_position * non_eos_mask
# y_masks = src_mask & y_masks.bool()
return paddle.cast(paddle.to_tensor(masked_position), paddle.bool), y_masks return masked_pos, y_masks
def get_segment_pos(speech_pad, text_pad, align_start, align_end, def get_seg_pos(speech_pad: paddle.Tensor,
align_start_lengths, sega_emb): text_pad: paddle.Tensor,
bz, speech_len, _ = speech_pad.size() align_start: paddle.Tensor,
_, text_len = text_pad.size() align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
sega_emb: bool):
bz, speech_len, _ = paddle.shape(speech_pad)
_, text_len = paddle.shape(text_pad)
# text_segment_pos = paddle.zeros_like(text_pad) text_seg_pos = paddle.zeros((bz, text_len), dtype='int64')
# speech_segment_pos = paddle.zeros((bz, speech_len),dtype=text_pad.dtype) speech_seg_pos = paddle.zeros((bz, speech_len), dtype='int64')
text_segment_pos = np.zeros((bz, text_len)).astype('int64')
speech_segment_pos = np.zeros((bz, speech_len)).astype('int64')
if not sega_emb: if not sega_emb:
text_segment_pos = paddle.to_tensor(text_segment_pos) return speech_seg_pos, text_seg_pos
speech_segment_pos = paddle.to_tensor(speech_segment_pos)
return speech_segment_pos, text_segment_pos
for idx in range(bz): for idx in range(bz):
align_length = align_start_lengths[idx].item() align_length = align_start_lens[idx]
for j in range(align_length): for j in range(align_length):
s, e = align_start[idx][j].item(), align_end[idx][j].item() s, e = align_start[idx][j], align_end[idx][j]
speech_segment_pos[idx][s:e] = j + 1 speech_seg_pos[idx, s:e] = j + 1
text_segment_pos[idx][j] = j + 1 text_seg_pos[idx, j] = j + 1
text_segment_pos = paddle.to_tensor(text_segment_pos)
speech_segment_pos = paddle.to_tensor(speech_segment_pos)
return speech_segment_pos, text_segment_pos return speech_seg_pos, text_seg_pos
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import math
import os import os
import pickle
import random import random
import string
import sys
from pathlib import Path from pathlib import Path
from typing import Collection from typing import Collection
from typing import Dict from typing import Dict
...@@ -18,17 +14,17 @@ import numpy as np ...@@ -18,17 +14,17 @@ import numpy as np
import paddle import paddle
import soundfile as sf import soundfile as sf
import torch import torch
from paddle import nn
from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
from align_english import alignment from align import alignment
from align_mandarin import alignment_zh from align import alignment_zh
from dataset import get_segment_pos from dataset import get_seg_pos
from dataset import make_non_pad_mask from dataset import get_seg_pos_reduce_duration
from dataset import make_pad_mask
from dataset import pad_list
from dataset import pad_to_longformer_att_window from dataset import pad_to_longformer_att_window
from dataset import phones_masking from dataset import phones_masking
from dataset import phones_text_masking
from model_paddle import build_model_from_file from model_paddle import build_model_from_file
from read_text import load_num_sequence_text from read_text import load_num_sequence_text
from read_text import read_2column_text from read_text import read_2column_text
...@@ -37,8 +33,9 @@ from utils import build_vocoder_from_file ...@@ -37,8 +33,9 @@ from utils import build_vocoder_from_file
from utils import evaluate_durations from utils import evaluate_durations
from utils import get_voc_out from utils import get_voc_out
from utils import is_chinese from utils import is_chinese
from utils import sentence2phns
from paddlespeech.t2s.datasets.get_feats import LogMelFBank from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.modules.nets_utils import pad_list
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -47,81 +44,72 @@ MODEL_DIR_EN = 'tools/aligner/english' ...@@ -47,81 +44,72 @@ MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin' MODEL_DIR_ZH = 'tools/aligner/mandarin'
def plot_mel_and_vocode_wav(uid, def plot_mel_and_vocode_wav(uid: str,
prefix, wav_path: str,
clone_uid, prefix: str="./prompt/dev/",
clone_prefix, source_lang: str='english',
source_language, target_lang: str='english',
target_language, model_name: str="conformer",
model_name, full_origin_str: str="",
wav_path, old_str: str="",
full_origin_str, new_str: str="",
old_str, duration_preditor_path: str=None,
new_str, use_pt_vocoder: bool=False,
use_pt_vocoder, sid: str=None,
duration_preditor_path, non_autoreg: bool=True):
sid=None, wav_org, input_feat, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output(
non_autoreg=True): uid=uid,
wav_org, input_feat, output_feat, old_span_boundary, new_span_boundary, fs, hop_length = get_mlm_output( prefix=prefix,
uid, source_lang=source_lang,
prefix, target_lang=target_lang,
clone_uid, model_name=model_name,
clone_prefix, wav_path=wav_path,
source_language, old_str=old_str,
target_language, new_str=new_str,
model_name, duration_preditor_path=duration_preditor_path,
wav_path,
old_str,
new_str,
duration_preditor_path,
use_teacher_forcing=non_autoreg, use_teacher_forcing=non_autoreg,
sid=sid) sid=sid)
masked_feat = output_feat[new_span_boundary[0]:new_span_boundary[ masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
1]].detach().float().cpu().numpy()
if target_language == 'english': if target_lang == 'english':
if use_pt_vocoder: if use_pt_vocoder:
output_feat = output_feat.detach().float().cpu().numpy() output_feat = output_feat.cpu().numpy()
output_feat = torch.tensor(output_feat, dtype=torch.float) output_feat = torch.tensor(output_feat, dtype=torch.float)
vocoder = load_vocoder('vctk_parallel_wavegan.v1.long') vocoder = load_vocoder('vctk_parallel_wavegan.v1.long')
replaced_wav = vocoder( replaced_wav = vocoder(output_feat).cpu().numpy()
output_feat).detach().float().data.cpu().numpy()
else: else:
output_feat_np = output_feat.detach().float().cpu().numpy() replaced_wav = get_voc_out(output_feat, target_lang)
replaced_wav = get_voc_out(output_feat_np, target_language)
elif target_language == 'chinese': elif target_lang == 'chinese':
output_feat_np = output_feat.detach().float().cpu().numpy() replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, target_lang)
replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat,
target_language)
old_time_boundary = [hop_length * x for x in old_span_boundary] old_time_bdy = [hop_length * x for x in old_span_bdy]
new_time_boundary = [hop_length * x for x in new_span_boundary] new_time_bdy = [hop_length * x for x in new_span_bdy]
if target_language == 'english': if target_lang == 'english':
wav_org_replaced_paddle_voc = np.concatenate([ wav_org_replaced_paddle_voc = np.concatenate([
wav_org[:old_time_boundary[0]], wav_org[:old_time_bdy[0]],
replaced_wav[new_time_boundary[0]:new_time_boundary[1]], replaced_wav[new_time_bdy[0]:new_time_bdy[1]],
wav_org[old_time_boundary[1]:] wav_org[old_time_bdy[1]:]
]) ])
data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc} data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc}
elif target_language == 'chinese': elif target_lang == 'chinese':
wav_org_replaced_only_mask_fst2_voc = np.concatenate([ wav_org_replaced_only_mask_fst2_voc = np.concatenate([
wav_org[:old_time_boundary[0]], replaced_wav_only_mask_fst2_voc, wav_org[:old_time_bdy[0]], replaced_wav_only_mask_fst2_voc,
wav_org[old_time_boundary[1]:] wav_org[old_time_bdy[1]:]
]) ])
data_dict = { data_dict = {
"origin": wav_org, "origin": wav_org,
"output": wav_org_replaced_only_mask_fst2_voc, "output": wav_org_replaced_only_mask_fst2_voc,
} }
return data_dict, old_span_boundary return data_dict, old_span_bdy
def get_unk_phns(word_str): def get_unk_phns(word_str: str):
tmpbase = '/tmp/tp.' tmpbase = '/tmp/tp.'
f = open(tmpbase + 'temp.words', 'w') f = open(tmpbase + 'temp.words', 'w')
f.write(word_str) f.write(word_str)
...@@ -160,9 +148,8 @@ def get_unk_phns(word_str): ...@@ -160,9 +148,8 @@ def get_unk_phns(word_str):
return phns return phns
def words2phns(line): def words2phns(line: str):
dictfile = MODEL_DIR_EN + '/dict' dictfile = MODEL_DIR_EN + '/dict'
tmpbase = '/tmp/tp.'
line = line.strip() line = line.strip()
words = [] words = []
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']: for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
...@@ -200,9 +187,8 @@ def words2phns(line): ...@@ -200,9 +187,8 @@ def words2phns(line):
return phns, wrd2phns return phns, wrd2phns
def words2phns_zh(line): def words2phns_zh(line: str):
dictfile = MODEL_DIR_ZH + '/dict' dictfile = MODEL_DIR_ZH + '/dict'
tmpbase = '/tmp/tp.'
line = line.strip() line = line.strip()
words = [] words = []
for pun in [ for pun in [
...@@ -242,7 +228,7 @@ def words2phns_zh(line): ...@@ -242,7 +228,7 @@ def words2phns_zh(line):
return phns, wrd2phns return phns, wrd2phns
def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"): def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "") vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "")
vocoder_file = download_pretrained_model(vocoder_tag) vocoder_file = download_pretrained_model(vocoder_tag)
vocoder_config = Path(vocoder_file).parent / "config.yml" vocoder_config = Path(vocoder_file).parent / "config.yml"
...@@ -250,7 +236,7 @@ def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"): ...@@ -250,7 +236,7 @@ def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"):
return vocoder return vocoder
def load_model(model_name): def load_model(model_name: str):
config_path = './pretrained_model/{}/config.yaml'.format(model_name) config_path = './pretrained_model/{}/config.yaml'.format(model_name)
model_path = './pretrained_model/{}/model.pdparams'.format(model_name) model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
mlm_model, args = build_model_from_file( mlm_model, args = build_model_from_file(
...@@ -258,7 +244,7 @@ def load_model(model_name): ...@@ -258,7 +244,7 @@ def load_model(model_name):
return mlm_model, args return mlm_model, args
def read_data(uid, prefix): def read_data(uid: str, prefix: str):
mfa_text = read_2column_text(prefix + '/text')[uid] mfa_text = read_2column_text(prefix + '/text')[uid]
mfa_wav_path = read_2column_text(prefix + '/wav.scp')[uid] mfa_wav_path = read_2column_text(prefix + '/wav.scp')[uid]
if 'mnt' not in mfa_wav_path: if 'mnt' not in mfa_wav_path:
...@@ -266,7 +252,7 @@ def read_data(uid, prefix): ...@@ -266,7 +252,7 @@ def read_data(uid, prefix):
return mfa_text, mfa_wav_path return mfa_text, mfa_wav_path
def get_align_data(uid, prefix): def get_align_data(uid: str, prefix: str):
mfa_path = prefix + "mfa_" mfa_path = prefix + "mfa_"
mfa_text = read_2column_text(mfa_path + 'text')[uid] mfa_text = read_2column_text(mfa_path + 'text')[uid]
mfa_start = load_num_sequence_text( mfa_start = load_num_sequence_text(
...@@ -277,43 +263,45 @@ def get_align_data(uid, prefix): ...@@ -277,43 +263,45 @@ def get_align_data(uid, prefix):
return mfa_text, mfa_start, mfa_end, mfa_wav_path return mfa_text, mfa_start, mfa_end, mfa_wav_path
def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, def get_masked_mel_bdy(mfa_start: List[float],
span_tobe_replaced): mfa_end: List[float],
fs: int,
hop_length: int,
span_to_repl: List[List[int]]):
align_start = paddle.to_tensor(mfa_start).unsqueeze(0) align_start = paddle.to_tensor(mfa_start).unsqueeze(0)
align_end = paddle.to_tensor(mfa_end).unsqueeze(0) align_end = paddle.to_tensor(mfa_end).unsqueeze(0)
align_start = paddle.floor(fs * align_start / hop_length).int() align_start = paddle.floor(fs * align_start / hop_length).int()
align_end = paddle.floor(fs * align_end / hop_length).int() align_end = paddle.floor(fs * align_end / hop_length).int()
if span_tobe_replaced[0] >= len(mfa_start): if span_to_repl[0] >= len(mfa_start):
span_boundary = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]] span_bdy = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]]
else: else:
span_boundary = [ span_bdy = [
align_start[0].tolist()[span_tobe_replaced[0]], align_start[0].tolist()[span_to_repl[0]],
align_end[0].tolist()[span_tobe_replaced[1] - 1] align_end[0].tolist()[span_to_repl[1] - 1]
] ]
return span_boundary return span_bdy
def recover_dict(word2phns, tp_word2phns): def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
dic = {} dic = {}
need_del_key = [] keys_to_del = []
exist_index = [] exist_idx = []
sp_count = 0 sp_count = 0
add_sp_count = 0 add_sp_count = 0
for key in word2phns.keys(): for key in word2phns.keys():
idx, wrd = key.split('_') idx, wrd = key.split('_')
if wrd == 'sp': if wrd == 'sp':
sp_count += 1 sp_count += 1
exist_index.append(int(idx)) exist_idx.append(int(idx))
else: else:
need_del_key.append(key) keys_to_del.append(key)
for key in need_del_key: for key in keys_to_del:
del word2phns[key] del word2phns[key]
cur_id = 0 cur_id = 0
for key in tp_word2phns.keys(): for key in tp_word2phns.keys():
# print("debug: ",key) if cur_id in exist_idx:
if cur_id in exist_index:
dic[str(cur_id) + "_sp"] = 'sp' dic[str(cur_id) + "_sp"] = 'sp'
cur_id += 1 cur_id += 1
add_sp_count += 1 add_sp_count += 1
...@@ -329,14 +317,17 @@ def recover_dict(word2phns, tp_word2phns): ...@@ -329,14 +317,17 @@ def recover_dict(word2phns, tp_word2phns):
return dic return dic
def get_phns_and_spans(wav_path, old_str, new_str, source_language, def get_phns_and_spans(wav_path: str,
clone_target_language): old_str: str="",
new_str: str="",
source_lang: str="english",
target_lang: str="english"):
append_new_str = (old_str == new_str[:len(old_str)]) append_new_str = (old_str == new_str[:len(old_str)])
old_phns, mfa_start, mfa_end = [], [], [] old_phns, mfa_start, mfa_end = [], [], []
if source_language == "english": if source_lang == "english":
times2, word2phns = alignment(wav_path, old_str) times2, word2phns = alignment(wav_path, old_str)
elif source_language == "chinese": elif source_lang == "chinese":
times2, word2phns = alignment_zh(wav_path, old_str) times2, word2phns = alignment_zh(wav_path, old_str)
_, tp_word2phns = words2phns_zh(old_str) _, tp_word2phns = words2phns_zh(old_str)
...@@ -348,14 +339,14 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, ...@@ -348,14 +339,14 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
word2phns = recover_dict(word2phns, tp_word2phns) word2phns = recover_dict(word2phns, tp_word2phns)
else: else:
assert source_language == "chinese" or source_language == "english", "source_language is wrong..." assert source_lang == "chinese" or source_lang == "english", "source_lang is wrong..."
for item in times2: for item in times2:
mfa_start.append(float(item[1])) mfa_start.append(float(item[1]))
mfa_end.append(float(item[2])) mfa_end.append(float(item[2]))
old_phns.append(item[0]) old_phns.append(item[0])
if append_new_str and (source_language != clone_target_language): if append_new_str and (source_lang != target_lang):
is_cross_lingual_clone = True is_cross_lingual_clone = True
else: else:
is_cross_lingual_clone = False is_cross_lingual_clone = False
...@@ -364,18 +355,21 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, ...@@ -364,18 +355,21 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
new_str_origin = new_str[:len(old_str)] new_str_origin = new_str[:len(old_str)]
new_str_append = new_str[len(old_str):] new_str_append = new_str[len(old_str):]
if clone_target_language == "chinese": if target_lang == "chinese":
new_phns_origin, new_origin_word2phns = words2phns(new_str_origin) new_phns_origin, new_origin_word2phns = words2phns(new_str_origin)
new_phns_append, temp_new_append_word2phns = words2phns_zh( new_phns_append, temp_new_append_word2phns = words2phns_zh(
new_str_append) new_str_append)
elif clone_target_language == "english": elif target_lang == "english":
# 原始句子
new_phns_origin, new_origin_word2phns = words2phns_zh( new_phns_origin, new_origin_word2phns = words2phns_zh(
new_str_origin) # 原始句子 new_str_origin)
# clone句子
new_phns_append, temp_new_append_word2phns = words2phns( new_phns_append, temp_new_append_word2phns = words2phns(
new_str_append) # clone句子 new_str_append)
else: else:
assert clone_target_language == "chinese" or clone_target_language == "english", "cloning is not support for this language, please check it." assert target_lang == "chinese" or target_lang == "english", \
"cloning is not support for this language, please check it."
new_phns = new_phns_origin + new_phns_append new_phns = new_phns_origin + new_phns_append
...@@ -390,16 +384,17 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, ...@@ -390,16 +384,17 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
new_append_word2phns.items())) new_append_word2phns.items()))
else: else:
if source_language == clone_target_language and clone_target_language == "english": if source_lang == target_lang and target_lang == "english":
new_phns, new_word2phns = words2phns(new_str) new_phns, new_word2phns = words2phns(new_str)
elif source_language == clone_target_language and clone_target_language == "chinese": elif source_lang == target_lang and target_lang == "chinese":
new_phns, new_word2phns = words2phns_zh(new_str) new_phns, new_word2phns = words2phns_zh(new_str)
else: else:
assert source_language == clone_target_language, "source language is not same with target language..." assert source_lang == target_lang, \
"source language is not same with target language..."
span_tobe_replaced = [0, len(old_phns) - 1] span_to_repl = [0, len(old_phns) - 1]
span_tobe_added = [0, len(new_phns) - 1] span_to_add = [0, len(new_phns) - 1]
left_index = 0 left_idx = 0
new_phns_left = [] new_phns_left = []
sp_count = 0 sp_count = 0
# find the left different index # find the left different index
...@@ -411,27 +406,27 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, ...@@ -411,27 +406,27 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
else: else:
idx = str(int(idx) - sp_count) idx = str(int(idx) - sp_count)
if idx + '_' + wrd in new_word2phns: if idx + '_' + wrd in new_word2phns:
left_index += len(new_word2phns[idx + '_' + wrd]) left_idx += len(new_word2phns[idx + '_' + wrd])
new_phns_left.extend(word2phns[key].split()) new_phns_left.extend(word2phns[key].split())
else: else:
span_tobe_replaced[0] = len(new_phns_left) span_to_repl[0] = len(new_phns_left)
span_tobe_added[0] = len(new_phns_left) span_to_add[0] = len(new_phns_left)
break break
# reverse word2phns and new_word2phns # reverse word2phns and new_word2phns
right_index = 0 right_idx = 0
new_phns_right = [] new_phns_right = []
sp_count = 0 sp_count = 0
word2phns_max_index = int(list(word2phns.keys())[-1].split('_')[0]) word2phns_max_idx = int(list(word2phns.keys())[-1].split('_')[0])
new_word2phns_max_index = int(list(new_word2phns.keys())[-1].split('_')[0]) new_word2phns_max_idx = int(list(new_word2phns.keys())[-1].split('_')[0])
new_phns_middle = [] new_phns_mid = []
if append_new_str: if append_new_str:
new_phns_right = [] new_phns_right = []
new_phns_middle = new_phns[left_index:] new_phns_mid = new_phns[left_idx:]
span_tobe_replaced[0] = len(new_phns_left) span_to_repl[0] = len(new_phns_left)
span_tobe_added[0] = len(new_phns_left) span_to_add[0] = len(new_phns_left)
span_tobe_added[1] = len(new_phns_left) + len(new_phns_middle) span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
span_tobe_replaced[1] = len(old_phns) - len(new_phns_right) span_to_repl[1] = len(old_phns) - len(new_phns_right)
else: else:
for key in list(word2phns.keys())[::-1]: for key in list(word2phns.keys())[::-1]:
idx, wrd = key.split('_') idx, wrd = key.split('_')
...@@ -439,33 +434,31 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, ...@@ -439,33 +434,31 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
sp_count += 1 sp_count += 1
new_phns_right = ['sp'] + new_phns_right new_phns_right = ['sp'] + new_phns_right
else: else:
idx = str(new_word2phns_max_index - (word2phns_max_index - int( idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx)
idx) - sp_count)) - sp_count))
if idx + '_' + wrd in new_word2phns: if idx + '_' + wrd in new_word2phns:
right_index -= len(new_word2phns[idx + '_' + wrd]) right_idx -= len(new_word2phns[idx + '_' + wrd])
new_phns_right = word2phns[key].split() + new_phns_right new_phns_right = word2phns[key].split() + new_phns_right
else: else:
span_tobe_replaced[1] = len(old_phns) - len(new_phns_right) span_to_repl[1] = len(old_phns) - len(new_phns_right)
new_phns_middle = new_phns[left_index:right_index] new_phns_mid = new_phns[left_idx:right_idx]
span_tobe_added[1] = len(new_phns_left) + len( span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
new_phns_middle) if len(new_phns_mid) == 0:
if len(new_phns_middle) == 0: span_to_add[1] = min(span_to_add[1] + 1, len(new_phns))
span_tobe_added[1] = min(span_tobe_added[1] + 1, span_to_add[0] = max(0, span_to_add[0] - 1)
len(new_phns)) span_to_repl[0] = max(0, span_to_repl[0] - 1)
span_tobe_added[0] = max(0, span_tobe_added[0] - 1) span_to_repl[1] = min(span_to_repl[1] + 1,
span_tobe_replaced[0] = max(0, len(old_phns))
span_tobe_replaced[0] - 1)
span_tobe_replaced[1] = min(span_tobe_replaced[1] + 1,
len(old_phns))
break break
new_phns = new_phns_left + new_phns_middle + new_phns_right new_phns = new_phns_left + new_phns_mid + new_phns_right
return mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add
def duration_adjust_factor(original_dur, pred_dur, phns): def duration_adjust_factor(original_dur: List[int],
pred_dur: List[int],
phns: List[str]):
length = 0 length = 0
accumulate = 0
factor_list = [] factor_list = []
for ori, pred, phn in zip(original_dur, pred_dur, phns): for ori, pred, phn in zip(original_dur, pred_dur, phns):
if pred == 0 or phn == 'sp': if pred == 0 or phn == 'sp':
...@@ -481,242 +474,224 @@ def duration_adjust_factor(original_dur, pred_dur, phns): ...@@ -481,242 +474,224 @@ def duration_adjust_factor(original_dur, pred_dur, phns):
return np.average(factor_list[length:-length]) return np.average(factor_list[length:-length])
def prepare_features_with_duration(uid, def prepare_features_with_duration(uid: str,
prefix, prefix: str,
clone_uid, wav_path: str,
clone_prefix, mlm_model: nn.Layer,
source_language, source_lang: str="English",
target_language, target_lang: str="English",
mlm_model, old_str: str="",
old_str, new_str: str="",
new_str, duration_preditor_path: str=None,
wav_path, sid: str=None,
duration_preditor_path, mask_reconstruct: bool=False,
sid=None, duration_adjust: bool=True,
mask_reconstruct=False, start_end_sp: bool=False,
duration_adjust=True,
start_end_sp=False,
train_args=None): train_args=None):
wav_org, rate = librosa.load( wav_org, rate = librosa.load(
wav_path, sr=train_args.feats_extract_conf['fs']) wav_path, sr=train_args.feats_extract_conf['fs'])
fs = train_args.feats_extract_conf['fs'] fs = train_args.feats_extract_conf['fs']
hop_length = train_args.feats_extract_conf['hop_length'] hop_length = train_args.feats_extract_conf['hop_length']
mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added = get_phns_and_spans( mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans(
wav_path, old_str, new_str, source_language, target_language) wav_path=wav_path,
old_str=old_str,
new_str=new_str,
source_lang=source_lang,
target_lang=target_lang)
if start_end_sp: if start_end_sp:
if new_phns[-1] != 'sp': if new_phns[-1] != 'sp':
new_phns = new_phns + ['sp'] new_phns = new_phns + ['sp']
if target_language == "english": if target_lang == "english":
old_durations = evaluate_durations( old_durations = evaluate_durations(old_phns, target_lang=target_lang)
old_phns, target_language=target_language)
elif target_language == "chinese": elif target_lang == "chinese":
if source_language == "english": if source_lang == "english":
old_durations = evaluate_durations( old_durations = evaluate_durations(
old_phns, target_language=source_language) old_phns, target_lang=source_lang)
elif source_language == "chinese": elif source_lang == "chinese":
old_durations = evaluate_durations( old_durations = evaluate_durations(
old_phns, target_language=source_language) old_phns, target_lang=source_lang)
else: else:
assert target_language == "chinese" or target_language == "english", "calculate duration_predict is not support for this language..." assert target_lang == "chinese" or target_lang == "english", "calculate duration_predict is not support for this language..."
original_old_durations = [e - s for e, s in zip(mfa_end, mfa_start)] original_old_durations = [e - s for e, s in zip(mfa_end, mfa_start)]
if '[MASK]' in new_str: if '[MASK]' in new_str:
new_phns = old_phns new_phns = old_phns
span_tobe_added = span_tobe_replaced span_to_add = span_to_repl
d_factor_left = duration_adjust_factor( d_factor_left = duration_adjust_factor(
original_old_durations[:span_tobe_replaced[0]], original_old_durations[:span_to_repl[0]],
old_durations[:span_tobe_replaced[0]], old_durations[:span_to_repl[0]], old_phns[:span_to_repl[0]])
old_phns[:span_tobe_replaced[0]])
d_factor_right = duration_adjust_factor( d_factor_right = duration_adjust_factor(
original_old_durations[span_tobe_replaced[1]:], original_old_durations[span_to_repl[1]:],
old_durations[span_tobe_replaced[1]:], old_durations[span_to_repl[1]:], old_phns[span_to_repl[1]:])
old_phns[span_tobe_replaced[1]:])
d_factor = (d_factor_left + d_factor_right) / 2 d_factor = (d_factor_left + d_factor_right) / 2
new_durations_adjusted = [d_factor * i for i in old_durations] new_durations_adjusted = [d_factor * i for i in old_durations]
else: else:
if duration_adjust: if duration_adjust:
d_factor = duration_adjust_factor(original_old_durations, d_factor = duration_adjust_factor(original_old_durations,
old_durations, old_phns) old_durations, old_phns)
d_factor_paddle = duration_adjust_factor(original_old_durations,
old_durations, old_phns)
d_factor = d_factor * 1.25 d_factor = d_factor * 1.25
else: else:
d_factor = 1 d_factor = 1
if target_language == "english": if target_lang == "english":
new_durations = evaluate_durations( new_durations = evaluate_durations(
new_phns, target_language=target_language) new_phns, target_lang=target_lang)
elif target_language == "chinese": elif target_lang == "chinese":
new_durations = evaluate_durations( new_durations = evaluate_durations(
new_phns, target_language=target_language) new_phns, target_lang=target_lang)
new_durations_adjusted = [d_factor * i for i in new_durations] new_durations_adjusted = [d_factor * i for i in new_durations]
if span_tobe_replaced[0] < len(old_phns) and old_phns[ if span_to_repl[0] < len(old_phns) and old_phns[span_to_repl[
span_tobe_replaced[0]] == new_phns[span_tobe_added[0]]: 0]] == new_phns[span_to_add[0]]:
new_durations_adjusted[span_tobe_added[0]] = original_old_durations[ new_durations_adjusted[span_to_add[0]] = original_old_durations[
span_tobe_replaced[0]] span_to_repl[0]]
if span_tobe_replaced[1] < len(old_phns) and span_tobe_added[1] < len( if span_to_repl[1] < len(old_phns) and span_to_add[1] < len(new_phns):
new_phns): if old_phns[span_to_repl[1]] == new_phns[span_to_add[1]]:
if old_phns[span_tobe_replaced[1]] == new_phns[span_tobe_added[1]]: new_durations_adjusted[span_to_add[1]] = original_old_durations[
new_durations_adjusted[span_tobe_added[ span_to_repl[1]]
1]] = original_old_durations[span_tobe_replaced[1]]
new_span_duration_sum = sum( new_span_duration_sum = sum(
new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]]) new_durations_adjusted[span_to_add[0]:span_to_add[1]])
old_span_duration_sum = sum( old_span_duration_sum = sum(
original_old_durations[span_tobe_replaced[0]:span_tobe_replaced[1]]) original_old_durations[span_to_repl[0]:span_to_repl[1]])
duration_offset = new_span_duration_sum - old_span_duration_sum duration_offset = new_span_duration_sum - old_span_duration_sum
new_mfa_start = mfa_start[:span_tobe_replaced[0]] new_mfa_start = mfa_start[:span_to_repl[0]]
new_mfa_end = mfa_end[:span_tobe_replaced[0]] new_mfa_end = mfa_end[:span_to_repl[0]]
for i in new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]]: for i in new_durations_adjusted[span_to_add[0]:span_to_add[1]]:
if len(new_mfa_end) == 0: if len(new_mfa_end) == 0:
new_mfa_start.append(0) new_mfa_start.append(0)
new_mfa_end.append(i) new_mfa_end.append(i)
else: else:
new_mfa_start.append(new_mfa_end[-1]) new_mfa_start.append(new_mfa_end[-1])
new_mfa_end.append(new_mfa_end[-1] + i) new_mfa_end.append(new_mfa_end[-1] + i)
new_mfa_start += [ new_mfa_start += [i + duration_offset for i in mfa_start[span_to_repl[1]:]]
i + duration_offset for i in mfa_start[span_tobe_replaced[1]:] new_mfa_end += [i + duration_offset for i in mfa_end[span_to_repl[1]:]]
]
new_mfa_end += [
i + duration_offset for i in mfa_end[span_tobe_replaced[1]:]
]
# 3. get new wav # 3. get new wav
if span_tobe_replaced[0] >= len(mfa_start): if span_to_repl[0] >= len(mfa_start):
left_index = len(wav_org) left_idx = len(wav_org)
right_index = left_index right_idx = left_idx
else: else:
left_index = int(np.floor(mfa_start[span_tobe_replaced[0]] * fs)) left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs))
right_index = int(np.ceil(mfa_end[span_tobe_replaced[1] - 1] * fs)) right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs))
new_blank_wav = np.zeros( new_blank_wav = np.zeros(
(int(np.ceil(new_span_duration_sum * fs)), ), dtype=wav_org.dtype) (int(np.ceil(new_span_duration_sum * fs)), ), dtype=wav_org.dtype)
new_wav_org = np.concatenate( new_wav_org = np.concatenate(
[wav_org[:left_index], new_blank_wav, wav_org[right_index:]]) [wav_org[:left_idx], new_blank_wav, wav_org[right_idx:]])
# 4. get old and new mel span to be mask # 4. get old and new mel span to be mask
old_span_boundary = get_masked_mel_boundary( # [92, 92]
mfa_start, mfa_end, fs, hop_length, span_tobe_replaced) # [92, 92] old_span_bdy = get_masked_mel_bdy(mfa_start, mfa_end, fs, hop_length,
new_span_boundary = get_masked_mel_boundary(new_mfa_start, new_mfa_end, fs, span_to_repl)
hop_length, # [92, 174]
span_tobe_added) # [92, 174] new_span_bdy = get_masked_mel_bdy(new_mfa_start, new_mfa_end, fs,
hop_length, span_to_add)
return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_boundary, new_span_boundary
return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy
def prepare_features(uid,
prefix, def prepare_features(uid: str,
clone_uid, mlm_model: nn.Layer,
clone_prefix,
source_language,
target_language,
mlm_model,
processor, processor,
wav_path, wav_path: str,
old_str, prefix: str="./prompt/dev/",
new_str, source_lang: str="english",
duration_preditor_path, target_lang: str="english",
sid=None, old_str: str="",
duration_adjust=True, new_str: str="",
start_end_sp=False, duration_preditor_path: str=None,
mask_reconstruct=False, sid: str=None,
duration_adjust: bool=True,
start_end_sp: bool=False,
mask_reconstruct: bool=False,
train_args=None): train_args=None):
wav_org, phns_list, mfa_start, mfa_end, old_span_boundary, new_span_boundary = prepare_features_with_duration( wav_org, phns_list, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prepare_features_with_duration(
uid, uid=uid,
prefix, prefix=prefix,
clone_uid, source_lang=source_lang,
clone_prefix, target_lang=target_lang,
source_language, mlm_model=mlm_model,
target_language, old_str=old_str,
mlm_model, new_str=new_str,
old_str, wav_path=wav_path,
new_str, duration_preditor_path=duration_preditor_path,
wav_path,
duration_preditor_path,
sid=sid, sid=sid,
duration_adjust=duration_adjust, duration_adjust=duration_adjust,
start_end_sp=start_end_sp, start_end_sp=start_end_sp,
mask_reconstruct=mask_reconstruct, mask_reconstruct=mask_reconstruct,
train_args=train_args) train_args=train_args)
speech = np.array(wav_org, dtype=np.float32) speech = wav_org
align_start = np.array(mfa_start) align_start = np.array(mfa_start)
align_end = np.array(mfa_end) align_end = np.array(mfa_end)
token_to_id = {item: i for i, item in enumerate(train_args.token_list)} token_to_id = {item: i for i, item in enumerate(train_args.token_list)}
text = np.array( text = np.array(
list( list(
map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns_list))) map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns_list)))
# print('unk id is', token_to_id['<unk>'])
# text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text']) span_bdy = np.array(new_span_bdy)
span_boundary = np.array(new_span_boundary)
batch = [('1', { batch = [('1', {
"speech": speech, "speech": speech,
"align_start": align_start, "align_start": align_start,
"align_end": align_end, "align_end": align_end,
"text": text, "text": text,
"span_boundary": span_boundary "span_bdy": span_bdy
})] })]
return batch, old_span_boundary, new_span_boundary return batch, old_span_bdy, new_span_bdy
def decode_with_model(uid, def decode_with_model(uid: str,
prefix, mlm_model: nn.Layer,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
processor, processor,
collate_fn, collate_fn,
wav_path, wav_path: str,
old_str, prefix: str="./prompt/dev/",
new_str, source_lang: str="english",
duration_preditor_path, target_lang: str="english",
sid=None, old_str: str="",
decoder=False, new_str: str="",
use_teacher_forcing=False, duration_preditor_path: str=None,
duration_adjust=True, sid: str=None,
start_end_sp=False, decoder: bool=False,
use_teacher_forcing: bool=False,
duration_adjust: bool=True,
start_end_sp: bool=False,
train_args=None): train_args=None):
fs, hop_length = train_args.feats_extract_conf[ fs, hop_length = train_args.feats_extract_conf[
'fs'], train_args.feats_extract_conf['hop_length'] 'fs'], train_args.feats_extract_conf['hop_length']
batch, old_span_boundary, new_span_boundary = prepare_features( batch, old_span_bdy, new_span_bdy = prepare_features(
uid, uid=uid,
prefix, prefix=prefix,
clone_uid, source_lang=source_lang,
clone_prefix, target_lang=target_lang,
source_language, mlm_model=mlm_model,
target_language, processor=processor,
mlm_model, wav_path=wav_path,
processor, old_str=old_str,
wav_path, new_str=new_str,
old_str, duration_preditor_path=duration_preditor_path,
new_str, sid=sid,
duration_preditor_path,
sid,
duration_adjust=duration_adjust, duration_adjust=duration_adjust,
start_end_sp=start_end_sp, start_end_sp=start_end_sp,
train_args=train_args) train_args=train_args)
feats = collate_fn(batch)[1] feats = collate_fn(batch)[1]
if 'text_masked_position' in feats.keys(): if 'text_masked_pos' in feats.keys():
feats.pop('text_masked_position') feats.pop('text_masked_pos')
for k, v in feats.items(): for k, v in feats.items():
feats[k] = paddle.to_tensor(v) feats[k] = paddle.to_tensor(v)
rtn = mlm_model.inference( rtn = mlm_model.inference(
**feats, **feats, span_bdy=new_span_bdy, use_teacher_forcing=use_teacher_forcing)
span_boundary=new_span_boundary,
use_teacher_forcing=use_teacher_forcing)
output = rtn['feat_gen'] output = rtn['feat_gen']
if 0 in output[0].shape and 0 not in output[-1].shape: if 0 in output[0].shape and 0 not in output[-1].shape:
output_feat = paddle.concat( output_feat = paddle.concat(
...@@ -731,12 +706,9 @@ def decode_with_model(uid, ...@@ -731,12 +706,9 @@ def decode_with_model(uid,
[output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)], [output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)],
axis=0).cpu() axis=0).cpu()
wav_org, rate = librosa.load( wav_org, _ = librosa.load(
wav_path, sr=train_args.feats_extract_conf['fs']) wav_path, sr=train_args.feats_extract_conf['fs'])
origin_speech = paddle.to_tensor( return wav_org, None, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
np.array(wav_org, dtype=np.float32)).unsqueeze(0)
speech_lengths = paddle.to_tensor(len(wav_org)).unsqueeze(0)
return wav_org, None, output_feat, old_span_boundary, new_span_boundary, fs, hop_length
class MLMCollateFn: class MLMCollateFn:
...@@ -800,33 +772,15 @@ def mlm_collate_fn( ...@@ -800,33 +772,15 @@ def mlm_collate_fn(
sega_emb: bool=False, sega_emb: bool=False,
duration_collect: bool=False, duration_collect: bool=False,
text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]: text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
Examples:
>>> from espnet2.samplers.constant_batch_sampler import ConstantBatchSampler,
>>> import espnet2.tasks.abs_task
>>> from espnet2.train.dataset import ESPnetDataset
>>> sampler = ConstantBatchSampler(...)
>>> dataset = ESPnetDataset(...)
>>> keys = next(iter(sampler)
>>> batch = [dataset[key] for key in keys]
>>> batch = common_collate_fn(batch)
>>> model(**batch)
Note that the dict-keys of batch are propagated from
that of the dataset as they are.
"""
uttids = [u for u, _ in data] uttids = [u for u, _ in data]
data = [d for _, d in data] data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(not k.endswith("_lengths") assert all(not k.endswith("_lens")
for k in data[0]), f"*_lengths is reserved: {list(data[0])}" for k in data[0]), f"*_lens is reserved: {list(data[0])}"
output = {} output = {}
for key in data[0]: for key in data[0]:
# NOTE(kamo):
# Each models, which accepts these values finally, are responsible # Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks. # to repaint the pad_value to the desired value for each tasks.
if data[0][key].dtype.kind == "i": if data[0][key].dtype.kind == "i":
...@@ -846,37 +800,35 @@ def mlm_collate_fn( ...@@ -846,37 +800,35 @@ def mlm_collate_fn(
# lens: (Batch,) # lens: (Batch,)
if key not in not_sequence: if key not in not_sequence:
lens = paddle.to_tensor( lens = paddle.to_tensor(
[d[key].shape[0] for d in data], dtype=paddle.long) [d[key].shape[0] for d in data], dtype=paddle.int64)
output[key + "_lengths"] = lens output[key + "_lens"] = lens
feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0])) feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
feats = paddle.to_tensor(feats) feats = paddle.to_tensor(feats)
# print('out shape', paddle.shape(feats)) feats_lens = paddle.shape(feats)[0]
feats_lengths = paddle.shape(feats)[0]
feats = paddle.unsqueeze(feats, 0) feats = paddle.unsqueeze(feats, 0)
batch_size = paddle.shape(feats)[0]
if 'text' not in output: if 'text' not in output:
text = paddle.zeros_like(feats_lengths.unsqueeze(-1)) - 2 text = paddle.zeros(paddle.shape(feats_lens.unsqueeze(-1))) - 2
text_lengths = paddle.zeros_like(feats_lengths) + 1 text_lens = paddle.zeros(paddle.shape(feats_lens)) + 1
max_tlen = 1 max_tlen = 1
align_start = paddle.zeros_like(text) align_start = paddle.zeros(paddle.shape(text))
align_end = paddle.zeros_like(text) align_end = paddle.zeros(paddle.shape(text))
align_start_lengths = paddle.zeros_like(feats_lengths) align_start_lens = paddle.zeros(paddle.shape(feats_lens))
align_end_lengths = paddle.zeros_like(feats_lengths)
sega_emb = False sega_emb = False
mean_phn_span = 0 mean_phn_span = 0
mlm_prob = 0.15 mlm_prob = 0.15
else: else:
text, text_lengths = output["text"], output["text_lengths"] text = output["text"]
align_start, align_start_lengths, align_end, align_end_lengths = output[ text_lens = output["text_lens"]
"align_start"], output["align_start_lengths"], output[ align_start = output["align_start"]
"align_end"], output["align_end_lengths"] align_start_lens = output["align_start_lens"]
align_end = output["align_end"]
align_start = paddle.floor(feats_extract.sr * align_start / align_start = paddle.floor(feats_extract.sr * align_start /
feats_extract.hop_length).int() feats_extract.hop_length).int()
align_end = paddle.floor(feats_extract.sr * align_end / align_end = paddle.floor(feats_extract.sr * align_end /
feats_extract.hop_length).int() feats_extract.hop_length).int()
max_tlen = max(text_lengths).item() max_tlen = max(text_lens)
max_slen = max(feats_lengths).item() max_slen = max(feats_lens)
speech_pad = feats[:, :max_slen] speech_pad = feats[:, :max_slen]
if attention_window > 0 and pad_speech: if attention_window > 0 and pad_speech:
speech_pad, max_slen = pad_to_longformer_att_window( speech_pad, max_slen = pad_to_longformer_att_window(
...@@ -888,51 +840,49 @@ def mlm_collate_fn( ...@@ -888,51 +840,49 @@ def mlm_collate_fn(
else: else:
text_pad = text text_pad = text
text_mask = make_non_pad_mask( text_mask = make_non_pad_mask(
text_lengths.tolist(), text_pad, length_dim=1).unsqueeze(-2) text_lens, text_pad, length_dim=1).unsqueeze(-2)
if attention_window > 0: if attention_window > 0:
text_mask = text_mask * 2 text_mask = text_mask * 2
speech_mask = make_non_pad_mask( speech_mask = make_non_pad_mask(
feats_lengths.tolist(), speech_pad[:, :, 0], length_dim=1).unsqueeze(-2) feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
span_boundary = None span_bdy = None
if 'span_boundary' in output.keys(): if 'span_bdy' in output.keys():
span_boundary = output['span_boundary'] span_bdy = output['span_bdy']
if text_masking: if text_masking:
masked_position, text_masked_position, _ = phones_text_masking( masked_pos, text_masked_pos, _ = phones_text_masking(
speech_pad, speech_mask, text_pad, text_mask, align_start, speech_pad, speech_mask, text_pad, text_mask, align_start,
align_end, align_start_lengths, mlm_prob, mean_phn_span, align_end, align_start_lens, mlm_prob, mean_phn_span, span_bdy)
span_boundary)
else: else:
text_masked_position = np.zeros(text_pad.size()) text_masked_pos = paddle.zeros(paddle.shape(text_pad))
masked_position, _ = phones_masking( masked_pos, _ = phones_masking(speech_pad, speech_mask, align_start,
speech_pad, speech_mask, align_start, align_end, align_end, align_start_lens, mlm_prob,
align_start_lengths, mlm_prob, mean_phn_span, span_boundary) mean_phn_span, span_bdy)
output_dict = {} output_dict = {}
if duration_collect and 'text' in output: if duration_collect and 'text' in output:
reordered_index, speech_segment_pos, text_segment_pos, durations, feats_lengths = get_segment_pos_reduce_duration( reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens = get_seg_pos_reduce_duration(
speech_pad, text_pad, align_start, align_end, align_start_lengths, speech_pad, text_pad, align_start, align_end, align_start_lens,
sega_emb, masked_position, feats_lengths) sega_emb, masked_pos, feats_lens)
speech_mask = make_non_pad_mask( speech_mask = make_non_pad_mask(
feats_lengths.tolist(), feats_lens, speech_pad[:, :reordered_idx.shape[1], 0],
speech_pad[:, :reordered_index.shape[1], 0],
length_dim=1).unsqueeze(-2) length_dim=1).unsqueeze(-2)
output_dict['durations'] = durations output_dict['durations'] = durations
output_dict['reordered_index'] = reordered_index output_dict['reordered_idx'] = reordered_idx
else: else:
speech_segment_pos, text_segment_pos = get_segment_pos( speech_seg_pos, text_seg_pos = get_seg_pos(speech_pad, text_pad,
speech_pad, text_pad, align_start, align_end, align_start_lengths, align_start, align_end,
sega_emb) align_start_lens, sega_emb)
output_dict['speech'] = speech_pad output_dict['speech'] = speech_pad
output_dict['text'] = text_pad output_dict['text'] = text_pad
output_dict['masked_position'] = masked_position output_dict['masked_pos'] = masked_pos
output_dict['text_masked_position'] = text_masked_position output_dict['text_masked_pos'] = text_masked_pos
output_dict['speech_mask'] = speech_mask output_dict['speech_mask'] = speech_mask
output_dict['text_mask'] = text_mask output_dict['text_mask'] = text_mask
output_dict['speech_segment_pos'] = speech_segment_pos output_dict['speech_seg_pos'] = speech_seg_pos
output_dict['text_segment_pos'] = text_segment_pos output_dict['text_seg_pos'] = text_seg_pos
output_dict['speech_lengths'] = output["speech_lengths"] output_dict['speech_lens'] = output["speech_lens"]
output_dict['text_lengths'] = text_lengths output_dict['text_lens'] = text_lens
output = (uttids, output_dict) output = (uttids, output_dict)
return output return output
...@@ -940,13 +890,13 @@ def mlm_collate_fn( ...@@ -940,13 +890,13 @@ def mlm_collate_fn(
def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1): def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
# -> Callable[ # -> Callable[
# [Collection[Tuple[str, Dict[str, np.ndarray]]]], # [Collection[Tuple[str, Dict[str, np.ndarray]]]],
# Tuple[List[str], Dict[str, torch.Tensor]], # Tuple[List[str], Dict[str, Tensor]],
# ]: # ]:
# assert check_argument_types() # assert check_argument_types()
# return CommonCollateFn(float_pad_value=0.0, int_pad_value=0) # return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
feats_extract_class = LogMelFBank feats_extract_class = LogMelFBank
if args.feats_extract_conf['win_length'] == None: if args.feats_extract_conf['win_length'] is None:
args.feats_extract_conf['win_length'] = args.feats_extract_conf['n_fft'] args.feats_extract_conf['win_length'] = args.feats_extract_conf['n_fft']
args_dic = {} args_dic = {}
...@@ -955,7 +905,6 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1): ...@@ -955,7 +905,6 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
args_dic['sr'] = v args_dic['sr'] = v
else: else:
args_dic[k] = v args_dic[k] = v
# feats_extract = feats_extract_class(**args.feats_extract_conf)
feats_extract = feats_extract_class(**args_dic) feats_extract = feats_extract_class(**args_dic)
sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False
...@@ -969,8 +918,7 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1): ...@@ -969,8 +918,7 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
if epoch == -1: if epoch == -1:
mlm_prob_factor = 1 mlm_prob_factor = 1
else: else:
mlm_probs = [1.0, 1.0, 0.7, 0.6, 0.5] mlm_prob_factor = 0.8
mlm_prob_factor = 0.8 #mlm_probs[epoch // 100]
if 'duration_predictor_layers' in args.model_conf.keys( if 'duration_predictor_layers' in args.model_conf.keys(
) and args.model_conf['duration_predictor_layers'] > 0: ) and args.model_conf['duration_predictor_layers'] > 0:
duration_collect = True duration_collect = True
...@@ -989,42 +937,37 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1): ...@@ -989,42 +937,37 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
duration_collect=duration_collect) duration_collect=duration_collect)
def get_mlm_output(uid, def get_mlm_output(uid: str,
prefix, wav_path: str,
clone_uid, prefix: str="./prompt/dev/",
clone_prefix, model_name: str="conformer",
source_language, source_lang: str="english",
target_language, target_lang: str="english",
model_name, old_str: str="",
wav_path, new_str: str="",
old_str, duration_preditor_path: str=None,
new_str, sid: str=None,
duration_preditor_path, decoder: bool=False,
sid=None, use_teacher_forcing: bool=False,
decoder=False, duration_adjust: bool=True,
use_teacher_forcing=False, start_end_sp: bool=False):
dynamic_eval=(0, 0),
duration_adjust=True,
start_end_sp=False):
mlm_model, train_args = load_model(model_name) mlm_model, train_args = load_model(model_name)
mlm_model.eval() mlm_model.eval()
processor = None processor = None
collate_fn = build_collate_fn(train_args, False) collate_fn = build_collate_fn(train_args, False)
return decode_with_model( return decode_with_model(
uid, uid=uid,
prefix, prefix=prefix,
clone_uid, source_lang=source_lang,
clone_prefix, target_lang=target_lang,
source_language, mlm_model=mlm_model,
target_language, processor=processor,
mlm_model, collate_fn=collate_fn,
processor, wav_path=wav_path,
collate_fn, old_str=old_str,
wav_path, new_str=new_str,
old_str, duration_preditor_path=duration_preditor_path,
new_str,
duration_preditor_path,
sid=sid, sid=sid,
decoder=decoder, decoder=decoder,
use_teacher_forcing=use_teacher_forcing, use_teacher_forcing=use_teacher_forcing,
...@@ -1033,23 +976,20 @@ def get_mlm_output(uid, ...@@ -1033,23 +976,20 @@ def get_mlm_output(uid,
train_args=train_args) train_args=train_args)
def test_vctk(uid, def evaluate(uid: str,
clone_uid, source_lang: str="english",
clone_prefix, target_lang: str="english",
source_language, use_pt_vocoder: bool=False,
target_language, prefix: str="./prompt/dev/",
vocoder, model_name: str="conformer",
prefix='dump/raw/dev', old_str: str="",
model_name="conformer", new_str: str="",
old_str="", prompt_decoding: bool=False,
new_str="", task_name: str=None):
prompt_decoding=False,
dynamic_eval=(0, 0),
task_name=None):
duration_preditor_path = None duration_preditor_path = None
spemd = None spemd = None
full_origin_str, wav_path = read_data(uid, prefix) full_origin_str, wav_path = read_data(uid=uid, prefix=prefix)
if task_name == 'edit': if task_name == 'edit':
new_str = new_str new_str = new_str
...@@ -1065,19 +1005,17 @@ def test_vctk(uid, ...@@ -1065,19 +1005,17 @@ def test_vctk(uid,
old_str = full_origin_str old_str = full_origin_str
results_dict, old_span = plot_mel_and_vocode_wav( results_dict, old_span = plot_mel_and_vocode_wav(
uid, uid=uid,
prefix, prefix=prefix,
clone_uid, source_lang=source_lang,
clone_prefix, target_lang=target_lang,
source_language, model_name=model_name,
target_language, wav_path=wav_path,
model_name, full_origin_str=full_origin_str,
wav_path, old_str=old_str,
full_origin_str, new_str=new_str,
old_str, use_pt_vocoder=use_pt_vocoder,
new_str, duration_preditor_path=duration_preditor_path,
vocoder,
duration_preditor_path,
sid=spemd) sid=spemd)
return results_dict return results_dict
...@@ -1086,17 +1024,14 @@ if __name__ == "__main__": ...@@ -1086,17 +1024,14 @@ if __name__ == "__main__":
# parse config and args # parse config and args
args = parse_args() args = parse_args()
data_dict = test_vctk( data_dict = evaluate(
args.uid, uid=args.uid,
args.clone_uid, source_lang=args.source_lang,
args.clone_prefix, target_lang=args.target_lang,
args.source_language, use_pt_vocoder=args.use_pt_vocoder,
args.target_language, prefix=args.prefix,
args.use_pt_vocoder, model_name=args.model_name,
args.prefix,
args.model_name,
new_str=args.new_str, new_str=args.new_str,
task_name=args.task_name) task_name=args.task_name)
sf.write(args.output_name, data_dict['output'], samplerate=24000) sf.write(args.output_name, data_dict['output'], samplerate=24000)
print("finished...") print("finished...")
# exit()
...@@ -121,12 +121,10 @@ class NewMaskInputLayer(nn.Layer): ...@@ -121,12 +121,10 @@ class NewMaskInputLayer(nn.Layer):
default_initializer=paddle.nn.initializer.Assign( default_initializer=paddle.nn.initializer.Assign(
paddle.normal(shape=(1, 1, out_features)))) paddle.normal(shape=(1, 1, out_features))))
def forward(self, input: paddle.Tensor, def forward(self, input: paddle.Tensor, masked_pos=None) -> paddle.Tensor:
masked_position=None) -> paddle.Tensor: masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input)
masked_position = paddle.expand_as( masked_input = masked_fill(input, masked_pos, 0) + masked_fill(
paddle.unsqueeze(masked_position, -1), input) paddle.expand_as(self.mask_feature, input), ~masked_pos, 0)
masked_input = masked_fill(input, masked_position, 0) + masked_fill(
paddle.expand_as(self.mask_feature, input), ~masked_position, 0)
return masked_input return masked_input
...@@ -443,37 +441,34 @@ class MLMEncoder(nn.Layer): ...@@ -443,37 +441,34 @@ class MLMEncoder(nn.Layer):
def forward(self, def forward(self,
speech_pad, speech_pad,
text_pad, text_pad,
masked_position, masked_pos,
speech_mask=None, speech_mask=None,
text_mask=None, text_mask=None,
speech_segment_pos=None, speech_seg_pos=None,
text_segment_pos=None): text_seg_pos=None):
"""Encode input sequence. """Encode input sequence.
""" """
if masked_position is not None: if masked_pos is not None:
speech_pad = self.speech_embed(speech_pad, masked_position) speech_pad = self.speech_embed(speech_pad, masked_pos)
else: else:
speech_pad = self.speech_embed(speech_pad) speech_pad = self.speech_embed(speech_pad)
# pure speech input # pure speech input
if -2 in np.array(text_pad): if -2 in np.array(text_pad):
text_pad = text_pad + 3 text_pad = text_pad + 3
text_mask = paddle.unsqueeze(bool(text_pad), 1) text_mask = paddle.unsqueeze(bool(text_pad), 1)
text_segment_pos = paddle.zeros_like(text_pad) text_seg_pos = paddle.zeros_like(text_pad)
text_pad = self.text_embed(text_pad) text_pad = self.text_embed(text_pad)
text_pad = (text_pad[0] + self.segment_emb(text_segment_pos), text_pad = (text_pad[0] + self.segment_emb(text_seg_pos),
text_pad[1]) text_pad[1])
text_segment_pos = None text_seg_pos = None
elif text_pad is not None: elif text_pad is not None:
text_pad = self.text_embed(text_pad) text_pad = self.text_embed(text_pad)
segment_emb = None if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb:
if speech_segment_pos is not None and text_segment_pos is not None and self.segment_emb: speech_seg_emb = self.segment_emb(speech_seg_pos)
speech_segment_emb = self.segment_emb(speech_segment_pos) text_seg_emb = self.segment_emb(text_seg_pos)
text_segment_emb = self.segment_emb(text_segment_pos) text_pad = (text_pad[0] + text_seg_emb, text_pad[1])
text_pad = (text_pad[0] + text_segment_emb, text_pad[1]) speech_pad = (speech_pad[0] + speech_seg_emb, speech_pad[1])
speech_pad = (speech_pad[0] + speech_segment_emb, speech_pad[1])
segment_emb = paddle.concat(
[speech_segment_emb, text_segment_emb], axis=1)
if self.pre_speech_encoders: if self.pre_speech_encoders:
speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask) speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask)
...@@ -493,11 +488,11 @@ class MLMEncoder(nn.Layer): ...@@ -493,11 +488,11 @@ class MLMEncoder(nn.Layer):
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)
return xs, masks #, segment_emb return xs, masks
class MLMDecoder(MLMEncoder): class MLMDecoder(MLMEncoder):
def forward(self, xs, masks, masked_position=None, segment_emb=None): def forward(self, xs, masks, masked_pos=None, segment_emb=None):
"""Encode input sequence. """Encode input sequence.
Args: Args:
...@@ -509,9 +504,8 @@ class MLMDecoder(MLMEncoder): ...@@ -509,9 +504,8 @@ class MLMDecoder(MLMEncoder):
paddle.Tensor: Mask tensor (#batch, time). paddle.Tensor: Mask tensor (#batch, time).
""" """
emb, mlm_position = None, None
if not self.training: if not self.training:
masked_position = None masked_pos = None
xs = self.embed(xs) xs = self.embed(xs)
if segment_emb: if segment_emb:
xs = (xs[0] + segment_emb, xs[1]) xs = (xs[0] + segment_emb, xs[1])
...@@ -632,18 +626,18 @@ class MLMModel(nn.Layer): ...@@ -632,18 +626,18 @@ class MLMModel(nn.Layer):
def collect_feats(self, def collect_feats(self,
speech, speech,
speech_lengths, speech_lens,
text, text,
text_lengths, text_lens,
masked_position, masked_pos,
speech_mask, speech_mask,
text_mask, text_mask,
speech_segment_pos, speech_seg_pos,
text_segment_pos, text_seg_pos,
y_masks=None) -> Dict[str, paddle.Tensor]: y_masks=None) -> Dict[str, paddle.Tensor]:
return {"feats": speech, "feats_lengths": speech_lengths} return {"feats": speech, "feats_lens": speech_lens}
def forward(self, batch, speech_segment_pos, y_masks=None): def forward(self, batch, speech_seg_pos, y_masks=None):
# feats: (Batch, Length, Dim) # feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2) # -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad'] speech_pad_placeholder = batch['speech_pad']
...@@ -654,7 +648,7 @@ class MLMModel(nn.Layer): ...@@ -654,7 +648,7 @@ class MLMModel(nn.Layer):
if self.decoder is not None: if self.decoder is not None:
zs, _ = self.decoder(ys_in, y_masks, encoder_out, zs, _ = self.decoder(ys_in, y_masks, encoder_out,
bool(h_masks), bool(h_masks),
self.encoder.segment_emb(speech_segment_pos)) self.encoder.segment_emb(speech_seg_pos))
speech_hidden_states = zs speech_hidden_states = zs
else: else:
speech_hidden_states = encoder_out[:, :paddle.shape(batch[ speech_hidden_states = encoder_out[:, :paddle.shape(batch[
...@@ -672,21 +666,21 @@ class MLMModel(nn.Layer): ...@@ -672,21 +666,21 @@ class MLMModel(nn.Layer):
else: else:
after_outs = None after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch[ return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_position'] 'masked_pos']
def inference( def inference(
self, self,
speech, speech,
text, text,
masked_position, masked_pos,
speech_mask, speech_mask,
text_mask, text_mask,
speech_segment_pos, speech_seg_pos,
text_segment_pos, text_seg_pos,
span_boundary, span_bdy,
y_masks=None, y_masks=None,
speech_lengths=None, speech_lens=None,
text_lengths=None, text_lens=None,
feats: Optional[paddle.Tensor]=None, feats: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None, spembs: Optional[paddle.Tensor]=None,
sids: Optional[paddle.Tensor]=None, sids: Optional[paddle.Tensor]=None,
...@@ -699,24 +693,24 @@ class MLMModel(nn.Layer): ...@@ -699,24 +693,24 @@ class MLMModel(nn.Layer):
batch = dict( batch = dict(
speech_pad=speech, speech_pad=speech,
text_pad=text, text_pad=text,
masked_position=masked_position, masked_pos=masked_pos,
speech_mask=speech_mask, speech_mask=speech_mask,
text_mask=text_mask, text_mask=text_mask,
speech_segment_pos=speech_segment_pos, speech_seg_pos=speech_seg_pos,
text_segment_pos=text_segment_pos, ) text_seg_pos=text_seg_pos, )
# # inference with teacher forcing # # inference with teacher forcing
# hs, h_masks = self.encoder(**batch) # hs, h_masks = self.encoder(**batch)
outs = [batch['speech_pad'][:, :span_boundary[0]]] outs = [batch['speech_pad'][:, :span_bdy[0]]]
z_cache = None z_cache = None
if use_teacher_forcing: if use_teacher_forcing:
before, zs, _, _ = self.forward( before, zs, _, _ = self.forward(
batch, speech_segment_pos, y_masks=y_masks) batch, speech_seg_pos, y_masks=y_masks)
if zs is None: if zs is None:
zs = before zs = before
outs += [zs[0][span_boundary[0]:span_boundary[1]]] outs += [zs[0][span_bdy[0]:span_bdy[1]]]
outs += [batch['speech_pad'][:, span_boundary[1]:]] outs += [batch['speech_pad'][:, span_bdy[1]:]]
return dict(feat_gen=outs) return dict(feat_gen=outs)
return None return None
...@@ -733,7 +727,7 @@ class MLMModel(nn.Layer): ...@@ -733,7 +727,7 @@ class MLMModel(nn.Layer):
class MLMEncAsDecoderModel(MLMModel): class MLMEncAsDecoderModel(MLMModel):
def forward(self, batch, speech_segment_pos, y_masks=None): def forward(self, batch, speech_seg_pos, y_masks=None):
# feats: (Batch, Length, Dim) # feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2) # -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad'] speech_pad_placeholder = batch['speech_pad']
...@@ -756,7 +750,7 @@ class MLMEncAsDecoderModel(MLMModel): ...@@ -756,7 +750,7 @@ class MLMEncAsDecoderModel(MLMModel):
else: else:
after_outs = None after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch[ return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_position'] 'masked_pos']
class MLMDualMaksingModel(MLMModel): class MLMDualMaksingModel(MLMModel):
...@@ -767,9 +761,9 @@ class MLMDualMaksingModel(MLMModel): ...@@ -767,9 +761,9 @@ class MLMDualMaksingModel(MLMModel):
batch): batch):
xs_pad = batch['speech_pad'] xs_pad = batch['speech_pad']
text_pad = batch['text_pad'] text_pad = batch['text_pad']
masked_position = batch['masked_position'] masked_pos = batch['masked_pos']
text_masked_position = batch['text_masked_position'] text_masked_pos = batch['text_masked_pos']
mlm_loss_position = masked_position > 0 mlm_loss_pos = masked_pos > 0
loss = paddle.sum( loss = paddle.sum(
self.l1_loss_func( self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)), paddle.reshape(before_outs, (-1, self.odim)),
...@@ -782,19 +776,17 @@ class MLMDualMaksingModel(MLMModel): ...@@ -782,19 +776,17 @@ class MLMDualMaksingModel(MLMModel):
paddle.reshape(xs_pad, (-1, self.odim))), paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1) axis=-1)
loss_mlm = paddle.sum((loss * paddle.reshape( loss_mlm = paddle.sum((loss * paddle.reshape(
mlm_loss_position, [-1]))) / paddle.sum((mlm_loss_position) + 1e-10) mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
loss_text = paddle.sum((self.text_mlm_loss( loss_text = paddle.sum((self.text_mlm_loss(
paddle.reshape(text_outs, (-1, self.vocab_size)), paddle.reshape(text_outs, (-1, self.vocab_size)),
paddle.reshape(text_pad, (-1))) * paddle.reshape( paddle.reshape(text_pad, (-1))) * paddle.reshape(
text_masked_position, text_masked_pos, (-1)))) / paddle.sum((text_masked_pos) + 1e-10)
(-1)))) / paddle.sum((text_masked_position) + 1e-10)
return loss_mlm, loss_text return loss_mlm, loss_text
def forward(self, batch, speech_segment_pos, y_masks=None): def forward(self, batch, speech_seg_pos, y_masks=None):
# feats: (Batch, Length, Dim) # feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2) # -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad']
encoder_out, h_masks = self.encoder(**batch) # segment_emb encoder_out, h_masks = self.encoder(**batch) # segment_emb
if self.decoder is not None: if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks) zs, _ = self.decoder(encoder_out, h_masks)
...@@ -819,7 +811,7 @@ class MLMDualMaksingModel(MLMModel): ...@@ -819,7 +811,7 @@ class MLMDualMaksingModel(MLMModel):
[0, 2, 1]) [0, 2, 1])
else: else:
after_outs = None after_outs = None
return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_position'],batch['text_masked_position'] return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_pos'],batch['text_masked_pos']
def build_model_from_file(config_file, model_file): def build_model_from_file(config_file, model_file):
......
...@@ -38,7 +38,7 @@ def pad_list(xs, pad_value): ...@@ -38,7 +38,7 @@ def pad_list(xs, pad_value):
""" """
n_batch = len(xs) n_batch = len(xs)
max_len = max(x.shape[0] for x in xs) max_len = max(x.shape[0] for x in xs)
pad = paddle.full([n_batch, max_len, *xs[0].shape[1:]], pad_value) pad = paddle.full([n_batch, max_len, *xs[0].shape[1:]], pad_value, dtype=xs[0].dtype)
for i in range(n_batch): for i in range(n_batch):
pad[i, :xs[i].shape[0]] = xs[i] pad[i, :xs[i].shape[0]] = xs[i]
...@@ -46,13 +46,18 @@ def pad_list(xs, pad_value): ...@@ -46,13 +46,18 @@ def pad_list(xs, pad_value):
return pad return pad
def make_pad_mask(lengths, length_dim=-1):
def make_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of padded part. """Make mask tensor containing indices of padded part.
Args: Args:
lengths (Tensor(int64)): Batch of lengths (B,). lengths (Tensor(int64)): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns: Returns:
Tensor(bool): Mask tensor containing indices of padded part bool. Tensor(bool): Mask tensor containing indices of padded part bool.
Examples: Examples:
...@@ -61,23 +66,98 @@ def make_pad_mask(lengths, length_dim=-1): ...@@ -61,23 +66,98 @@ def make_pad_mask(lengths, length_dim=-1):
>>> lengths = [5, 3, 2] >>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths) >>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0], masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1], [0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]] [0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = paddle.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]])
>>> xs = paddle.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]])
With the reference tensor and dimension indicator.
>>> xs = paddle.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]])
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]],)
""" """
if length_dim == 0: if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim)) raise ValueError("length_dim cannot be 0: {}".format(length_dim))
bs = paddle.shape(lengths)[0] bs = paddle.shape(lengths)[0]
maxlen = lengths.max() if xs is None:
maxlen = lengths.max()
else:
maxlen = paddle.shape(xs)[length_dim]
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64) seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen]) seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen])
seq_length_expand = lengths.unsqueeze(-1) seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand mask = seq_range_expand >= seq_length_expand
return mask if xs is not None:
assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs)
if length_dim < 0:
length_dim = len(paddle.shape(xs)) + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None
for i in range(len(paddle.shape(xs))))
mask = paddle.expand(mask[ind], paddle.shape(xs))
return mask
def make_non_pad_mask(lengths, length_dim=-1): def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part. """Make mask tensor containing indices of non-padded part.
Args: Args:
...@@ -90,16 +170,78 @@ def make_non_pad_mask(lengths, length_dim=-1): ...@@ -90,16 +170,78 @@ def make_non_pad_mask(lengths, length_dim=-1):
Returns: Returns:
Tensor(bool): mask tensor containing indices of padded part bool. Tensor(bool): mask tensor containing indices of padded part bool.
Examples: Examples:
With only lengths. With only lengths.
>>> lengths = [5, 3, 2] >>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths) >>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1], masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0], [1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]] [1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = paddle.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]])
>>> xs = paddle.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]])
With the reference tensor and dimension indicator.
>>> xs = paddle.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]])
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]])
""" """
return paddle.logical_not(make_pad_mask(lengths, length_dim)) return paddle.logical_not(make_pad_mask(lengths, xs, length_dim))
def initialize(model: nn.Layer, init: str): def initialize(model: nn.Layer, init: str):
......
...@@ -10,8 +10,8 @@ python inference.py \ ...@@ -10,8 +10,8 @@ python inference.py \
--uid=Prompt_003_new \ --uid=Prompt_003_new \
--new_str='今天天气很好.' \ --new_str='今天天气很好.' \
--prefix='./prompt/dev/' \ --prefix='./prompt/dev/' \
--source_language=english \ --source_lang=english \
--target_language=chinese \ --target_lang=chinese \
--output_name=pred_clone.wav \ --output_name=pred_clone.wav \
--use_pt_vocoder=False \ --use_pt_vocoder=False \
--voc=pwgan_aishell3 \ --voc=pwgan_aishell3 \
......
...@@ -9,8 +9,8 @@ python inference.py \ ...@@ -9,8 +9,8 @@ python inference.py \
--uid=p299_096 \ --uid=p299_096 \
--new_str='I enjoy my life, do you?' \ --new_str='I enjoy my life, do you?' \
--prefix='./prompt/dev/' \ --prefix='./prompt/dev/' \
--source_language=english \ --source_lang=english \
--target_language=english \ --target_lang=english \
--output_name=pred_gen.wav \ --output_name=pred_gen.wav \
--use_pt_vocoder=False \ --use_pt_vocoder=False \
--voc=pwgan_aishell3 \ --voc=pwgan_aishell3 \
......
...@@ -10,8 +10,8 @@ python inference.py \ ...@@ -10,8 +10,8 @@ python inference.py \
--uid=p243_new \ --uid=p243_new \
--new_str='for that reason cover is impossible to be given.' \ --new_str='for that reason cover is impossible to be given.' \
--prefix='./prompt/dev/' \ --prefix='./prompt/dev/' \
--source_language=english \ --source_lang=english \
--target_language=english \ --target_lang=english \
--output_name=pred_edit.wav \ --output_name=pred_edit.wav \
--use_pt_vocoder=False \ --use_pt_vocoder=False \
--voc=pwgan_aishell3 \ --voc=pwgan_aishell3 \
......
...@@ -80,10 +80,8 @@ def parse_args(): ...@@ -80,10 +80,8 @@ def parse_args():
parser.add_argument("--uid", type=str, help="uid") parser.add_argument("--uid", type=str, help="uid")
parser.add_argument("--new_str", type=str, help="new string") parser.add_argument("--new_str", type=str, help="new string")
parser.add_argument("--prefix", type=str, help="prefix") parser.add_argument("--prefix", type=str, help="prefix")
parser.add_argument("--clone_prefix", type=str, default=None, help="clone prefix") parser.add_argument("--source_lang", type=str, default="english", help="source language")
parser.add_argument("--clone_uid", type=str, default=None, help="clone uid") parser.add_argument("--target_lang", type=str, default="english", help="target language")
parser.add_argument("--source_language", type=str, help="source language")
parser.add_argument("--target_language", type=str, help="target language")
parser.add_argument("--output_name", type=str, help="output name") parser.add_argument("--output_name", type=str, help="output name")
parser.add_argument("--task_name", type=str, help="task name") parser.add_argument("--task_name", type=str, help="task name")
parser.add_argument( parser.add_argument(
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
import yaml import yaml
class ParallelWaveGANPretrainedVocoder(torch.nn.Module): class TorchPWGAN(torch.nn.Module):
"""Wrapper class to load the vocoder trained with parallel_wavegan repo.""" """Wrapper class to load the vocoder trained with parallel_wavegan repo."""
def __init__( def __init__(
......
import os
from typing import List
from typing import Optional
import numpy as np import numpy as np
import paddle import paddle
import yaml import yaml
...@@ -5,11 +9,8 @@ from sedit_arg_parser import parse_args ...@@ -5,11 +9,8 @@ from sedit_arg_parser import parse_args
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from tools.parallel_wavegan_pretrained_vocoder import ParallelWaveGANPretrainedVocoder from tools.torch_pwgan import TorchPWGAN
# new add
model_alias = { model_alias = {
# acoustic model # acoustic model
...@@ -25,6 +26,10 @@ model_alias = { ...@@ -25,6 +26,10 @@ model_alias = {
"paddlespeech.t2s.models.tacotron2:Tacotron2", "paddlespeech.t2s.models.tacotron2:Tacotron2",
"tacotron2_inference": "tacotron2_inference":
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference", "paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
"pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
"pwgan_inference":
"paddlespeech.t2s.models.parallel_wavegan:PWGInference",
} }
...@@ -43,60 +48,65 @@ def build_vocoder_from_file( ...@@ -43,60 +48,65 @@ def build_vocoder_from_file(
# Build vocoder # Build vocoder
if str(vocoder_file).endswith(".pkl"): if str(vocoder_file).endswith(".pkl"):
# If the extension is ".pkl", the model is trained with parallel_wavegan # If the extension is ".pkl", the model is trained with parallel_wavegan
vocoder = ParallelWaveGANPretrainedVocoder(vocoder_file, vocoder = TorchPWGAN(vocoder_file, vocoder_config_file)
vocoder_config_file)
return vocoder.to(device) return vocoder.to(device)
else: else:
raise ValueError(f"{vocoder_file} is not supported format.") raise ValueError(f"{vocoder_file} is not supported format.")
def get_voc_out(mel, target_language="chinese"): def get_voc_out(mel, target_lang: str="chinese"):
# vocoder # vocoder
args = parse_args() args = parse_args()
assert target_language == "chinese" or target_language == "english", "In get_voc_out function, target_language is illegal..." assert target_lang == "chinese" or target_lang == "english", "In get_voc_out function, target_lang is illegal..."
# print("current vocoder: ", args.voc) # print("current vocoder: ", args.voc)
with open(args.voc_config) as f: with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f)) voc_config = CfgNode(yaml.safe_load(f))
# print(voc_config) voc_inference = voc_inference = get_voc_inference(
voc=args.voc,
voc_inference = get_voc_inference(args, voc_config) voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
mel = paddle.to_tensor(mel)
# print("masked_mel: ", mel.shape)
with paddle.no_grad(): with paddle.no_grad():
wav = voc_inference(mel) wav = voc_inference(mel)
# print("shepe of wav (time x n_channels):%s"%wav.shape)
return np.squeeze(wav) return np.squeeze(wav)
# dygraph # dygraph
def get_am_inference(args, am_config): def get_am_inference(am: str='fastspeech2_csmsc',
with open(args.phones_dict, "r") as f: am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
return_am: bool=False):
with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
# print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
tone_size = None tone_size = None
if 'tones_dict' in args and args.tones_dict: if tones_dict is not None:
with open(args.tones_dict, "r") as f: with open(tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()] tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id) tone_size = len(tone_id)
print("tone_size:", tone_size) print("tone_size:", tone_size)
spk_num = None spk_num = None
if 'speaker_dict' in args and args.speaker_dict: if speaker_dict is not None:
with open(args.speaker_dict, 'rt') as f: with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id) spk_num = len(spk_id)
print("spk_num:", spk_num) print("spk_num:", spk_num)
odim = am_config.n_mels odim = am_config.n_mels
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
am_class = dynamic_import(am_name, model_alias) am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias) am_inference_class = dynamic_import(am_name + '_inference', model_alias)
...@@ -113,39 +123,61 @@ def get_am_inference(args, am_config): ...@@ -113,39 +123,61 @@ def get_am_inference(args, am_config):
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) am.set_state_dict(paddle.load(am_ckpt)["main_params"])
am.eval() am.eval()
am_mu, am_std = np.load(args.am_stat) am_mu, am_std = np.load(am_stat)
am_mu = paddle.to_tensor(am_mu) am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std) am_std = paddle.to_tensor(am_std)
am_normalizer = ZScore(am_mu, am_std) am_normalizer = ZScore(am_mu, am_std)
am_inference = am_inference_class(am_normalizer, am) am_inference = am_inference_class(am_normalizer, am)
am_inference.eval() am_inference.eval()
print("acoustic model done!") print("acoustic model done!")
return am, am_inference, am_name, am_dataset, phn_id if return_am:
return am_inference, am
else:
return am_inference
def evaluate_durations(phns, def get_voc_inference(
target_language="chinese", voc: str='pwgan_csmsc',
fs=24000, voc_config: Optional[os.PathLike]=None,
hop_length=300): voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None, ):
# model: {model_name}_{dataset}
voc_name = voc[:voc.rindex('_')]
voc_class = dynamic_import(voc_name, model_alias)
voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
if voc_name != 'wavernn':
voc = voc_class(**voc_config["generator_params"])
voc.set_state_dict(paddle.load(voc_ckpt)["generator_params"])
voc.remove_weight_norm()
voc.eval()
else:
voc = voc_class(**voc_config["model"])
voc.set_state_dict(paddle.load(voc_ckpt)["main_params"])
voc.eval()
voc_mu, voc_std = np.load(voc_stat)
voc_mu = paddle.to_tensor(voc_mu)
voc_std = paddle.to_tensor(voc_std)
voc_normalizer = ZScore(voc_mu, voc_std)
voc_inference = voc_inference_class(voc_normalizer, voc)
voc_inference.eval()
print("voc done!")
return voc_inference
def evaluate_durations(phns: List[str],
target_lang: str="chinese",
fs: int=24000,
hop_length: int=300):
args = parse_args() args = parse_args()
if target_language == 'english': if target_lang == 'english':
args.lang = 'en' args.lang = 'en'
args.am = "fastspeech2_ljspeech"
args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_language == 'chinese': elif target_lang == 'chinese':
args.lang = 'zh' args.lang = 'zh'
args.am = "fastspeech2_csmsc"
args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[]) # args = parser.parse_args(args=[])
if args.ngpu == 0: if args.ngpu == 0:
...@@ -155,23 +187,28 @@ def evaluate_durations(phns, ...@@ -155,23 +187,28 @@ def evaluate_durations(phns,
else: else:
print("ngpu should >= 0 !") print("ngpu should >= 0 !")
assert target_language == "chinese" or target_language == "english", "In evaluate_durations function, target_language is illegal..." assert target_lang == "chinese" or target_lang == "english", "In evaluate_durations function, target_lang is illegal..."
# Init body. # Init body.
with open(args.am_config) as f: with open(args.am_config) as f:
am_config = CfgNode(yaml.safe_load(f)) am_config = CfgNode(yaml.safe_load(f))
# print("========Config========")
# print(am_config) am_inference, am = get_am_inference(
# print("---------------------") am=args.am,
# acoustic model am_config=am_config,
am, am_inference, am_name, am_dataset, phn_id = get_am_inference(args, am_ckpt=args.am_ckpt,
am_config) am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict,
return_am=True)
torch_phns = phns torch_phns = phns
vocab_phones = {} vocab_phones = {}
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
for tone, id in phn_id: for tone, id in phn_id:
vocab_phones[tone] = int(id) vocab_phones[tone] = int(id)
# print("vocab_phones: ", len(vocab_phones))
vocab_size = len(vocab_phones) vocab_size = len(vocab_phones)
phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns] phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns]
...@@ -185,59 +222,3 @@ def evaluate_durations(phns, ...@@ -185,59 +222,3 @@ def evaluate_durations(phns,
phoneme_durations_new = pre_d_outs * hop_length / fs phoneme_durations_new = pre_d_outs * hop_length / fs
phoneme_durations_new = phoneme_durations_new.tolist()[:-1] phoneme_durations_new = phoneme_durations_new.tolist()[:-1]
return phoneme_durations_new return phoneme_durations_new
def sentence2phns(sentence, target_language="en"):
args = parse_args()
if target_language == 'en':
args.lang = 'en'
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_language == 'zh':
args.lang = 'zh'
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
else:
print("target_language should in {'zh', 'en'}!")
frontend = get_frontend(args)
merge_sentences = True
get_tone_ids = False
if target_language == 'zh':
input_ids = frontend.get_input_ids(
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
print_info=False)
phone_ids = input_ids["phone_ids"]
phonemes = frontend.get_phonemes(
sentence, merge_sentences=merge_sentences, print_info=False)
return phonemes[0], input_ids["phone_ids"][0]
elif target_language == 'en':
phonemes = frontend.phoneticize(sentence)
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
phones_list = []
vocab_phones = {}
punc = ":,;。?!“”‘’':,;.?!"
with open(args.phones_dict, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
vocab_phones[phn] = int(id)
phones = phonemes[1:-1]
phones = [phn for phn in phones if not phn.isspace()]
# replace unk phone with sp
phones = [
phn if (phn in vocab_phones and phn not in punc) else "sp"
for phn in phones
]
phones_list.append(phones)
return phones_list[0], input_ids["phone_ids"][0]
else:
print("lang should in {'zh', 'en'}!")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册