未验证 提交 b81832ce 编写于 作者: K Kennycao123 提交者: GitHub

Merge pull request #825 from yt605155624/format

[ernie sat]Format ernie sat
......@@ -113,8 +113,8 @@ prompt/dev
8. ` --uid` 特定提示(prompt)语音的 id
9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本)
10. ` --prefix` 特定音频对应的文本、音素相关文件的地址
11. ` --source_language` , 源语言
12. ` --target_language` , 目标语言
11. ` --source_lang` , 源语言
12. ` --target_lang` , 目标语言
13. ` --output_name` , 合成语音名称
14. ` --task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
15. ` --use_pt_vocoder`, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder
......
#!/usr/bin/env python
""" Usage:
align_english.py wavfile trsfile outwordfile outphonefile
align.py wavfile trsfile outwordfile outphonefile
"""
import multiprocessing as mp
import os
......@@ -9,12 +9,45 @@ import sys
from tqdm import tqdm
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR = 'tools/aligner/english'
MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin'
HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy'
def prep_txt(line, tmpbase, dictfile):
def prep_txt_zh(line: str, tmpbase: str, dictfile: str):
words = []
line = line.strip()
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
with open(dictfile, 'r') as fid:
for line in fid:
ds.add(line.split()[0])
unk_words = set([])
with open(tmpbase + '.txt', 'w') as fwid:
for wrd in words:
if (wrd not in ds):
unk_words.add(wrd)
fwid.write(wrd + ' ')
fwid.write('\n')
return unk_words
def prep_txt_en(line: str, tmpbase, dictfile):
words = []
......@@ -97,7 +130,7 @@ def prep_txt(line, tmpbase, dictfile):
fw.close()
def prep_mlf(txt, tmpbase):
def prep_mlf(txt: str, tmpbase: str):
with open(tmpbase + '.mlf', 'w') as fwid:
fwid.write('#!MLF!#\n')
......@@ -110,7 +143,55 @@ def prep_mlf(txt, tmpbase):
fwid.write('.\n')
def gen_res(tmpbase, outfile1, outfile2):
def _get_user():
return os.path.expanduser('~').split("/")[-1]
def alignment(wav_path: str, text: str):
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -')
except:
print('sox error!')
return None
#prepare clean_transcript file
try:
prep_txt_en(text, tmpbase, MODEL_DIR_EN + '/dict')
except:
print('prep_txt error!')
return None
#prepare mlf file
try:
with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline()
prep_mlf(txt, tmpbase)
except:
print('prep_mlf error!')
return None
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR_EN + '/16000/macros -H ' + MODEL_DIR_EN
+ '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase +
'.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
print('HVite error!')
return None
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
words = txt.strip().split()
......@@ -119,59 +200,47 @@ def gen_res(tmpbase, outfile1, outfile2):
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
times1 = []
times2 = []
word2phns = {}
current_word = ''
index = 0
while (i < len(lines)):
if (len(lines[i].split()) >= 4) and (
lines[i].split()[0] != lines[i].split()[1]):
phn = lines[i].split()[2]
pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000
pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000
splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
if (len(lines[i].split()) == 5):
if (lines[i].split()[0] != lines[i].split()[1]):
wrd = lines[i].split()[-1].strip()
st = (int(lines[i].split()[0]) / 1000 + 125) / 10000
j = i + 1
while (lines[j] != '.\n') and (len(lines[j].split()) != 5):
j += 1
en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
times1.append([wrd, st, en])
# splited_line[-1]!='sp'
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn
index += 1
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
with open(outfile1, 'w') as fwid:
for item in times1:
if (item[0] == 'sp'):
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' SIL\n')
else:
wrd = words.pop()
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + wrd + '\n')
if words:
print('not matched::' + alignfile)
sys.exit(1)
with open(outfile2, 'w') as fwid:
for item in times2:
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')
def _get_user():
return os.path.expanduser('~').split("/")[-1]
return times2, word2phns
def alignment(wav_path, text_string):
def alignment_zh(wav_path, text_string):
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -')
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -')
except:
print('sox error!')
return None
#prepare clean_transcript file
try:
prep_txt(text_string, tmpbase, MODEL_DIR + '/dict')
unk_words = prep_txt_zh(text_string, tmpbase, MODEL_DIR_ZH + '/dict')
if unk_words:
print('Error! Please add the following words to dictionary:')
for unk in unk_words:
print("非法words: ", unk)
except:
print('prep_txt error!')
return None
......@@ -187,7 +256,7 @@ def alignment(wav_path, text_string):
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase +
os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
print('HCopy error!')
......@@ -196,10 +265,11 @@ def alignment(wav_path, text_string):
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR +
'/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase +
'.dict ' + MODEL_DIR + '/monophones ' + tmpbase +
'.mlf -H ' + MODEL_DIR_ZH + '/16000/macros -H ' + MODEL_DIR_ZH
+ '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR_ZH
+ '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
print('HVite error!')
return None
......@@ -211,6 +281,7 @@ def alignment(wav_path, text_string):
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
times2 = []
word2phns = {}
......
#!/usr/bin/env python
""" Usage:
align_mandarin.py wavfile trsfile outwordfile putphonefile
"""
import multiprocessing as mp
import os
import sys
from tqdm import tqdm
MODEL_DIR = 'tools/aligner/mandarin'
HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy'
def prep_txt(line, tmpbase, dictfile):
words = []
line = line.strip()
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
with open(dictfile, 'r') as fid:
for line in fid:
ds.add(line.split()[0])
unk_words = set([])
with open(tmpbase + '.txt', 'w') as fwid:
for wrd in words:
if (wrd not in ds):
unk_words.add(wrd)
fwid.write(wrd + ' ')
fwid.write('\n')
return unk_words
def prep_mlf(txt, tmpbase):
with open(tmpbase + '.mlf', 'w') as fwid:
fwid.write('#!MLF!#\n')
fwid.write('"' + tmpbase + '.lab"\n')
fwid.write('sp\n')
wrds = txt.split()
for wrd in wrds:
fwid.write(wrd.upper() + '\n')
fwid.write('sp\n')
fwid.write('.\n')
def gen_res(tmpbase, outfile1, outfile2):
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
words = txt.strip().split()
words.reverse()
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
times1 = []
times2 = []
while (i < len(lines)):
if (len(lines[i].split()) >= 4) and (
lines[i].split()[0] != lines[i].split()[1]):
phn = lines[i].split()[2]
pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000
pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
if (len(lines[i].split()) == 5):
if (lines[i].split()[0] != lines[i].split()[1]):
wrd = lines[i].split()[-1].strip()
st = (int(lines[i].split()[0]) / 1000 + 125) / 10000
j = i + 1
while (lines[j] != '.\n') and (len(lines[j].split()) != 5):
j += 1
en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
times1.append([wrd, st, en])
i += 1
with open(outfile1, 'w') as fwid:
for item in times1:
if (item[0] == 'sp'):
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' SIL\n')
else:
wrd = words.pop()
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + wrd + '\n')
if words:
print('not matched::' + alignfile)
sys.exit(1)
with open(outfile2, 'w') as fwid:
for item in times2:
fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')
def alignment_zh(wav_path, text_string):
tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -')
except:
print('sox error!')
return None
#prepare clean_transcript file
try:
unk_words = prep_txt(text_string, tmpbase, MODEL_DIR + '/dict')
if unk_words:
print('Error! Please add the following words to dictionary:')
for unk in unk_words:
print("非法words: ", unk)
except:
print('prep_txt error!')
return None
#prepare mlf file
try:
with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline()
prep_mlf(txt, tmpbase)
except:
print('prep_mlf error!')
return None
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR +
'/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR +
'/dict ' + MODEL_DIR + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except:
print('HVite error!')
return None
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
words = txt.strip().split()
words.reverse()
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
times2 = []
word2phns = {}
current_word = ''
index = 0
while (i < len(lines)):
splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn
index += 1
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return times2, word2phns
......@@ -4,37 +4,180 @@ import numpy as np
import paddle
def pad_list(xs, pad_value):
"""Perform padding for the list of tensors.
def phones_text_masking(xs_pad: paddle.Tensor,
src_mask: paddle.Tensor,
text_pad: paddle.Tensor,
text_mask: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
mlm_prob: float,
mean_phn_span: float,
span_bdy: paddle.Tensor=None):
bz, sent_len, _ = paddle.shape(xs_pad)
masked_pos = paddle.zeros((bz, sent_len))
_, text_len = paddle.shape(text_pad)
text_mask_num_lower = math.ceil(text_len * (1 - mlm_prob) * 0.5)
text_masked_pos = paddle.zeros((bz, text_len))
y_masks = None
if mlm_prob == 1.0:
masked_pos += 1
# y_masks = tril_masks
elif mean_phn_span == 0:
# only speech
length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_idxs = random_spans_noise_mask(length, mlm_prob,
mean_phn_span).nonzero()
masked_pos[:, masked_phn_idxs] = 1
else:
for idx in range(bz):
if span_bdy is not None:
for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
masked_pos[idx, s:e] = 1
else:
length = align_start_lens[idx]
if length < 2:
continue
masked_phn_idxs = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero()
unmasked_phn_idxs = list(
set(range(length)) - set(masked_phn_idxs[0].tolist()))
np.random.shuffle(unmasked_phn_idxs)
masked_text_idxs = unmasked_phn_idxs[:text_mask_num_lower]
text_masked_pos[idx][masked_text_idxs] = 1
masked_start = align_start[idx][masked_phn_idxs].tolist()
masked_end = align_end[idx][masked_phn_idxs].tolist()
for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask
non_eos_text_mask = paddle.reshape(text_mask, paddle.shape(xs_pad)[:2])
text_masked_pos = text_masked_pos * non_eos_text_mask
masked_pos = paddle.cast(masked_pos, 'bool')
text_masked_pos = paddle.cast(text_masked_pos, 'bool')
return masked_pos, text_masked_pos, y_masks
def get_seg_pos_reduce_duration(
speech_pad: paddle.Tensor,
text_pad: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
sega_emb: bool,
masked_pos: paddle.Tensor,
feats_lens: paddle.Tensor, ):
bz, speech_len, _ = paddle.shape(speech_pad)
text_seg_pos = paddle.zeros(paddle.shape(text_pad))
speech_seg_pos = paddle.zeros((bz, speech_len), dtype=text_pad.dtype)
reordered_idx = paddle.zeros((bz, speech_len), dtype=align_start_lens.dtype)
durations = paddle.ones((bz, speech_len), dtype=align_start_lens.dtype)
max_reduced_length = 0
if not sega_emb:
return speech_pad, masked_pos, speech_seg_pos, text_seg_pos, durations
for idx in range(bz):
first_idx = []
last_idx = []
align_length = align_start_lens[idx]
for j in range(align_length):
s, e = align_start[idx][j], align_end[idx][j]
if j == 0:
if paddle.sum(masked_pos[idx][0:s]) == 0:
first_idx.extend(range(0, s))
else:
first_idx.extend([0])
last_idx.extend(range(1, s))
if paddle.sum(masked_pos[idx][s:e]) == 0:
first_idx.extend(range(s, e))
else:
first_idx.extend([s])
last_idx.extend(range(s + 1, e))
durations[idx][s] = e - s
speech_seg_pos[idx][s:e] = j + 1
text_seg_pos[idx][j] = j + 1
max_reduced_length = max(
len(first_idx) + feats_lens[idx] - e, max_reduced_length)
first_idx.extend(range(e, speech_len))
reordered_idx[idx] = paddle.to_tensor(
(first_idx + last_idx), dtype=align_start_lens.dtype)
feats_lens[idx] = len(first_idx)
reordered_idx = reordered_idx[:, :max_reduced_length]
return reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens
def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
"""This function is copy of `random_spans_helper
<https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
The number of noise tokens and the number of noise spans and non-noise spans
are determined deterministically as follows:
num_noise_tokens = round(length * noise_density)
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
Spans alternate between non-noise and noise, beginning with non-noise.
Subject to the above restrictions, all masks are equally likely.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
length: an int32 scalar (length of the incoming token sequence)
noise_density: a float - approximate density of output mask
mean_noise_span_length: a number
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
a boolean tensor with shape [length]
"""
n_batch = len(xs)
max_len = max(paddle.shape(x)[0] for x in xs)
pad = paddle.full((n_batch, max_len), pad_value, dtype=xs[0].dtype)
for i in range(n_batch):
pad[i, :paddle.shape(xs[i])[0]] = xs[i]
orig_length = length
return pad
num_noise_tokens = int(np.round(length * mlm_prob))
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
num_noise_spans = int(np.round(num_noise_tokens / mean_phn_span))
# avoid degeneracy by ensuring positive number of noise spans
num_noise_spans = max(num_noise_spans, 1)
num_nonnoise_tokens = length - num_noise_tokens
# pick the lengths of the noise spans and the non-noise spans
def _random_seg(num_items, num_segs):
"""Partition a sequence of items randomly into non-empty segments.
Args:
num_items: an integer scalar > 0
num_segs: an integer scalar in [1, num_items]
Returns:
a Tensor with shape [num_segs] containing positive integers that add
up to num_items
"""
mask_idxs = np.arange(num_items - 1) < (num_segs - 1)
np.random.shuffle(mask_idxs)
first_in_seg = np.pad(mask_idxs, [[1, 0]])
segment_id = np.cumsum(first_in_seg)
# count length of sub segments assuming that list is sorted
_, segment_length = np.unique(segment_id, return_counts=True)
return segment_length
noise_span_lens = _random_seg(num_noise_tokens, num_noise_spans)
nonnoise_span_lens = _random_seg(num_nonnoise_tokens, num_noise_spans)
interleaved_span_lens = np.reshape(
np.stack([nonnoise_span_lens, noise_span_lens], axis=1),
[num_noise_spans * 2])
span_starts = np.cumsum(interleaved_span_lens)[:-1]
span_start_indicator = np.zeros((length, ), dtype=np.int8)
span_start_indicator[span_starts] = True
span_num = np.cumsum(span_start_indicator)
is_noise = np.equal(span_num % 2, 1)
return is_noise[:orig_length]
def pad_to_longformer_att_window(text: paddle.Tensor,
max_len: int,
max_tlen: int,
attention_window: int=0):
def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
round = max_len % attention_window
if round != 0:
max_tlen += (attention_window - round)
......@@ -48,286 +191,67 @@ def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window):
return text_pad, max_tlen
def make_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = list(lengths)
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = paddle.shape(xs)[length_dim]
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
seq_range_expand = paddle.expand(
paddle.unsqueeze(seq_range, 0), (bs, maxlen))
seq_length_expand = paddle.unsqueeze(paddle.to_tensor(lengths), -1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs)
if length_dim < 0:
length_dim = len(paddle.shape(xs)) + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None
for i in range(len(paddle.shape(xs))))
mask = paddle.expand(mask[ind], paddle.shape(xs))
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
return ~make_pad_mask(lengths, xs, length_dim)
def phones_masking(xs_pad,
src_mask,
align_start,
align_end,
align_start_lengths,
mlm_prob,
mean_phn_span,
span_boundary=None):
def phones_masking(xs_pad: paddle.Tensor,
src_mask: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
mlm_prob: float,
mean_phn_span: int,
span_bdy: paddle.Tensor=None):
bz, sent_len, _ = paddle.shape(xs_pad)
mask_num_lower = math.ceil(sent_len * mlm_prob)
masked_position = np.zeros((bz, sent_len))
masked_pos = paddle.zeros((bz, sent_len))
y_masks = None
# y_masks = torch.ones(bz,sent_len,sent_len,device=xs_pad.device,dtype=xs_pad.dtype)
# tril_masks = torch.tril(y_masks)
if mlm_prob == 1.0:
masked_position += 1
# y_masks = tril_masks
masked_pos += 1
elif mean_phn_span == 0:
# only speech
length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_indices = random_spans_noise_mask(length, mlm_prob,
masked_phn_idxs = random_spans_noise_mask(length, mlm_prob,
mean_phn_span).nonzero()
masked_position[:, masked_phn_indices] = 1
masked_pos[:, masked_phn_idxs] = 1
else:
for idx in range(bz):
if span_boundary is not None:
for s, e in zip(span_boundary[idx][::2],
span_boundary[idx][1::2]):
masked_position[idx, s:e] = 1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0
if span_bdy is not None:
for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
masked_pos[idx, s:e] = 1
else:
length = align_start_lengths[idx].item()
length = align_start_lens[idx]
if length < 2:
continue
masked_phn_indices = random_spans_noise_mask(
masked_phn_idxs = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero()
masked_start = align_start[idx][masked_phn_indices].tolist()
masked_end = align_end[idx][masked_phn_indices].tolist()
masked_start = align_start[idx][masked_phn_idxs].tolist()
masked_end = align_end[idx][masked_phn_idxs].tolist()
for s, e in zip(masked_start, masked_end):
masked_position[idx, s:e] = 1
# y_masks[idx, :, s:e] = tril_masks[idx, :, s:e]
# y_masks[idx, e:, s:e ] = 0
non_eos_mask = np.array(
paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]).float().cpu())
masked_position = masked_position * non_eos_mask
# y_masks = src_mask & y_masks.bool()
masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask
masked_pos = paddle.cast(masked_pos, 'bool')
return paddle.cast(paddle.to_tensor(masked_position), paddle.bool), y_masks
return masked_pos, y_masks
def get_segment_pos(speech_pad, text_pad, align_start, align_end,
align_start_lengths, sega_emb):
bz, speech_len, _ = speech_pad.size()
_, text_len = text_pad.size()
def get_seg_pos(speech_pad: paddle.Tensor,
text_pad: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
sega_emb: bool):
bz, speech_len, _ = paddle.shape(speech_pad)
_, text_len = paddle.shape(text_pad)
# text_segment_pos = paddle.zeros_like(text_pad)
# speech_segment_pos = paddle.zeros((bz, speech_len),dtype=text_pad.dtype)
text_segment_pos = np.zeros((bz, text_len)).astype('int64')
speech_segment_pos = np.zeros((bz, speech_len)).astype('int64')
text_seg_pos = paddle.zeros((bz, text_len), dtype='int64')
speech_seg_pos = paddle.zeros((bz, speech_len), dtype='int64')
if not sega_emb:
text_segment_pos = paddle.to_tensor(text_segment_pos)
speech_segment_pos = paddle.to_tensor(speech_segment_pos)
return speech_segment_pos, text_segment_pos
return speech_seg_pos, text_seg_pos
for idx in range(bz):
align_length = align_start_lengths[idx].item()
align_length = align_start_lens[idx]
for j in range(align_length):
s, e = align_start[idx][j].item(), align_end[idx][j].item()
speech_segment_pos[idx][s:e] = j + 1
text_segment_pos[idx][j] = j + 1
text_segment_pos = paddle.to_tensor(text_segment_pos)
speech_segment_pos = paddle.to_tensor(speech_segment_pos)
s, e = align_start[idx][j], align_end[idx][j]
speech_seg_pos[idx, s:e] = j + 1
text_seg_pos[idx, j] = j + 1
return speech_segment_pos, text_segment_pos
return speech_seg_pos, text_seg_pos
#!/usr/bin/env python3
import argparse
import math
import os
import pickle
import random
import string
import sys
from pathlib import Path
from typing import Collection
from typing import Dict
......@@ -18,17 +14,17 @@ import numpy as np
import paddle
import soundfile as sf
import torch
from paddle import nn
from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model
from align_english import alignment
from align_mandarin import alignment_zh
from dataset import get_segment_pos
from dataset import make_non_pad_mask
from dataset import make_pad_mask
from dataset import pad_list
from align import alignment
from align import alignment_zh
from dataset import get_seg_pos
from dataset import get_seg_pos_reduce_duration
from dataset import pad_to_longformer_att_window
from dataset import phones_masking
from dataset import phones_text_masking
from model_paddle import build_model_from_file
from read_text import load_num_sequence_text
from read_text import read_2column_text
......@@ -37,8 +33,9 @@ from utils import build_vocoder_from_file
from utils import evaluate_durations
from utils import get_voc_out
from utils import is_chinese
from utils import sentence2phns
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.modules.nets_utils import pad_list
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
random.seed(0)
np.random.seed(0)
......@@ -47,81 +44,72 @@ MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin'
def plot_mel_and_vocode_wav(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
model_name,
wav_path,
full_origin_str,
old_str,
new_str,
use_pt_vocoder,
duration_preditor_path,
sid=None,
non_autoreg=True):
wav_org, input_feat, output_feat, old_span_boundary, new_span_boundary, fs, hop_length = get_mlm_output(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
model_name,
wav_path,
old_str,
new_str,
duration_preditor_path,
def plot_mel_and_vocode_wav(uid: str,
wav_path: str,
prefix: str="./prompt/dev/",
source_lang: str='english',
target_lang: str='english',
model_name: str="conformer",
full_origin_str: str="",
old_str: str="",
new_str: str="",
duration_preditor_path: str=None,
use_pt_vocoder: bool=False,
sid: str=None,
non_autoreg: bool=True):
wav_org, input_feat, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output(
uid=uid,
prefix=prefix,
source_lang=source_lang,
target_lang=target_lang,
model_name=model_name,
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
duration_preditor_path=duration_preditor_path,
use_teacher_forcing=non_autoreg,
sid=sid)
masked_feat = output_feat[new_span_boundary[0]:new_span_boundary[
1]].detach().float().cpu().numpy()
masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
if target_language == 'english':
if target_lang == 'english':
if use_pt_vocoder:
output_feat = output_feat.detach().float().cpu().numpy()
output_feat = output_feat.cpu().numpy()
output_feat = torch.tensor(output_feat, dtype=torch.float)
vocoder = load_vocoder('vctk_parallel_wavegan.v1.long')
replaced_wav = vocoder(
output_feat).detach().float().data.cpu().numpy()
replaced_wav = vocoder(output_feat).cpu().numpy()
else:
output_feat_np = output_feat.detach().float().cpu().numpy()
replaced_wav = get_voc_out(output_feat_np, target_language)
replaced_wav = get_voc_out(output_feat, target_lang)
elif target_language == 'chinese':
output_feat_np = output_feat.detach().float().cpu().numpy()
replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat,
target_language)
elif target_lang == 'chinese':
replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, target_lang)
old_time_boundary = [hop_length * x for x in old_span_boundary]
new_time_boundary = [hop_length * x for x in new_span_boundary]
old_time_bdy = [hop_length * x for x in old_span_bdy]
new_time_bdy = [hop_length * x for x in new_span_bdy]
if target_language == 'english':
if target_lang == 'english':
wav_org_replaced_paddle_voc = np.concatenate([
wav_org[:old_time_boundary[0]],
replaced_wav[new_time_boundary[0]:new_time_boundary[1]],
wav_org[old_time_boundary[1]:]
wav_org[:old_time_bdy[0]],
replaced_wav[new_time_bdy[0]:new_time_bdy[1]],
wav_org[old_time_bdy[1]:]
])
data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc}
elif target_language == 'chinese':
elif target_lang == 'chinese':
wav_org_replaced_only_mask_fst2_voc = np.concatenate([
wav_org[:old_time_boundary[0]], replaced_wav_only_mask_fst2_voc,
wav_org[old_time_boundary[1]:]
wav_org[:old_time_bdy[0]], replaced_wav_only_mask_fst2_voc,
wav_org[old_time_bdy[1]:]
])
data_dict = {
"origin": wav_org,
"output": wav_org_replaced_only_mask_fst2_voc,
}
return data_dict, old_span_boundary
return data_dict, old_span_bdy
def get_unk_phns(word_str):
def get_unk_phns(word_str: str):
tmpbase = '/tmp/tp.'
f = open(tmpbase + 'temp.words', 'w')
f.write(word_str)
......@@ -160,9 +148,8 @@ def get_unk_phns(word_str):
return phns
def words2phns(line):
def words2phns(line: str):
dictfile = MODEL_DIR_EN + '/dict'
tmpbase = '/tmp/tp.'
line = line.strip()
words = []
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
......@@ -200,9 +187,8 @@ def words2phns(line):
return phns, wrd2phns
def words2phns_zh(line):
def words2phns_zh(line: str):
dictfile = MODEL_DIR_ZH + '/dict'
tmpbase = '/tmp/tp.'
line = line.strip()
words = []
for pun in [
......@@ -242,7 +228,7 @@ def words2phns_zh(line):
return phns, wrd2phns
def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"):
def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"):
vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "")
vocoder_file = download_pretrained_model(vocoder_tag)
vocoder_config = Path(vocoder_file).parent / "config.yml"
......@@ -250,7 +236,7 @@ def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"):
return vocoder
def load_model(model_name):
def load_model(model_name: str):
config_path = './pretrained_model/{}/config.yaml'.format(model_name)
model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
mlm_model, args = build_model_from_file(
......@@ -258,7 +244,7 @@ def load_model(model_name):
return mlm_model, args
def read_data(uid, prefix):
def read_data(uid: str, prefix: str):
mfa_text = read_2column_text(prefix + '/text')[uid]
mfa_wav_path = read_2column_text(prefix + '/wav.scp')[uid]
if 'mnt' not in mfa_wav_path:
......@@ -266,7 +252,7 @@ def read_data(uid, prefix):
return mfa_text, mfa_wav_path
def get_align_data(uid, prefix):
def get_align_data(uid: str, prefix: str):
mfa_path = prefix + "mfa_"
mfa_text = read_2column_text(mfa_path + 'text')[uid]
mfa_start = load_num_sequence_text(
......@@ -277,43 +263,45 @@ def get_align_data(uid, prefix):
return mfa_text, mfa_start, mfa_end, mfa_wav_path
def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length,
span_tobe_replaced):
def get_masked_mel_bdy(mfa_start: List[float],
mfa_end: List[float],
fs: int,
hop_length: int,
span_to_repl: List[List[int]]):
align_start = paddle.to_tensor(mfa_start).unsqueeze(0)
align_end = paddle.to_tensor(mfa_end).unsqueeze(0)
align_start = paddle.floor(fs * align_start / hop_length).int()
align_end = paddle.floor(fs * align_end / hop_length).int()
if span_tobe_replaced[0] >= len(mfa_start):
span_boundary = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]]
if span_to_repl[0] >= len(mfa_start):
span_bdy = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]]
else:
span_boundary = [
align_start[0].tolist()[span_tobe_replaced[0]],
align_end[0].tolist()[span_tobe_replaced[1] - 1]
span_bdy = [
align_start[0].tolist()[span_to_repl[0]],
align_end[0].tolist()[span_to_repl[1] - 1]
]
return span_boundary
return span_bdy
def recover_dict(word2phns, tp_word2phns):
def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
dic = {}
need_del_key = []
exist_index = []
keys_to_del = []
exist_idx = []
sp_count = 0
add_sp_count = 0
for key in word2phns.keys():
idx, wrd = key.split('_')
if wrd == 'sp':
sp_count += 1
exist_index.append(int(idx))
exist_idx.append(int(idx))
else:
need_del_key.append(key)
keys_to_del.append(key)
for key in need_del_key:
for key in keys_to_del:
del word2phns[key]
cur_id = 0
for key in tp_word2phns.keys():
# print("debug: ",key)
if cur_id in exist_index:
if cur_id in exist_idx:
dic[str(cur_id) + "_sp"] = 'sp'
cur_id += 1
add_sp_count += 1
......@@ -329,14 +317,17 @@ def recover_dict(word2phns, tp_word2phns):
return dic
def get_phns_and_spans(wav_path, old_str, new_str, source_language,
clone_target_language):
def get_phns_and_spans(wav_path: str,
old_str: str="",
new_str: str="",
source_lang: str="english",
target_lang: str="english"):
append_new_str = (old_str == new_str[:len(old_str)])
old_phns, mfa_start, mfa_end = [], [], []
if source_language == "english":
if source_lang == "english":
times2, word2phns = alignment(wav_path, old_str)
elif source_language == "chinese":
elif source_lang == "chinese":
times2, word2phns = alignment_zh(wav_path, old_str)
_, tp_word2phns = words2phns_zh(old_str)
......@@ -348,14 +339,14 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
word2phns = recover_dict(word2phns, tp_word2phns)
else:
assert source_language == "chinese" or source_language == "english", "source_language is wrong..."
assert source_lang == "chinese" or source_lang == "english", "source_lang is wrong..."
for item in times2:
mfa_start.append(float(item[1]))
mfa_end.append(float(item[2]))
old_phns.append(item[0])
if append_new_str and (source_language != clone_target_language):
if append_new_str and (source_lang != target_lang):
is_cross_lingual_clone = True
else:
is_cross_lingual_clone = False
......@@ -364,18 +355,21 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
new_str_origin = new_str[:len(old_str)]
new_str_append = new_str[len(old_str):]
if clone_target_language == "chinese":
if target_lang == "chinese":
new_phns_origin, new_origin_word2phns = words2phns(new_str_origin)
new_phns_append, temp_new_append_word2phns = words2phns_zh(
new_str_append)
elif clone_target_language == "english":
elif target_lang == "english":
# 原始句子
new_phns_origin, new_origin_word2phns = words2phns_zh(
new_str_origin) # 原始句子
new_str_origin)
# clone句子
new_phns_append, temp_new_append_word2phns = words2phns(
new_str_append) # clone句子
new_str_append)
else:
assert clone_target_language == "chinese" or clone_target_language == "english", "cloning is not support for this language, please check it."
assert target_lang == "chinese" or target_lang == "english", \
"cloning is not support for this language, please check it."
new_phns = new_phns_origin + new_phns_append
......@@ -390,16 +384,17 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
new_append_word2phns.items()))
else:
if source_language == clone_target_language and clone_target_language == "english":
if source_lang == target_lang and target_lang == "english":
new_phns, new_word2phns = words2phns(new_str)
elif source_language == clone_target_language and clone_target_language == "chinese":
elif source_lang == target_lang and target_lang == "chinese":
new_phns, new_word2phns = words2phns_zh(new_str)
else:
assert source_language == clone_target_language, "source language is not same with target language..."
assert source_lang == target_lang, \
"source language is not same with target language..."
span_tobe_replaced = [0, len(old_phns) - 1]
span_tobe_added = [0, len(new_phns) - 1]
left_index = 0
span_to_repl = [0, len(old_phns) - 1]
span_to_add = [0, len(new_phns) - 1]
left_idx = 0
new_phns_left = []
sp_count = 0
# find the left different index
......@@ -411,27 +406,27 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
else:
idx = str(int(idx) - sp_count)
if idx + '_' + wrd in new_word2phns:
left_index += len(new_word2phns[idx + '_' + wrd])
left_idx += len(new_word2phns[idx + '_' + wrd])
new_phns_left.extend(word2phns[key].split())
else:
span_tobe_replaced[0] = len(new_phns_left)
span_tobe_added[0] = len(new_phns_left)
span_to_repl[0] = len(new_phns_left)
span_to_add[0] = len(new_phns_left)
break
# reverse word2phns and new_word2phns
right_index = 0
right_idx = 0
new_phns_right = []
sp_count = 0
word2phns_max_index = int(list(word2phns.keys())[-1].split('_')[0])
new_word2phns_max_index = int(list(new_word2phns.keys())[-1].split('_')[0])
new_phns_middle = []
word2phns_max_idx = int(list(word2phns.keys())[-1].split('_')[0])
new_word2phns_max_idx = int(list(new_word2phns.keys())[-1].split('_')[0])
new_phns_mid = []
if append_new_str:
new_phns_right = []
new_phns_middle = new_phns[left_index:]
span_tobe_replaced[0] = len(new_phns_left)
span_tobe_added[0] = len(new_phns_left)
span_tobe_added[1] = len(new_phns_left) + len(new_phns_middle)
span_tobe_replaced[1] = len(old_phns) - len(new_phns_right)
new_phns_mid = new_phns[left_idx:]
span_to_repl[0] = len(new_phns_left)
span_to_add[0] = len(new_phns_left)
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
span_to_repl[1] = len(old_phns) - len(new_phns_right)
else:
for key in list(word2phns.keys())[::-1]:
idx, wrd = key.split('_')
......@@ -439,33 +434,31 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language,
sp_count += 1
new_phns_right = ['sp'] + new_phns_right
else:
idx = str(new_word2phns_max_index - (word2phns_max_index - int(
idx) - sp_count))
idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx)
- sp_count))
if idx + '_' + wrd in new_word2phns:
right_index -= len(new_word2phns[idx + '_' + wrd])
right_idx -= len(new_word2phns[idx + '_' + wrd])
new_phns_right = word2phns[key].split() + new_phns_right
else:
span_tobe_replaced[1] = len(old_phns) - len(new_phns_right)
new_phns_middle = new_phns[left_index:right_index]
span_tobe_added[1] = len(new_phns_left) + len(
new_phns_middle)
if len(new_phns_middle) == 0:
span_tobe_added[1] = min(span_tobe_added[1] + 1,
len(new_phns))
span_tobe_added[0] = max(0, span_tobe_added[0] - 1)
span_tobe_replaced[0] = max(0,
span_tobe_replaced[0] - 1)
span_tobe_replaced[1] = min(span_tobe_replaced[1] + 1,
span_to_repl[1] = len(old_phns) - len(new_phns_right)
new_phns_mid = new_phns[left_idx:right_idx]
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
if len(new_phns_mid) == 0:
span_to_add[1] = min(span_to_add[1] + 1, len(new_phns))
span_to_add[0] = max(0, span_to_add[0] - 1)
span_to_repl[0] = max(0, span_to_repl[0] - 1)
span_to_repl[1] = min(span_to_repl[1] + 1,
len(old_phns))
break
new_phns = new_phns_left + new_phns_middle + new_phns_right
new_phns = new_phns_left + new_phns_mid + new_phns_right
return mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added
return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add
def duration_adjust_factor(original_dur, pred_dur, phns):
def duration_adjust_factor(original_dur: List[int],
pred_dur: List[int],
phns: List[str]):
length = 0
accumulate = 0
factor_list = []
for ori, pred, phn in zip(original_dur, pred_dur, phns):
if pred == 0 or phn == 'sp':
......@@ -481,242 +474,224 @@ def duration_adjust_factor(original_dur, pred_dur, phns):
return np.average(factor_list[length:-length])
def prepare_features_with_duration(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
old_str,
new_str,
wav_path,
duration_preditor_path,
sid=None,
mask_reconstruct=False,
duration_adjust=True,
start_end_sp=False,
def prepare_features_with_duration(uid: str,
prefix: str,
wav_path: str,
mlm_model: nn.Layer,
source_lang: str="English",
target_lang: str="English",
old_str: str="",
new_str: str="",
duration_preditor_path: str=None,
sid: str=None,
mask_reconstruct: bool=False,
duration_adjust: bool=True,
start_end_sp: bool=False,
train_args=None):
wav_org, rate = librosa.load(
wav_path, sr=train_args.feats_extract_conf['fs'])
fs = train_args.feats_extract_conf['fs']
hop_length = train_args.feats_extract_conf['hop_length']
mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added = get_phns_and_spans(
wav_path, old_str, new_str, source_language, target_language)
mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans(
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
source_lang=source_lang,
target_lang=target_lang)
if start_end_sp:
if new_phns[-1] != 'sp':
new_phns = new_phns + ['sp']
if target_language == "english":
old_durations = evaluate_durations(
old_phns, target_language=target_language)
if target_lang == "english":
old_durations = evaluate_durations(old_phns, target_lang=target_lang)
elif target_language == "chinese":
elif target_lang == "chinese":
if source_language == "english":
if source_lang == "english":
old_durations = evaluate_durations(
old_phns, target_language=source_language)
old_phns, target_lang=source_lang)
elif source_language == "chinese":
elif source_lang == "chinese":
old_durations = evaluate_durations(
old_phns, target_language=source_language)
old_phns, target_lang=source_lang)
else:
assert target_language == "chinese" or target_language == "english", "calculate duration_predict is not support for this language..."
assert target_lang == "chinese" or target_lang == "english", "calculate duration_predict is not support for this language..."
original_old_durations = [e - s for e, s in zip(mfa_end, mfa_start)]
if '[MASK]' in new_str:
new_phns = old_phns
span_tobe_added = span_tobe_replaced
span_to_add = span_to_repl
d_factor_left = duration_adjust_factor(
original_old_durations[:span_tobe_replaced[0]],
old_durations[:span_tobe_replaced[0]],
old_phns[:span_tobe_replaced[0]])
original_old_durations[:span_to_repl[0]],
old_durations[:span_to_repl[0]], old_phns[:span_to_repl[0]])
d_factor_right = duration_adjust_factor(
original_old_durations[span_tobe_replaced[1]:],
old_durations[span_tobe_replaced[1]:],
old_phns[span_tobe_replaced[1]:])
original_old_durations[span_to_repl[1]:],
old_durations[span_to_repl[1]:], old_phns[span_to_repl[1]:])
d_factor = (d_factor_left + d_factor_right) / 2
new_durations_adjusted = [d_factor * i for i in old_durations]
else:
if duration_adjust:
d_factor = duration_adjust_factor(original_old_durations,
old_durations, old_phns)
d_factor_paddle = duration_adjust_factor(original_old_durations,
old_durations, old_phns)
d_factor = d_factor * 1.25
else:
d_factor = 1
if target_language == "english":
if target_lang == "english":
new_durations = evaluate_durations(
new_phns, target_language=target_language)
new_phns, target_lang=target_lang)
elif target_language == "chinese":
elif target_lang == "chinese":
new_durations = evaluate_durations(
new_phns, target_language=target_language)
new_phns, target_lang=target_lang)
new_durations_adjusted = [d_factor * i for i in new_durations]
if span_tobe_replaced[0] < len(old_phns) and old_phns[
span_tobe_replaced[0]] == new_phns[span_tobe_added[0]]:
new_durations_adjusted[span_tobe_added[0]] = original_old_durations[
span_tobe_replaced[0]]
if span_tobe_replaced[1] < len(old_phns) and span_tobe_added[1] < len(
new_phns):
if old_phns[span_tobe_replaced[1]] == new_phns[span_tobe_added[1]]:
new_durations_adjusted[span_tobe_added[
1]] = original_old_durations[span_tobe_replaced[1]]
if span_to_repl[0] < len(old_phns) and old_phns[span_to_repl[
0]] == new_phns[span_to_add[0]]:
new_durations_adjusted[span_to_add[0]] = original_old_durations[
span_to_repl[0]]
if span_to_repl[1] < len(old_phns) and span_to_add[1] < len(new_phns):
if old_phns[span_to_repl[1]] == new_phns[span_to_add[1]]:
new_durations_adjusted[span_to_add[1]] = original_old_durations[
span_to_repl[1]]
new_span_duration_sum = sum(
new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]])
new_durations_adjusted[span_to_add[0]:span_to_add[1]])
old_span_duration_sum = sum(
original_old_durations[span_tobe_replaced[0]:span_tobe_replaced[1]])
original_old_durations[span_to_repl[0]:span_to_repl[1]])
duration_offset = new_span_duration_sum - old_span_duration_sum
new_mfa_start = mfa_start[:span_tobe_replaced[0]]
new_mfa_end = mfa_end[:span_tobe_replaced[0]]
for i in new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]]:
new_mfa_start = mfa_start[:span_to_repl[0]]
new_mfa_end = mfa_end[:span_to_repl[0]]
for i in new_durations_adjusted[span_to_add[0]:span_to_add[1]]:
if len(new_mfa_end) == 0:
new_mfa_start.append(0)
new_mfa_end.append(i)
else:
new_mfa_start.append(new_mfa_end[-1])
new_mfa_end.append(new_mfa_end[-1] + i)
new_mfa_start += [
i + duration_offset for i in mfa_start[span_tobe_replaced[1]:]
]
new_mfa_end += [
i + duration_offset for i in mfa_end[span_tobe_replaced[1]:]
]
new_mfa_start += [i + duration_offset for i in mfa_start[span_to_repl[1]:]]
new_mfa_end += [i + duration_offset for i in mfa_end[span_to_repl[1]:]]
# 3. get new wav
if span_tobe_replaced[0] >= len(mfa_start):
left_index = len(wav_org)
right_index = left_index
if span_to_repl[0] >= len(mfa_start):
left_idx = len(wav_org)
right_idx = left_idx
else:
left_index = int(np.floor(mfa_start[span_tobe_replaced[0]] * fs))
right_index = int(np.ceil(mfa_end[span_tobe_replaced[1] - 1] * fs))
left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs))
right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs))
new_blank_wav = np.zeros(
(int(np.ceil(new_span_duration_sum * fs)), ), dtype=wav_org.dtype)
new_wav_org = np.concatenate(
[wav_org[:left_index], new_blank_wav, wav_org[right_index:]])
[wav_org[:left_idx], new_blank_wav, wav_org[right_idx:]])
# 4. get old and new mel span to be mask
old_span_boundary = get_masked_mel_boundary(
mfa_start, mfa_end, fs, hop_length, span_tobe_replaced) # [92, 92]
new_span_boundary = get_masked_mel_boundary(new_mfa_start, new_mfa_end, fs,
hop_length,
span_tobe_added) # [92, 174]
return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_boundary, new_span_boundary
def prepare_features(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
# [92, 92]
old_span_bdy = get_masked_mel_bdy(mfa_start, mfa_end, fs, hop_length,
span_to_repl)
# [92, 174]
new_span_bdy = get_masked_mel_bdy(new_mfa_start, new_mfa_end, fs,
hop_length, span_to_add)
return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy
def prepare_features(uid: str,
mlm_model: nn.Layer,
processor,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid=None,
duration_adjust=True,
start_end_sp=False,
mask_reconstruct=False,
wav_path: str,
prefix: str="./prompt/dev/",
source_lang: str="english",
target_lang: str="english",
old_str: str="",
new_str: str="",
duration_preditor_path: str=None,
sid: str=None,
duration_adjust: bool=True,
start_end_sp: bool=False,
mask_reconstruct: bool=False,
train_args=None):
wav_org, phns_list, mfa_start, mfa_end, old_span_boundary, new_span_boundary = prepare_features_with_duration(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
old_str,
new_str,
wav_path,
duration_preditor_path,
wav_org, phns_list, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prepare_features_with_duration(
uid=uid,
prefix=prefix,
source_lang=source_lang,
target_lang=target_lang,
mlm_model=mlm_model,
old_str=old_str,
new_str=new_str,
wav_path=wav_path,
duration_preditor_path=duration_preditor_path,
sid=sid,
duration_adjust=duration_adjust,
start_end_sp=start_end_sp,
mask_reconstruct=mask_reconstruct,
train_args=train_args)
speech = np.array(wav_org, dtype=np.float32)
speech = wav_org
align_start = np.array(mfa_start)
align_end = np.array(mfa_end)
token_to_id = {item: i for i, item in enumerate(train_args.token_list)}
text = np.array(
list(
map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns_list)))
# print('unk id is', token_to_id['<unk>'])
# text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text'])
span_boundary = np.array(new_span_boundary)
span_bdy = np.array(new_span_bdy)
batch = [('1', {
"speech": speech,
"align_start": align_start,
"align_end": align_end,
"text": text,
"span_boundary": span_boundary
"span_bdy": span_bdy
})]
return batch, old_span_boundary, new_span_boundary
return batch, old_span_bdy, new_span_bdy
def decode_with_model(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
def decode_with_model(uid: str,
mlm_model: nn.Layer,
processor,
collate_fn,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid=None,
decoder=False,
use_teacher_forcing=False,
duration_adjust=True,
start_end_sp=False,
wav_path: str,
prefix: str="./prompt/dev/",
source_lang: str="english",
target_lang: str="english",
old_str: str="",
new_str: str="",
duration_preditor_path: str=None,
sid: str=None,
decoder: bool=False,
use_teacher_forcing: bool=False,
duration_adjust: bool=True,
start_end_sp: bool=False,
train_args=None):
fs, hop_length = train_args.feats_extract_conf[
'fs'], train_args.feats_extract_conf['hop_length']
batch, old_span_boundary, new_span_boundary = prepare_features(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
processor,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid,
batch, old_span_bdy, new_span_bdy = prepare_features(
uid=uid,
prefix=prefix,
source_lang=source_lang,
target_lang=target_lang,
mlm_model=mlm_model,
processor=processor,
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
duration_preditor_path=duration_preditor_path,
sid=sid,
duration_adjust=duration_adjust,
start_end_sp=start_end_sp,
train_args=train_args)
feats = collate_fn(batch)[1]
if 'text_masked_position' in feats.keys():
feats.pop('text_masked_position')
if 'text_masked_pos' in feats.keys():
feats.pop('text_masked_pos')
for k, v in feats.items():
feats[k] = paddle.to_tensor(v)
rtn = mlm_model.inference(
**feats,
span_boundary=new_span_boundary,
use_teacher_forcing=use_teacher_forcing)
**feats, span_bdy=new_span_bdy, use_teacher_forcing=use_teacher_forcing)
output = rtn['feat_gen']
if 0 in output[0].shape and 0 not in output[-1].shape:
output_feat = paddle.concat(
......@@ -731,12 +706,9 @@ def decode_with_model(uid,
[output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)],
axis=0).cpu()
wav_org, rate = librosa.load(
wav_org, _ = librosa.load(
wav_path, sr=train_args.feats_extract_conf['fs'])
origin_speech = paddle.to_tensor(
np.array(wav_org, dtype=np.float32)).unsqueeze(0)
speech_lengths = paddle.to_tensor(len(wav_org)).unsqueeze(0)
return wav_org, None, output_feat, old_span_boundary, new_span_boundary, fs, hop_length
return wav_org, None, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
class MLMCollateFn:
......@@ -800,33 +772,15 @@ def mlm_collate_fn(
sega_emb: bool=False,
duration_collect: bool=False,
text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
Examples:
>>> from espnet2.samplers.constant_batch_sampler import ConstantBatchSampler,
>>> import espnet2.tasks.abs_task
>>> from espnet2.train.dataset import ESPnetDataset
>>> sampler = ConstantBatchSampler(...)
>>> dataset = ESPnetDataset(...)
>>> keys = next(iter(sampler)
>>> batch = [dataset[key] for key in keys]
>>> batch = common_collate_fn(batch)
>>> model(**batch)
Note that the dict-keys of batch are propagated from
that of the dataset as they are.
"""
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(not k.endswith("_lengths")
for k in data[0]), f"*_lengths is reserved: {list(data[0])}"
assert all(not k.endswith("_lens")
for k in data[0]), f"*_lens is reserved: {list(data[0])}"
output = {}
for key in data[0]:
# NOTE(kamo):
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if data[0][key].dtype.kind == "i":
......@@ -846,37 +800,35 @@ def mlm_collate_fn(
# lens: (Batch,)
if key not in not_sequence:
lens = paddle.to_tensor(
[d[key].shape[0] for d in data], dtype=paddle.long)
output[key + "_lengths"] = lens
[d[key].shape[0] for d in data], dtype=paddle.int64)
output[key + "_lens"] = lens
feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
feats = paddle.to_tensor(feats)
# print('out shape', paddle.shape(feats))
feats_lengths = paddle.shape(feats)[0]
feats_lens = paddle.shape(feats)[0]
feats = paddle.unsqueeze(feats, 0)
batch_size = paddle.shape(feats)[0]
if 'text' not in output:
text = paddle.zeros_like(feats_lengths.unsqueeze(-1)) - 2
text_lengths = paddle.zeros_like(feats_lengths) + 1
text = paddle.zeros(paddle.shape(feats_lens.unsqueeze(-1))) - 2
text_lens = paddle.zeros(paddle.shape(feats_lens)) + 1
max_tlen = 1
align_start = paddle.zeros_like(text)
align_end = paddle.zeros_like(text)
align_start_lengths = paddle.zeros_like(feats_lengths)
align_end_lengths = paddle.zeros_like(feats_lengths)
align_start = paddle.zeros(paddle.shape(text))
align_end = paddle.zeros(paddle.shape(text))
align_start_lens = paddle.zeros(paddle.shape(feats_lens))
sega_emb = False
mean_phn_span = 0
mlm_prob = 0.15
else:
text, text_lengths = output["text"], output["text_lengths"]
align_start, align_start_lengths, align_end, align_end_lengths = output[
"align_start"], output["align_start_lengths"], output[
"align_end"], output["align_end_lengths"]
text = output["text"]
text_lens = output["text_lens"]
align_start = output["align_start"]
align_start_lens = output["align_start_lens"]
align_end = output["align_end"]
align_start = paddle.floor(feats_extract.sr * align_start /
feats_extract.hop_length).int()
align_end = paddle.floor(feats_extract.sr * align_end /
feats_extract.hop_length).int()
max_tlen = max(text_lengths).item()
max_slen = max(feats_lengths).item()
max_tlen = max(text_lens)
max_slen = max(feats_lens)
speech_pad = feats[:, :max_slen]
if attention_window > 0 and pad_speech:
speech_pad, max_slen = pad_to_longformer_att_window(
......@@ -888,51 +840,49 @@ def mlm_collate_fn(
else:
text_pad = text
text_mask = make_non_pad_mask(
text_lengths.tolist(), text_pad, length_dim=1).unsqueeze(-2)
text_lens, text_pad, length_dim=1).unsqueeze(-2)
if attention_window > 0:
text_mask = text_mask * 2
speech_mask = make_non_pad_mask(
feats_lengths.tolist(), speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
span_boundary = None
if 'span_boundary' in output.keys():
span_boundary = output['span_boundary']
feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
span_bdy = None
if 'span_bdy' in output.keys():
span_bdy = output['span_bdy']
if text_masking:
masked_position, text_masked_position, _ = phones_text_masking(
masked_pos, text_masked_pos, _ = phones_text_masking(
speech_pad, speech_mask, text_pad, text_mask, align_start,
align_end, align_start_lengths, mlm_prob, mean_phn_span,
span_boundary)
align_end, align_start_lens, mlm_prob, mean_phn_span, span_bdy)
else:
text_masked_position = np.zeros(text_pad.size())
masked_position, _ = phones_masking(
speech_pad, speech_mask, align_start, align_end,
align_start_lengths, mlm_prob, mean_phn_span, span_boundary)
text_masked_pos = paddle.zeros(paddle.shape(text_pad))
masked_pos, _ = phones_masking(speech_pad, speech_mask, align_start,
align_end, align_start_lens, mlm_prob,
mean_phn_span, span_bdy)
output_dict = {}
if duration_collect and 'text' in output:
reordered_index, speech_segment_pos, text_segment_pos, durations, feats_lengths = get_segment_pos_reduce_duration(
speech_pad, text_pad, align_start, align_end, align_start_lengths,
sega_emb, masked_position, feats_lengths)
reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens = get_seg_pos_reduce_duration(
speech_pad, text_pad, align_start, align_end, align_start_lens,
sega_emb, masked_pos, feats_lens)
speech_mask = make_non_pad_mask(
feats_lengths.tolist(),
speech_pad[:, :reordered_index.shape[1], 0],
feats_lens, speech_pad[:, :reordered_idx.shape[1], 0],
length_dim=1).unsqueeze(-2)
output_dict['durations'] = durations
output_dict['reordered_index'] = reordered_index
output_dict['reordered_idx'] = reordered_idx
else:
speech_segment_pos, text_segment_pos = get_segment_pos(
speech_pad, text_pad, align_start, align_end, align_start_lengths,
sega_emb)
speech_seg_pos, text_seg_pos = get_seg_pos(speech_pad, text_pad,
align_start, align_end,
align_start_lens, sega_emb)
output_dict['speech'] = speech_pad
output_dict['text'] = text_pad
output_dict['masked_position'] = masked_position
output_dict['text_masked_position'] = text_masked_position
output_dict['masked_pos'] = masked_pos
output_dict['text_masked_pos'] = text_masked_pos
output_dict['speech_mask'] = speech_mask
output_dict['text_mask'] = text_mask
output_dict['speech_segment_pos'] = speech_segment_pos
output_dict['text_segment_pos'] = text_segment_pos
output_dict['speech_lengths'] = output["speech_lengths"]
output_dict['text_lengths'] = text_lengths
output_dict['speech_seg_pos'] = speech_seg_pos
output_dict['text_seg_pos'] = text_seg_pos
output_dict['speech_lens'] = output["speech_lens"]
output_dict['text_lens'] = text_lens
output = (uttids, output_dict)
return output
......@@ -940,13 +890,13 @@ def mlm_collate_fn(
def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
# -> Callable[
# [Collection[Tuple[str, Dict[str, np.ndarray]]]],
# Tuple[List[str], Dict[str, torch.Tensor]],
# Tuple[List[str], Dict[str, Tensor]],
# ]:
# assert check_argument_types()
# return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
feats_extract_class = LogMelFBank
if args.feats_extract_conf['win_length'] == None:
if args.feats_extract_conf['win_length'] is None:
args.feats_extract_conf['win_length'] = args.feats_extract_conf['n_fft']
args_dic = {}
......@@ -955,7 +905,6 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
args_dic['sr'] = v
else:
args_dic[k] = v
# feats_extract = feats_extract_class(**args.feats_extract_conf)
feats_extract = feats_extract_class(**args_dic)
sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False
......@@ -969,8 +918,7 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_probs = [1.0, 1.0, 0.7, 0.6, 0.5]
mlm_prob_factor = 0.8 #mlm_probs[epoch // 100]
mlm_prob_factor = 0.8
if 'duration_predictor_layers' in args.model_conf.keys(
) and args.model_conf['duration_predictor_layers'] > 0:
duration_collect = True
......@@ -989,42 +937,37 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1):
duration_collect=duration_collect)
def get_mlm_output(uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
model_name,
wav_path,
old_str,
new_str,
duration_preditor_path,
sid=None,
decoder=False,
use_teacher_forcing=False,
dynamic_eval=(0, 0),
duration_adjust=True,
start_end_sp=False):
def get_mlm_output(uid: str,
wav_path: str,
prefix: str="./prompt/dev/",
model_name: str="conformer",
source_lang: str="english",
target_lang: str="english",
old_str: str="",
new_str: str="",
duration_preditor_path: str=None,
sid: str=None,
decoder: bool=False,
use_teacher_forcing: bool=False,
duration_adjust: bool=True,
start_end_sp: bool=False):
mlm_model, train_args = load_model(model_name)
mlm_model.eval()
processor = None
collate_fn = build_collate_fn(train_args, False)
return decode_with_model(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
mlm_model,
processor,
collate_fn,
wav_path,
old_str,
new_str,
duration_preditor_path,
uid=uid,
prefix=prefix,
source_lang=source_lang,
target_lang=target_lang,
mlm_model=mlm_model,
processor=processor,
collate_fn=collate_fn,
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
duration_preditor_path=duration_preditor_path,
sid=sid,
decoder=decoder,
use_teacher_forcing=use_teacher_forcing,
......@@ -1033,23 +976,20 @@ def get_mlm_output(uid,
train_args=train_args)
def test_vctk(uid,
clone_uid,
clone_prefix,
source_language,
target_language,
vocoder,
prefix='dump/raw/dev',
model_name="conformer",
old_str="",
new_str="",
prompt_decoding=False,
dynamic_eval=(0, 0),
task_name=None):
def evaluate(uid: str,
source_lang: str="english",
target_lang: str="english",
use_pt_vocoder: bool=False,
prefix: str="./prompt/dev/",
model_name: str="conformer",
old_str: str="",
new_str: str="",
prompt_decoding: bool=False,
task_name: str=None):
duration_preditor_path = None
spemd = None
full_origin_str, wav_path = read_data(uid, prefix)
full_origin_str, wav_path = read_data(uid=uid, prefix=prefix)
if task_name == 'edit':
new_str = new_str
......@@ -1065,19 +1005,17 @@ def test_vctk(uid,
old_str = full_origin_str
results_dict, old_span = plot_mel_and_vocode_wav(
uid,
prefix,
clone_uid,
clone_prefix,
source_language,
target_language,
model_name,
wav_path,
full_origin_str,
old_str,
new_str,
vocoder,
duration_preditor_path,
uid=uid,
prefix=prefix,
source_lang=source_lang,
target_lang=target_lang,
model_name=model_name,
wav_path=wav_path,
full_origin_str=full_origin_str,
old_str=old_str,
new_str=new_str,
use_pt_vocoder=use_pt_vocoder,
duration_preditor_path=duration_preditor_path,
sid=spemd)
return results_dict
......@@ -1086,17 +1024,14 @@ if __name__ == "__main__":
# parse config and args
args = parse_args()
data_dict = test_vctk(
args.uid,
args.clone_uid,
args.clone_prefix,
args.source_language,
args.target_language,
args.use_pt_vocoder,
args.prefix,
args.model_name,
data_dict = evaluate(
uid=args.uid,
source_lang=args.source_lang,
target_lang=args.target_lang,
use_pt_vocoder=args.use_pt_vocoder,
prefix=args.prefix,
model_name=args.model_name,
new_str=args.new_str,
task_name=args.task_name)
sf.write(args.output_name, data_dict['output'], samplerate=24000)
print("finished...")
# exit()
......@@ -121,12 +121,10 @@ class NewMaskInputLayer(nn.Layer):
default_initializer=paddle.nn.initializer.Assign(
paddle.normal(shape=(1, 1, out_features))))
def forward(self, input: paddle.Tensor,
masked_position=None) -> paddle.Tensor:
masked_position = paddle.expand_as(
paddle.unsqueeze(masked_position, -1), input)
masked_input = masked_fill(input, masked_position, 0) + masked_fill(
paddle.expand_as(self.mask_feature, input), ~masked_position, 0)
def forward(self, input: paddle.Tensor, masked_pos=None) -> paddle.Tensor:
masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input)
masked_input = masked_fill(input, masked_pos, 0) + masked_fill(
paddle.expand_as(self.mask_feature, input), ~masked_pos, 0)
return masked_input
......@@ -443,37 +441,34 @@ class MLMEncoder(nn.Layer):
def forward(self,
speech_pad,
text_pad,
masked_position,
masked_pos,
speech_mask=None,
text_mask=None,
speech_segment_pos=None,
text_segment_pos=None):
speech_seg_pos=None,
text_seg_pos=None):
"""Encode input sequence.
"""
if masked_position is not None:
speech_pad = self.speech_embed(speech_pad, masked_position)
if masked_pos is not None:
speech_pad = self.speech_embed(speech_pad, masked_pos)
else:
speech_pad = self.speech_embed(speech_pad)
# pure speech input
if -2 in np.array(text_pad):
text_pad = text_pad + 3
text_mask = paddle.unsqueeze(bool(text_pad), 1)
text_segment_pos = paddle.zeros_like(text_pad)
text_seg_pos = paddle.zeros_like(text_pad)
text_pad = self.text_embed(text_pad)
text_pad = (text_pad[0] + self.segment_emb(text_segment_pos),
text_pad = (text_pad[0] + self.segment_emb(text_seg_pos),
text_pad[1])
text_segment_pos = None
text_seg_pos = None
elif text_pad is not None:
text_pad = self.text_embed(text_pad)
segment_emb = None
if speech_segment_pos is not None and text_segment_pos is not None and self.segment_emb:
speech_segment_emb = self.segment_emb(speech_segment_pos)
text_segment_emb = self.segment_emb(text_segment_pos)
text_pad = (text_pad[0] + text_segment_emb, text_pad[1])
speech_pad = (speech_pad[0] + speech_segment_emb, speech_pad[1])
segment_emb = paddle.concat(
[speech_segment_emb, text_segment_emb], axis=1)
if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb:
speech_seg_emb = self.segment_emb(speech_seg_pos)
text_seg_emb = self.segment_emb(text_seg_pos)
text_pad = (text_pad[0] + text_seg_emb, text_pad[1])
speech_pad = (speech_pad[0] + speech_seg_emb, speech_pad[1])
if self.pre_speech_encoders:
speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask)
......@@ -493,11 +488,11 @@ class MLMEncoder(nn.Layer):
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks #, segment_emb
return xs, masks
class MLMDecoder(MLMEncoder):
def forward(self, xs, masks, masked_position=None, segment_emb=None):
def forward(self, xs, masks, masked_pos=None, segment_emb=None):
"""Encode input sequence.
Args:
......@@ -509,9 +504,8 @@ class MLMDecoder(MLMEncoder):
paddle.Tensor: Mask tensor (#batch, time).
"""
emb, mlm_position = None, None
if not self.training:
masked_position = None
masked_pos = None
xs = self.embed(xs)
if segment_emb:
xs = (xs[0] + segment_emb, xs[1])
......@@ -632,18 +626,18 @@ class MLMModel(nn.Layer):
def collect_feats(self,
speech,
speech_lengths,
speech_lens,
text,
text_lengths,
masked_position,
text_lens,
masked_pos,
speech_mask,
text_mask,
speech_segment_pos,
text_segment_pos,
speech_seg_pos,
text_seg_pos,
y_masks=None) -> Dict[str, paddle.Tensor]:
return {"feats": speech, "feats_lengths": speech_lengths}
return {"feats": speech, "feats_lens": speech_lens}
def forward(self, batch, speech_segment_pos, y_masks=None):
def forward(self, batch, speech_seg_pos, y_masks=None):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad']
......@@ -654,7 +648,7 @@ class MLMModel(nn.Layer):
if self.decoder is not None:
zs, _ = self.decoder(ys_in, y_masks, encoder_out,
bool(h_masks),
self.encoder.segment_emb(speech_segment_pos))
self.encoder.segment_emb(speech_seg_pos))
speech_hidden_states = zs
else:
speech_hidden_states = encoder_out[:, :paddle.shape(batch[
......@@ -672,21 +666,21 @@ class MLMModel(nn.Layer):
else:
after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_position']
'masked_pos']
def inference(
self,
speech,
text,
masked_position,
masked_pos,
speech_mask,
text_mask,
speech_segment_pos,
text_segment_pos,
span_boundary,
speech_seg_pos,
text_seg_pos,
span_bdy,
y_masks=None,
speech_lengths=None,
text_lengths=None,
speech_lens=None,
text_lens=None,
feats: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
sids: Optional[paddle.Tensor]=None,
......@@ -699,24 +693,24 @@ class MLMModel(nn.Layer):
batch = dict(
speech_pad=speech,
text_pad=text,
masked_position=masked_position,
masked_pos=masked_pos,
speech_mask=speech_mask,
text_mask=text_mask,
speech_segment_pos=speech_segment_pos,
text_segment_pos=text_segment_pos, )
speech_seg_pos=speech_seg_pos,
text_seg_pos=text_seg_pos, )
# # inference with teacher forcing
# hs, h_masks = self.encoder(**batch)
outs = [batch['speech_pad'][:, :span_boundary[0]]]
outs = [batch['speech_pad'][:, :span_bdy[0]]]
z_cache = None
if use_teacher_forcing:
before, zs, _, _ = self.forward(
batch, speech_segment_pos, y_masks=y_masks)
batch, speech_seg_pos, y_masks=y_masks)
if zs is None:
zs = before
outs += [zs[0][span_boundary[0]:span_boundary[1]]]
outs += [batch['speech_pad'][:, span_boundary[1]:]]
outs += [zs[0][span_bdy[0]:span_bdy[1]]]
outs += [batch['speech_pad'][:, span_bdy[1]:]]
return dict(feat_gen=outs)
return None
......@@ -733,7 +727,7 @@ class MLMModel(nn.Layer):
class MLMEncAsDecoderModel(MLMModel):
def forward(self, batch, speech_segment_pos, y_masks=None):
def forward(self, batch, speech_seg_pos, y_masks=None):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad']
......@@ -756,7 +750,7 @@ class MLMEncAsDecoderModel(MLMModel):
else:
after_outs = None
return before_outs, after_outs, speech_pad_placeholder, batch[
'masked_position']
'masked_pos']
class MLMDualMaksingModel(MLMModel):
......@@ -767,9 +761,9 @@ class MLMDualMaksingModel(MLMModel):
batch):
xs_pad = batch['speech_pad']
text_pad = batch['text_pad']
masked_position = batch['masked_position']
text_masked_position = batch['text_masked_position']
mlm_loss_position = masked_position > 0
masked_pos = batch['masked_pos']
text_masked_pos = batch['text_masked_pos']
mlm_loss_pos = masked_pos > 0
loss = paddle.sum(
self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)),
......@@ -782,19 +776,17 @@ class MLMDualMaksingModel(MLMModel):
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
loss_mlm = paddle.sum((loss * paddle.reshape(
mlm_loss_position, [-1]))) / paddle.sum((mlm_loss_position) + 1e-10)
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
loss_text = paddle.sum((self.text_mlm_loss(
paddle.reshape(text_outs, (-1, self.vocab_size)),
paddle.reshape(text_pad, (-1))) * paddle.reshape(
text_masked_position,
(-1)))) / paddle.sum((text_masked_position) + 1e-10)
text_masked_pos, (-1)))) / paddle.sum((text_masked_pos) + 1e-10)
return loss_mlm, loss_text
def forward(self, batch, speech_segment_pos, y_masks=None):
def forward(self, batch, speech_seg_pos, y_masks=None):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
speech_pad_placeholder = batch['speech_pad']
encoder_out, h_masks = self.encoder(**batch) # segment_emb
if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks)
......@@ -819,7 +811,7 @@ class MLMDualMaksingModel(MLMModel):
[0, 2, 1])
else:
after_outs = None
return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_position'],batch['text_masked_position']
return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_pos'],batch['text_masked_pos']
def build_model_from_file(config_file, model_file):
......
......@@ -38,7 +38,7 @@ def pad_list(xs, pad_value):
"""
n_batch = len(xs)
max_len = max(x.shape[0] for x in xs)
pad = paddle.full([n_batch, max_len, *xs[0].shape[1:]], pad_value)
pad = paddle.full([n_batch, max_len, *xs[0].shape[1:]], pad_value, dtype=xs[0].dtype)
for i in range(n_batch):
pad[i, :xs[i].shape[0]] = xs[i]
......@@ -46,11 +46,16 @@ def pad_list(xs, pad_value):
return pad
def make_pad_mask(lengths, length_dim=-1):
def make_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of padded part.
Args:
lengths (Tensor(int64)): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Tensor(bool): Mask tensor containing indices of padded part bool.
......@@ -63,21 +68,96 @@ def make_pad_mask(lengths, length_dim=-1):
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = paddle.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]])
>>> xs = paddle.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]])
With the reference tensor and dimension indicator.
>>> xs = paddle.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]])
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]],)
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
bs = paddle.shape(lengths)[0]
if xs is None:
maxlen = lengths.max()
else:
maxlen = paddle.shape(xs)[length_dim]
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen])
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
if xs is not None:
assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs)
if length_dim < 0:
length_dim = len(paddle.shape(xs)) + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None
for i in range(len(paddle.shape(xs))))
mask = paddle.expand(mask[ind], paddle.shape(xs))
return mask
def make_non_pad_mask(lengths, length_dim=-1):
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
Args:
......@@ -98,8 +178,70 @@ def make_non_pad_mask(lengths, length_dim=-1):
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = paddle.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]])
>>> xs = paddle.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]])
With the reference tensor and dimension indicator.
>>> xs = paddle.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]])
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]])
"""
return paddle.logical_not(make_pad_mask(lengths, length_dim))
return paddle.logical_not(make_pad_mask(lengths, xs, length_dim))
def initialize(model: nn.Layer, init: str):
......
......@@ -10,8 +10,8 @@ python inference.py \
--uid=Prompt_003_new \
--new_str='今天天气很好.' \
--prefix='./prompt/dev/' \
--source_language=english \
--target_language=chinese \
--source_lang=english \
--target_lang=chinese \
--output_name=pred_clone.wav \
--use_pt_vocoder=False \
--voc=pwgan_aishell3 \
......
......@@ -9,8 +9,8 @@ python inference.py \
--uid=p299_096 \
--new_str='I enjoy my life, do you?' \
--prefix='./prompt/dev/' \
--source_language=english \
--target_language=english \
--source_lang=english \
--target_lang=english \
--output_name=pred_gen.wav \
--use_pt_vocoder=False \
--voc=pwgan_aishell3 \
......
......@@ -10,8 +10,8 @@ python inference.py \
--uid=p243_new \
--new_str='for that reason cover is impossible to be given.' \
--prefix='./prompt/dev/' \
--source_language=english \
--target_language=english \
--source_lang=english \
--target_lang=english \
--output_name=pred_edit.wav \
--use_pt_vocoder=False \
--voc=pwgan_aishell3 \
......
......@@ -80,10 +80,8 @@ def parse_args():
parser.add_argument("--uid", type=str, help="uid")
parser.add_argument("--new_str", type=str, help="new string")
parser.add_argument("--prefix", type=str, help="prefix")
parser.add_argument("--clone_prefix", type=str, default=None, help="clone prefix")
parser.add_argument("--clone_uid", type=str, default=None, help="clone uid")
parser.add_argument("--source_language", type=str, help="source language")
parser.add_argument("--target_language", type=str, help="target language")
parser.add_argument("--source_lang", type=str, default="english", help="source language")
parser.add_argument("--target_lang", type=str, default="english", help="target language")
parser.add_argument("--output_name", type=str, help="output name")
parser.add_argument("--task_name", type=str, help="task name")
parser.add_argument(
......
......@@ -9,7 +9,7 @@ import torch
import yaml
class ParallelWaveGANPretrainedVocoder(torch.nn.Module):
class TorchPWGAN(torch.nn.Module):
"""Wrapper class to load the vocoder trained with parallel_wavegan repo."""
def __init__(
......
import os
from typing import List
from typing import Optional
import numpy as np
import paddle
import yaml
......@@ -5,11 +9,8 @@ from sedit_arg_parser import parse_args
from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.modules.normalizer import ZScore
from tools.parallel_wavegan_pretrained_vocoder import ParallelWaveGANPretrainedVocoder
# new add
from tools.torch_pwgan import TorchPWGAN
model_alias = {
# acoustic model
......@@ -25,6 +26,10 @@ model_alias = {
"paddlespeech.t2s.models.tacotron2:Tacotron2",
"tacotron2_inference":
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
"pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
"pwgan_inference":
"paddlespeech.t2s.models.parallel_wavegan:PWGInference",
}
......@@ -43,60 +48,65 @@ def build_vocoder_from_file(
# Build vocoder
if str(vocoder_file).endswith(".pkl"):
# If the extension is ".pkl", the model is trained with parallel_wavegan
vocoder = ParallelWaveGANPretrainedVocoder(vocoder_file,
vocoder_config_file)
vocoder = TorchPWGAN(vocoder_file, vocoder_config_file)
return vocoder.to(device)
else:
raise ValueError(f"{vocoder_file} is not supported format.")
def get_voc_out(mel, target_language="chinese"):
def get_voc_out(mel, target_lang: str="chinese"):
# vocoder
args = parse_args()
assert target_language == "chinese" or target_language == "english", "In get_voc_out function, target_language is illegal..."
assert target_lang == "chinese" or target_lang == "english", "In get_voc_out function, target_lang is illegal..."
# print("current vocoder: ", args.voc)
with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f))
# print(voc_config)
voc_inference = get_voc_inference(args, voc_config)
voc_inference = voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
mel = paddle.to_tensor(mel)
# print("masked_mel: ", mel.shape)
with paddle.no_grad():
wav = voc_inference(mel)
# print("shepe of wav (time x n_channels):%s"%wav.shape)
return np.squeeze(wav)
# dygraph
def get_am_inference(args, am_config):
with open(args.phones_dict, "r") as f:
def get_am_inference(am: str='fastspeech2_csmsc',
am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
return_am: bool=False):
with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
# print("vocab_size:", vocab_size)
print("vocab_size:", vocab_size)
tone_size = None
if 'tones_dict' in args and args.tones_dict:
with open(args.tones_dict, "r") as f:
if tones_dict is not None:
with open(tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id)
print("tone_size:", tone_size)
spk_num = None
if 'speaker_dict' in args and args.speaker_dict:
with open(args.speaker_dict, 'rt') as f:
if speaker_dict is not None:
with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
print("spk_num:", spk_num)
odim = am_config.n_mels
# model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias)
......@@ -113,39 +123,61 @@ def get_am_inference(args, am_config):
elif am_name == 'tacotron2':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
am.set_state_dict(paddle.load(args.am_ckpt)["main_params"])
am.set_state_dict(paddle.load(am_ckpt)["main_params"])
am.eval()
am_mu, am_std = np.load(args.am_stat)
am_mu, am_std = np.load(am_stat)
am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std)
am_normalizer = ZScore(am_mu, am_std)
am_inference = am_inference_class(am_normalizer, am)
am_inference.eval()
print("acoustic model done!")
return am, am_inference, am_name, am_dataset, phn_id
if return_am:
return am_inference, am
else:
return am_inference
def evaluate_durations(phns,
target_language="chinese",
fs=24000,
hop_length=300):
def get_voc_inference(
voc: str='pwgan_csmsc',
voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None, ):
# model: {model_name}_{dataset}
voc_name = voc[:voc.rindex('_')]
voc_class = dynamic_import(voc_name, model_alias)
voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
if voc_name != 'wavernn':
voc = voc_class(**voc_config["generator_params"])
voc.set_state_dict(paddle.load(voc_ckpt)["generator_params"])
voc.remove_weight_norm()
voc.eval()
else:
voc = voc_class(**voc_config["model"])
voc.set_state_dict(paddle.load(voc_ckpt)["main_params"])
voc.eval()
voc_mu, voc_std = np.load(voc_stat)
voc_mu = paddle.to_tensor(voc_mu)
voc_std = paddle.to_tensor(voc_std)
voc_normalizer = ZScore(voc_mu, voc_std)
voc_inference = voc_inference_class(voc_normalizer, voc)
voc_inference.eval()
print("voc done!")
return voc_inference
def evaluate_durations(phns: List[str],
target_lang: str="chinese",
fs: int=24000,
hop_length: int=300):
args = parse_args()
if target_language == 'english':
if target_lang == 'english':
args.lang = 'en'
args.am = "fastspeech2_ljspeech"
args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_language == 'chinese':
elif target_lang == 'chinese':
args.lang = 'zh'
args.am = "fastspeech2_csmsc"
args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[])
if args.ngpu == 0:
......@@ -155,23 +187,28 @@ def evaluate_durations(phns,
else:
print("ngpu should >= 0 !")
assert target_language == "chinese" or target_language == "english", "In evaluate_durations function, target_language is illegal..."
assert target_lang == "chinese" or target_lang == "english", "In evaluate_durations function, target_lang is illegal..."
# Init body.
with open(args.am_config) as f:
am_config = CfgNode(yaml.safe_load(f))
# print("========Config========")
# print(am_config)
# print("---------------------")
# acoustic model
am, am_inference, am_name, am_dataset, phn_id = get_am_inference(args,
am_config)
am_inference, am = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict,
return_am=True)
torch_phns = phns
vocab_phones = {}
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
for tone, id in phn_id:
vocab_phones[tone] = int(id)
# print("vocab_phones: ", len(vocab_phones))
vocab_size = len(vocab_phones)
phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns]
......@@ -185,59 +222,3 @@ def evaluate_durations(phns,
phoneme_durations_new = pre_d_outs * hop_length / fs
phoneme_durations_new = phoneme_durations_new.tolist()[:-1]
return phoneme_durations_new
def sentence2phns(sentence, target_language="en"):
args = parse_args()
if target_language == 'en':
args.lang = 'en'
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_language == 'zh':
args.lang = 'zh'
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
else:
print("target_language should in {'zh', 'en'}!")
frontend = get_frontend(args)
merge_sentences = True
get_tone_ids = False
if target_language == 'zh':
input_ids = frontend.get_input_ids(
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
print_info=False)
phone_ids = input_ids["phone_ids"]
phonemes = frontend.get_phonemes(
sentence, merge_sentences=merge_sentences, print_info=False)
return phonemes[0], input_ids["phone_ids"][0]
elif target_language == 'en':
phonemes = frontend.phoneticize(sentence)
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
phones_list = []
vocab_phones = {}
punc = ":,;。?!“”‘’':,;.?!"
with open(args.phones_dict, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
vocab_phones[phn] = int(id)
phones = phonemes[1:-1]
phones = [phn for phn in phones if not phn.isspace()]
# replace unk phone with sp
phones = [
phn if (phn in vocab_phones and phn not in punc) else "sp"
for phn in phones
]
phones_list.append(phones)
return phones_list[0], input_ids["phone_ids"][0]
else:
print("lang should in {'zh', 'en'}!")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册