diff --git a/ernie-sat/README.md b/ernie-sat/README.md index db8cc07504d45a143557476bc800fcfc075ed706..bfecccc5500431ad515802a430987a71f1bb89be 100644 --- a/ernie-sat/README.md +++ b/ernie-sat/README.md @@ -113,8 +113,8 @@ prompt/dev 8. ` --uid` 特定提示(prompt)语音的 id 9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本) 10. ` --prefix` 特定音频对应的文本、音素相关文件的地址 -11. ` --source_language` , 源语言 -12. ` --target_language` , 目标语言 +11. ` --source_lang` , 源语言 +12. ` --target_lang` , 目标语言 13. ` --output_name` , 合成语音名称 14. ` --task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务 15. ` --use_pt_vocoder`, 英文场景下是否使用 torch 版本的 vocoder, 默认情况下为 False; 设置为 False 则在英文场景下使用 paddle 版本 vocoder diff --git a/ernie-sat/align_english.py b/ernie-sat/align.py similarity index 59% rename from ernie-sat/align_english.py rename to ernie-sat/align.py index aff47fe0514e6f03090cbed9dddaf961537eb186..5c7144f439d887fac1b7763df8f4376cd2ec73a1 100755 --- a/ernie-sat/align_english.py +++ b/ernie-sat/align.py @@ -1,6 +1,6 @@ #!/usr/bin/env python """ Usage: - align_english.py wavfile trsfile outwordfile outphonefile + align.py wavfile trsfile outwordfile outphonefile """ import multiprocessing as mp import os @@ -9,12 +9,45 @@ import sys from tqdm import tqdm PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' -MODEL_DIR = 'tools/aligner/english' +MODEL_DIR_EN = 'tools/aligner/english' +MODEL_DIR_ZH = 'tools/aligner/mandarin' HVITE = 'tools/htk/HTKTools/HVite' HCOPY = 'tools/htk/HTKTools/HCopy' -def prep_txt(line, tmpbase, dictfile): +def prep_txt_zh(line: str, tmpbase: str, dictfile: str): + + words = [] + line = line.strip() + for pun in [ + ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', + u'。', u':', u';', u'!', u'?', u'(', u')' + ]: + line = line.replace(pun, ' ') + for wrd in line.split(): + if (wrd[-1] == '-'): + wrd = wrd[:-1] + if (wrd[0] == "'"): + wrd = wrd[1:] + if wrd: + words.append(wrd) + + ds = set([]) + with open(dictfile, 'r') as fid: + for line in fid: + ds.add(line.split()[0]) + + unk_words = set([]) + with open(tmpbase + '.txt', 'w') as fwid: + for wrd in words: + if (wrd not in ds): + unk_words.add(wrd) + fwid.write(wrd + ' ') + fwid.write('\n') + return unk_words + + +def prep_txt_en(line: str, tmpbase, dictfile): words = [] @@ -97,7 +130,7 @@ def prep_txt(line, tmpbase, dictfile): fw.close() -def prep_mlf(txt, tmpbase): +def prep_mlf(txt: str, tmpbase: str): with open(tmpbase + '.mlf', 'w') as fwid: fwid.write('#!MLF!#\n') @@ -110,7 +143,55 @@ def prep_mlf(txt, tmpbase): fwid.write('.\n') -def gen_res(tmpbase, outfile1, outfile2): +def _get_user(): + return os.path.expanduser('~').split("/")[-1] + + +def alignment(wav_path: str, text: str): + tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) + + #prepare wav and trs files + try: + os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -') + except: + print('sox error!') + return None + + #prepare clean_transcript file + try: + prep_txt_en(text, tmpbase, MODEL_DIR_EN + '/dict') + except: + print('prep_txt error!') + return None + + #prepare mlf file + try: + with open(tmpbase + '.txt', 'r') as fid: + txt = fid.readline() + prep_mlf(txt, tmpbase) + except: + print('prep_mlf error!') + return None + + #prepare scp + try: + os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase + + '.wav' + ' ' + tmpbase + '.plp') + except: + print('HCopy error!') + return None + + #run alignment + try: + os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + + '.mlf -H ' + MODEL_DIR_EN + '/16000/macros -H ' + MODEL_DIR_EN + + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + + '.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase + + '.plp 2>&1 > /dev/null') + except: + print('HVite error!') + return None + with open(tmpbase + '.txt', 'r') as fid: words = fid.readline().strip().split() words = txt.strip().split() @@ -119,59 +200,47 @@ def gen_res(tmpbase, outfile1, outfile2): with open(tmpbase + '.aligned', 'r') as fid: lines = fid.readlines() i = 2 - times1 = [] times2 = [] + word2phns = {} + current_word = '' + index = 0 while (i < len(lines)): - if (len(lines[i].split()) >= 4) and ( - lines[i].split()[0] != lines[i].split()[1]): - phn = lines[i].split()[2] - pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000 - pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000 + splited_line = lines[i].strip().split() + if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]): + phn = splited_line[2] + pst = (int(splited_line[0]) / 1000 + 125) / 10000 + pen = (int(splited_line[1]) / 1000 + 125) / 10000 times2.append([phn, pst, pen]) - if (len(lines[i].split()) == 5): - if (lines[i].split()[0] != lines[i].split()[1]): - wrd = lines[i].split()[-1].strip() - st = (int(lines[i].split()[0]) / 1000 + 125) / 10000 - j = i + 1 - while (lines[j] != '.\n') and (len(lines[j].split()) != 5): - j += 1 - en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000 - times1.append([wrd, st, en]) + # splited_line[-1]!='sp' + if len(splited_line) == 5: + current_word = str(index) + '_' + splited_line[-1] + word2phns[current_word] = phn + index += 1 + elif len(splited_line) == 4: + word2phns[current_word] += ' ' + phn i += 1 - - with open(outfile1, 'w') as fwid: - for item in times1: - if (item[0] == 'sp'): - fwid.write(str(item[1]) + ' ' + str(item[2]) + ' SIL\n') - else: - wrd = words.pop() - fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + wrd + '\n') - if words: - print('not matched::' + alignfile) - sys.exit(1) - - with open(outfile2, 'w') as fwid: - for item in times2: - fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n') - - -def _get_user(): - return os.path.expanduser('~').split("/")[-1] + return times2, word2phns -def alignment(wav_path, text_string): +def alignment_zh(wav_path, text_string): tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) #prepare wav and trs files try: - os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -') + os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase + + '.wav remix -') + except: print('sox error!') return None #prepare clean_transcript file try: - prep_txt(text_string, tmpbase, MODEL_DIR + '/dict') + unk_words = prep_txt_zh(text_string, tmpbase, MODEL_DIR_ZH + '/dict') + if unk_words: + print('Error! Please add the following words to dictionary:') + for unk in unk_words: + print("非法words: ", unk) except: print('prep_txt error!') return None @@ -187,7 +256,7 @@ def alignment(wav_path, text_string): #prepare scp try: - os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + + os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase + '.wav' + ' ' + tmpbase + '.plp') except: print('HCopy error!') @@ -196,10 +265,11 @@ def alignment(wav_path, text_string): #run alignment try: os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + - '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + - '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + - '.dict ' + MODEL_DIR + '/monophones ' + tmpbase + + '.mlf -H ' + MODEL_DIR_ZH + '/16000/macros -H ' + MODEL_DIR_ZH + + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR_ZH + + '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase + '.plp 2>&1 > /dev/null') + except: print('HVite error!') return None @@ -211,6 +281,7 @@ def alignment(wav_path, text_string): with open(tmpbase + '.aligned', 'r') as fid: lines = fid.readlines() + i = 2 times2 = [] word2phns = {} diff --git a/ernie-sat/align_mandarin.py b/ernie-sat/align_mandarin.py deleted file mode 100755 index fae2a2aea5d66f0853b003b6e5eae81488ff04a2..0000000000000000000000000000000000000000 --- a/ernie-sat/align_mandarin.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python -""" Usage: - align_mandarin.py wavfile trsfile outwordfile putphonefile -""" -import multiprocessing as mp -import os -import sys - -from tqdm import tqdm - -MODEL_DIR = 'tools/aligner/mandarin' -HVITE = 'tools/htk/HTKTools/HVite' -HCOPY = 'tools/htk/HTKTools/HCopy' - - -def prep_txt(line, tmpbase, dictfile): - - words = [] - line = line.strip() - for pun in [ - ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', - u'。', u':', u';', u'!', u'?', u'(', u')' - ]: - line = line.replace(pun, ' ') - for wrd in line.split(): - if (wrd[-1] == '-'): - wrd = wrd[:-1] - if (wrd[0] == "'"): - wrd = wrd[1:] - if wrd: - words.append(wrd) - - ds = set([]) - with open(dictfile, 'r') as fid: - for line in fid: - ds.add(line.split()[0]) - - unk_words = set([]) - with open(tmpbase + '.txt', 'w') as fwid: - for wrd in words: - if (wrd not in ds): - unk_words.add(wrd) - fwid.write(wrd + ' ') - fwid.write('\n') - return unk_words - - -def prep_mlf(txt, tmpbase): - - with open(tmpbase + '.mlf', 'w') as fwid: - fwid.write('#!MLF!#\n') - fwid.write('"' + tmpbase + '.lab"\n') - fwid.write('sp\n') - wrds = txt.split() - for wrd in wrds: - fwid.write(wrd.upper() + '\n') - fwid.write('sp\n') - fwid.write('.\n') - - -def gen_res(tmpbase, outfile1, outfile2): - with open(tmpbase + '.txt', 'r') as fid: - words = fid.readline().strip().split() - words = txt.strip().split() - words.reverse() - - with open(tmpbase + '.aligned', 'r') as fid: - lines = fid.readlines() - i = 2 - times1 = [] - times2 = [] - while (i < len(lines)): - if (len(lines[i].split()) >= 4) and ( - lines[i].split()[0] != lines[i].split()[1]): - phn = lines[i].split()[2] - pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000 - pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000 - times2.append([phn, pst, pen]) - if (len(lines[i].split()) == 5): - if (lines[i].split()[0] != lines[i].split()[1]): - wrd = lines[i].split()[-1].strip() - st = (int(lines[i].split()[0]) / 1000 + 125) / 10000 - j = i + 1 - while (lines[j] != '.\n') and (len(lines[j].split()) != 5): - j += 1 - en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000 - times1.append([wrd, st, en]) - i += 1 - - with open(outfile1, 'w') as fwid: - for item in times1: - if (item[0] == 'sp'): - fwid.write(str(item[1]) + ' ' + str(item[2]) + ' SIL\n') - else: - wrd = words.pop() - fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + wrd + '\n') - if words: - print('not matched::' + alignfile) - sys.exit(1) - - with open(outfile2, 'w') as fwid: - for item in times2: - fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n') - - -def alignment_zh(wav_path, text_string): - tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid()) - - #prepare wav and trs files - try: - os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase + - '.wav remix -') - - except: - print('sox error!') - return None - - #prepare clean_transcript file - try: - unk_words = prep_txt(text_string, tmpbase, MODEL_DIR + '/dict') - if unk_words: - print('Error! Please add the following words to dictionary:') - for unk in unk_words: - print("非法words: ", unk) - except: - print('prep_txt error!') - return None - - #prepare mlf file - try: - with open(tmpbase + '.txt', 'r') as fid: - txt = fid.readline() - prep_mlf(txt, tmpbase) - except: - print('prep_mlf error!') - return None - - #prepare scp - try: - os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase + - '.wav' + ' ' + tmpbase + '.plp') - except: - print('HCopy error!') - return None - - #run alignment - try: - os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase + - '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR + - '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR + - '/dict ' + MODEL_DIR + '/monophones ' + tmpbase + - '.plp 2>&1 > /dev/null') - - except: - print('HVite error!') - return None - - with open(tmpbase + '.txt', 'r') as fid: - words = fid.readline().strip().split() - words = txt.strip().split() - words.reverse() - - with open(tmpbase + '.aligned', 'r') as fid: - lines = fid.readlines() - - i = 2 - times2 = [] - word2phns = {} - current_word = '' - index = 0 - while (i < len(lines)): - splited_line = lines[i].strip().split() - if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]): - phn = splited_line[2] - pst = (int(splited_line[0]) / 1000 + 125) / 10000 - pen = (int(splited_line[1]) / 1000 + 125) / 10000 - times2.append([phn, pst, pen]) - # splited_line[-1]!='sp' - if len(splited_line) == 5: - current_word = str(index) + '_' + splited_line[-1] - word2phns[current_word] = phn - index += 1 - elif len(splited_line) == 4: - word2phns[current_word] += ' ' + phn - i += 1 - return times2, word2phns diff --git a/ernie-sat/dataset.py b/ernie-sat/dataset.py index cb1a9b25296117ad4bf924df012578ff8e7dd8ce..d8b896ada57380809f57d7e7f8329fe3947c92a9 100644 --- a/ernie-sat/dataset.py +++ b/ernie-sat/dataset.py @@ -4,37 +4,180 @@ import numpy as np import paddle -def pad_list(xs, pad_value): - """Perform padding for the list of tensors. +def phones_text_masking(xs_pad: paddle.Tensor, + src_mask: paddle.Tensor, + text_pad: paddle.Tensor, + text_mask: paddle.Tensor, + align_start: paddle.Tensor, + align_end: paddle.Tensor, + align_start_lens: paddle.Tensor, + mlm_prob: float, + mean_phn_span: float, + span_bdy: paddle.Tensor=None): + bz, sent_len, _ = paddle.shape(xs_pad) + masked_pos = paddle.zeros((bz, sent_len)) + _, text_len = paddle.shape(text_pad) + text_mask_num_lower = math.ceil(text_len * (1 - mlm_prob) * 0.5) + text_masked_pos = paddle.zeros((bz, text_len)) + y_masks = None + if mlm_prob == 1.0: + masked_pos += 1 + # y_masks = tril_masks + elif mean_phn_span == 0: + # only speech + length = sent_len + mean_phn_span = min(length * mlm_prob // 3, 50) + masked_phn_idxs = random_spans_noise_mask(length, mlm_prob, + mean_phn_span).nonzero() + masked_pos[:, masked_phn_idxs] = 1 + else: + for idx in range(bz): + if span_bdy is not None: + for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]): + masked_pos[idx, s:e] = 1 + else: + length = align_start_lens[idx] + if length < 2: + continue + masked_phn_idxs = random_spans_noise_mask( + length, mlm_prob, mean_phn_span).nonzero() + unmasked_phn_idxs = list( + set(range(length)) - set(masked_phn_idxs[0].tolist())) + np.random.shuffle(unmasked_phn_idxs) + masked_text_idxs = unmasked_phn_idxs[:text_mask_num_lower] + text_masked_pos[idx][masked_text_idxs] = 1 + masked_start = align_start[idx][masked_phn_idxs].tolist() + masked_end = align_end[idx][masked_phn_idxs].tolist() + for s, e in zip(masked_start, masked_end): + masked_pos[idx, s:e] = 1 + non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) + masked_pos = masked_pos * non_eos_mask + non_eos_text_mask = paddle.reshape(text_mask, paddle.shape(xs_pad)[:2]) + text_masked_pos = text_masked_pos * non_eos_text_mask + masked_pos = paddle.cast(masked_pos, 'bool') + text_masked_pos = paddle.cast(text_masked_pos, 'bool') + + return masked_pos, text_masked_pos, y_masks + + +def get_seg_pos_reduce_duration( + speech_pad: paddle.Tensor, + text_pad: paddle.Tensor, + align_start: paddle.Tensor, + align_end: paddle.Tensor, + align_start_lens: paddle.Tensor, + sega_emb: bool, + masked_pos: paddle.Tensor, + feats_lens: paddle.Tensor, ): + bz, speech_len, _ = paddle.shape(speech_pad) + text_seg_pos = paddle.zeros(paddle.shape(text_pad)) + speech_seg_pos = paddle.zeros((bz, speech_len), dtype=text_pad.dtype) + + reordered_idx = paddle.zeros((bz, speech_len), dtype=align_start_lens.dtype) + + durations = paddle.ones((bz, speech_len), dtype=align_start_lens.dtype) + max_reduced_length = 0 + if not sega_emb: + return speech_pad, masked_pos, speech_seg_pos, text_seg_pos, durations + for idx in range(bz): + first_idx = [] + last_idx = [] + align_length = align_start_lens[idx] + for j in range(align_length): + s, e = align_start[idx][j], align_end[idx][j] + if j == 0: + if paddle.sum(masked_pos[idx][0:s]) == 0: + first_idx.extend(range(0, s)) + else: + first_idx.extend([0]) + last_idx.extend(range(1, s)) + if paddle.sum(masked_pos[idx][s:e]) == 0: + first_idx.extend(range(s, e)) + else: + first_idx.extend([s]) + last_idx.extend(range(s + 1, e)) + durations[idx][s] = e - s + speech_seg_pos[idx][s:e] = j + 1 + text_seg_pos[idx][j] = j + 1 + max_reduced_length = max( + len(first_idx) + feats_lens[idx] - e, max_reduced_length) + first_idx.extend(range(e, speech_len)) + reordered_idx[idx] = paddle.to_tensor( + (first_idx + last_idx), dtype=align_start_lens.dtype) + feats_lens[idx] = len(first_idx) + reordered_idx = reordered_idx[:, :max_reduced_length] + + return reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens + + +def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float): + """This function is copy of `random_spans_helper + `__ . + Noise mask consisting of random spans of noise tokens. + The number of noise tokens and the number of noise spans and non-noise spans + are determined deterministically as follows: + num_noise_tokens = round(length * noise_density) + num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) + Spans alternate between non-noise and noise, beginning with non-noise. + Subject to the above restrictions, all masks are equally likely. Args: - xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. - pad_value (float): Value for padding. - + length: an int32 scalar (length of the incoming token sequence) + noise_density: a float - approximate density of output mask + mean_noise_span_length: a number Returns: - Tensor: Padded tensor (B, Tmax, `*`). - - Examples: - >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] - >>> x - [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] - >>> pad_list(x, 0) - tensor([[1., 1., 1., 1.], - [1., 1., 0., 0.], - [1., 0., 0., 0.]]) - + a boolean tensor with shape [length] """ - n_batch = len(xs) - max_len = max(paddle.shape(x)[0] for x in xs) - pad = paddle.full((n_batch, max_len), pad_value, dtype=xs[0].dtype) - - for i in range(n_batch): - pad[i, :paddle.shape(xs[i])[0]] = xs[i] - return pad + orig_length = length + + num_noise_tokens = int(np.round(length * mlm_prob)) + # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. + num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) + num_noise_spans = int(np.round(num_noise_tokens / mean_phn_span)) + + # avoid degeneracy by ensuring positive number of noise spans + num_noise_spans = max(num_noise_spans, 1) + num_nonnoise_tokens = length - num_noise_tokens + + # pick the lengths of the noise spans and the non-noise spans + def _random_seg(num_items, num_segs): + """Partition a sequence of items randomly into non-empty segments. + Args: + num_items: an integer scalar > 0 + num_segs: an integer scalar in [1, num_items] + Returns: + a Tensor with shape [num_segs] containing positive integers that add + up to num_items + """ + mask_idxs = np.arange(num_items - 1) < (num_segs - 1) + np.random.shuffle(mask_idxs) + first_in_seg = np.pad(mask_idxs, [[1, 0]]) + segment_id = np.cumsum(first_in_seg) + # count length of sub segments assuming that list is sorted + _, segment_length = np.unique(segment_id, return_counts=True) + return segment_length + + noise_span_lens = _random_seg(num_noise_tokens, num_noise_spans) + nonnoise_span_lens = _random_seg(num_nonnoise_tokens, num_noise_spans) + + interleaved_span_lens = np.reshape( + np.stack([nonnoise_span_lens, noise_span_lens], axis=1), + [num_noise_spans * 2]) + span_starts = np.cumsum(interleaved_span_lens)[:-1] + span_start_indicator = np.zeros((length, ), dtype=np.int8) + span_start_indicator[span_starts] = True + span_num = np.cumsum(span_start_indicator) + is_noise = np.equal(span_num % 2, 1) + + return is_noise[:orig_length] + + +def pad_to_longformer_att_window(text: paddle.Tensor, + max_len: int, + max_tlen: int, + attention_window: int=0): - -def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window): round = max_len % attention_window if round != 0: max_tlen += (attention_window - round) @@ -48,286 +191,67 @@ def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window): return text_pad, max_tlen -def make_pad_mask(lengths, xs=None, length_dim=-1): - """Make mask tensor containing indices of padded part. - - Args: - lengths (LongTensor or List): Batch of lengths (B,). - xs (Tensor, optional): The reference tensor. - If set, masks will be the same shape as this tensor. - length_dim (int, optional): Dimension indicator of the above tensor. - See the example. - - Returns: - Tensor: Mask tensor containing indices of padded part. - dtype=torch.uint8 in PyTorch 1.2- - dtype=torch.bool in PyTorch 1.2+ (including 1.2) - - Examples: - With only lengths. - - >>> lengths = [5, 3, 2] - >>> make_non_pad_mask(lengths) - masks = [[0, 0, 0, 0 ,0], - [0, 0, 0, 1, 1], - [0, 0, 1, 1, 1]] - - With the reference tensor. - - >>> xs = torch.zeros((3, 2, 4)) - >>> make_pad_mask(lengths, xs) - tensor([[[0, 0, 0, 0], - [0, 0, 0, 0]], - [[0, 0, 0, 1], - [0, 0, 0, 1]], - [[0, 0, 1, 1], - [0, 0, 1, 1]]], dtype=torch.uint8) - >>> xs = torch.zeros((3, 2, 6)) - >>> make_pad_mask(lengths, xs) - tensor([[[0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1]], - [[0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1]], - [[0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) - - With the reference tensor and dimension indicator. - - >>> xs = torch.zeros((3, 6, 6)) - >>> make_pad_mask(lengths, xs, 1) - tensor([[[0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1]], - [[0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1]], - [[0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) - >>> make_pad_mask(lengths, xs, 2) - tensor([[[0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1]], - [[0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1]], - [[0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) - - """ - if length_dim == 0: - raise ValueError("length_dim cannot be 0: {}".format(length_dim)) - - if not isinstance(lengths, list): - lengths = list(lengths) - bs = int(len(lengths)) - if xs is None: - maxlen = int(max(lengths)) - else: - maxlen = paddle.shape(xs)[length_dim] - - seq_range = paddle.arange(0, maxlen, dtype=paddle.int64) - seq_range_expand = paddle.expand( - paddle.unsqueeze(seq_range, 0), (bs, maxlen)) - seq_length_expand = paddle.unsqueeze(paddle.to_tensor(lengths), -1) - mask = seq_range_expand >= seq_length_expand - - if xs is not None: - assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs) - - if length_dim < 0: - length_dim = len(paddle.shape(xs)) + length_dim - # ind = (:, None, ..., None, :, , None, ..., None) - ind = tuple( - slice(None) if i in (0, length_dim) else None - for i in range(len(paddle.shape(xs)))) - mask = paddle.expand(mask[ind], paddle.shape(xs)) - return mask - - -def make_non_pad_mask(lengths, xs=None, length_dim=-1): - """Make mask tensor containing indices of non-padded part. - - Args: - lengths (LongTensor or List): Batch of lengths (B,). - xs (Tensor, optional): The reference tensor. - If set, masks will be the same shape as this tensor. - length_dim (int, optional): Dimension indicator of the above tensor. - See the example. - - Returns: - ByteTensor: mask tensor containing indices of padded part. - dtype=torch.uint8 in PyTorch 1.2- - dtype=torch.bool in PyTorch 1.2+ (including 1.2) - - Examples: - With only lengths. - - >>> lengths = [5, 3, 2] - >>> make_non_pad_mask(lengths) - masks = [[1, 1, 1, 1 ,1], - [1, 1, 1, 0, 0], - [1, 1, 0, 0, 0]] - - With the reference tensor. - - >>> xs = torch.zeros((3, 2, 4)) - >>> make_non_pad_mask(lengths, xs) - tensor([[[1, 1, 1, 1], - [1, 1, 1, 1]], - [[1, 1, 1, 0], - [1, 1, 1, 0]], - [[1, 1, 0, 0], - [1, 1, 0, 0]]], dtype=torch.uint8) - >>> xs = torch.zeros((3, 2, 6)) - >>> make_non_pad_mask(lengths, xs) - tensor([[[1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0]], - [[1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0]], - [[1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) - - With the reference tensor and dimension indicator. - - >>> xs = torch.zeros((3, 6, 6)) - >>> make_non_pad_mask(lengths, xs, 1) - tensor([[[1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0]], - [[1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]], - [[1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) - >>> make_non_pad_mask(lengths, xs, 2) - tensor([[[1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0]], - [[1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0]], - [[1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) - - """ - return ~make_pad_mask(lengths, xs, length_dim) - - -def phones_masking(xs_pad, - src_mask, - align_start, - align_end, - align_start_lengths, - mlm_prob, - mean_phn_span, - span_boundary=None): +def phones_masking(xs_pad: paddle.Tensor, + src_mask: paddle.Tensor, + align_start: paddle.Tensor, + align_end: paddle.Tensor, + align_start_lens: paddle.Tensor, + mlm_prob: float, + mean_phn_span: int, + span_bdy: paddle.Tensor=None): bz, sent_len, _ = paddle.shape(xs_pad) - mask_num_lower = math.ceil(sent_len * mlm_prob) - masked_position = np.zeros((bz, sent_len)) + masked_pos = paddle.zeros((bz, sent_len)) y_masks = None - # y_masks = torch.ones(bz,sent_len,sent_len,device=xs_pad.device,dtype=xs_pad.dtype) - # tril_masks = torch.tril(y_masks) if mlm_prob == 1.0: - masked_position += 1 - # y_masks = tril_masks + masked_pos += 1 elif mean_phn_span == 0: # only speech length = sent_len mean_phn_span = min(length * mlm_prob // 3, 50) - masked_phn_indices = random_spans_noise_mask(length, mlm_prob, - mean_phn_span).nonzero() - masked_position[:, masked_phn_indices] = 1 + masked_phn_idxs = random_spans_noise_mask(length, mlm_prob, + mean_phn_span).nonzero() + masked_pos[:, masked_phn_idxs] = 1 else: for idx in range(bz): - if span_boundary is not None: - for s, e in zip(span_boundary[idx][::2], - span_boundary[idx][1::2]): - masked_position[idx, s:e] = 1 - - # y_masks[idx, :, s:e] = tril_masks[idx, :, s:e] - # y_masks[idx, e:, s:e ] = 0 + if span_bdy is not None: + for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]): + masked_pos[idx, s:e] = 1 else: - length = align_start_lengths[idx].item() + length = align_start_lens[idx] if length < 2: continue - masked_phn_indices = random_spans_noise_mask( + masked_phn_idxs = random_spans_noise_mask( length, mlm_prob, mean_phn_span).nonzero() - masked_start = align_start[idx][masked_phn_indices].tolist() - masked_end = align_end[idx][masked_phn_indices].tolist() + masked_start = align_start[idx][masked_phn_idxs].tolist() + masked_end = align_end[idx][masked_phn_idxs].tolist() for s, e in zip(masked_start, masked_end): - masked_position[idx, s:e] = 1 - # y_masks[idx, :, s:e] = tril_masks[idx, :, s:e] - # y_masks[idx, e:, s:e ] = 0 - non_eos_mask = np.array( - paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]).float().cpu()) - masked_position = masked_position * non_eos_mask - # y_masks = src_mask & y_masks.bool() + masked_pos[idx, s:e] = 1 + non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) + masked_pos = masked_pos * non_eos_mask + masked_pos = paddle.cast(masked_pos, 'bool') - return paddle.cast(paddle.to_tensor(masked_position), paddle.bool), y_masks + return masked_pos, y_masks -def get_segment_pos(speech_pad, text_pad, align_start, align_end, - align_start_lengths, sega_emb): - bz, speech_len, _ = speech_pad.size() - _, text_len = text_pad.size() +def get_seg_pos(speech_pad: paddle.Tensor, + text_pad: paddle.Tensor, + align_start: paddle.Tensor, + align_end: paddle.Tensor, + align_start_lens: paddle.Tensor, + sega_emb: bool): + bz, speech_len, _ = paddle.shape(speech_pad) + _, text_len = paddle.shape(text_pad) - # text_segment_pos = paddle.zeros_like(text_pad) - # speech_segment_pos = paddle.zeros((bz, speech_len),dtype=text_pad.dtype) - text_segment_pos = np.zeros((bz, text_len)).astype('int64') - speech_segment_pos = np.zeros((bz, speech_len)).astype('int64') + text_seg_pos = paddle.zeros((bz, text_len), dtype='int64') + speech_seg_pos = paddle.zeros((bz, speech_len), dtype='int64') if not sega_emb: - text_segment_pos = paddle.to_tensor(text_segment_pos) - speech_segment_pos = paddle.to_tensor(speech_segment_pos) - return speech_segment_pos, text_segment_pos + return speech_seg_pos, text_seg_pos for idx in range(bz): - align_length = align_start_lengths[idx].item() + align_length = align_start_lens[idx] for j in range(align_length): - s, e = align_start[idx][j].item(), align_end[idx][j].item() - speech_segment_pos[idx][s:e] = j + 1 - text_segment_pos[idx][j] = j + 1 - - text_segment_pos = paddle.to_tensor(text_segment_pos) - speech_segment_pos = paddle.to_tensor(speech_segment_pos) + s, e = align_start[idx][j], align_end[idx][j] + speech_seg_pos[idx, s:e] = j + 1 + text_seg_pos[idx, j] = j + 1 - return speech_segment_pos, text_segment_pos + return speech_seg_pos, text_seg_pos diff --git a/ernie-sat/inference.py b/ernie-sat/inference.py index e0ad021e968bba65af532f7680d9d3bb0e8f60a8..ee702d2b1e7ec0dc34382472465fa4845ae84f35 100644 --- a/ernie-sat/inference.py +++ b/ernie-sat/inference.py @@ -1,11 +1,7 @@ #!/usr/bin/env python3 import argparse -import math import os -import pickle import random -import string -import sys from pathlib import Path from typing import Collection from typing import Dict @@ -18,17 +14,17 @@ import numpy as np import paddle import soundfile as sf import torch +from paddle import nn from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model -from align_english import alignment -from align_mandarin import alignment_zh -from dataset import get_segment_pos -from dataset import make_non_pad_mask -from dataset import make_pad_mask -from dataset import pad_list +from align import alignment +from align import alignment_zh +from dataset import get_seg_pos +from dataset import get_seg_pos_reduce_duration from dataset import pad_to_longformer_att_window from dataset import phones_masking +from dataset import phones_text_masking from model_paddle import build_model_from_file from read_text import load_num_sequence_text from read_text import read_2column_text @@ -37,8 +33,9 @@ from utils import build_vocoder_from_file from utils import evaluate_durations from utils import get_voc_out from utils import is_chinese -from utils import sentence2phns from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.modules.nets_utils import pad_list +from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask random.seed(0) np.random.seed(0) @@ -47,81 +44,72 @@ MODEL_DIR_EN = 'tools/aligner/english' MODEL_DIR_ZH = 'tools/aligner/mandarin' -def plot_mel_and_vocode_wav(uid, - prefix, - clone_uid, - clone_prefix, - source_language, - target_language, - model_name, - wav_path, - full_origin_str, - old_str, - new_str, - use_pt_vocoder, - duration_preditor_path, - sid=None, - non_autoreg=True): - wav_org, input_feat, output_feat, old_span_boundary, new_span_boundary, fs, hop_length = get_mlm_output( - uid, - prefix, - clone_uid, - clone_prefix, - source_language, - target_language, - model_name, - wav_path, - old_str, - new_str, - duration_preditor_path, +def plot_mel_and_vocode_wav(uid: str, + wav_path: str, + prefix: str="./prompt/dev/", + source_lang: str='english', + target_lang: str='english', + model_name: str="conformer", + full_origin_str: str="", + old_str: str="", + new_str: str="", + duration_preditor_path: str=None, + use_pt_vocoder: bool=False, + sid: str=None, + non_autoreg: bool=True): + wav_org, input_feat, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output( + uid=uid, + prefix=prefix, + source_lang=source_lang, + target_lang=target_lang, + model_name=model_name, + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + duration_preditor_path=duration_preditor_path, use_teacher_forcing=non_autoreg, sid=sid) - masked_feat = output_feat[new_span_boundary[0]:new_span_boundary[ - 1]].detach().float().cpu().numpy() + masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]] - if target_language == 'english': + if target_lang == 'english': if use_pt_vocoder: - output_feat = output_feat.detach().float().cpu().numpy() + output_feat = output_feat.cpu().numpy() output_feat = torch.tensor(output_feat, dtype=torch.float) vocoder = load_vocoder('vctk_parallel_wavegan.v1.long') - replaced_wav = vocoder( - output_feat).detach().float().data.cpu().numpy() + replaced_wav = vocoder(output_feat).cpu().numpy() else: - output_feat_np = output_feat.detach().float().cpu().numpy() - replaced_wav = get_voc_out(output_feat_np, target_language) + replaced_wav = get_voc_out(output_feat, target_lang) - elif target_language == 'chinese': - output_feat_np = output_feat.detach().float().cpu().numpy() - replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, - target_language) + elif target_lang == 'chinese': + replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, target_lang) - old_time_boundary = [hop_length * x for x in old_span_boundary] - new_time_boundary = [hop_length * x for x in new_span_boundary] + old_time_bdy = [hop_length * x for x in old_span_bdy] + new_time_bdy = [hop_length * x for x in new_span_bdy] - if target_language == 'english': + if target_lang == 'english': wav_org_replaced_paddle_voc = np.concatenate([ - wav_org[:old_time_boundary[0]], - replaced_wav[new_time_boundary[0]:new_time_boundary[1]], - wav_org[old_time_boundary[1]:] + wav_org[:old_time_bdy[0]], + replaced_wav[new_time_bdy[0]:new_time_bdy[1]], + wav_org[old_time_bdy[1]:] ]) data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc} - elif target_language == 'chinese': + elif target_lang == 'chinese': wav_org_replaced_only_mask_fst2_voc = np.concatenate([ - wav_org[:old_time_boundary[0]], replaced_wav_only_mask_fst2_voc, - wav_org[old_time_boundary[1]:] + wav_org[:old_time_bdy[0]], replaced_wav_only_mask_fst2_voc, + wav_org[old_time_bdy[1]:] ]) data_dict = { "origin": wav_org, "output": wav_org_replaced_only_mask_fst2_voc, } - return data_dict, old_span_boundary + return data_dict, old_span_bdy -def get_unk_phns(word_str): +def get_unk_phns(word_str: str): tmpbase = '/tmp/tp.' f = open(tmpbase + 'temp.words', 'w') f.write(word_str) @@ -160,9 +148,8 @@ def get_unk_phns(word_str): return phns -def words2phns(line): +def words2phns(line: str): dictfile = MODEL_DIR_EN + '/dict' - tmpbase = '/tmp/tp.' line = line.strip() words = [] for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']: @@ -200,9 +187,8 @@ def words2phns(line): return phns, wrd2phns -def words2phns_zh(line): +def words2phns_zh(line: str): dictfile = MODEL_DIR_ZH + '/dict' - tmpbase = '/tmp/tp.' line = line.strip() words = [] for pun in [ @@ -242,7 +228,7 @@ def words2phns_zh(line): return phns, wrd2phns -def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"): +def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"): vocoder_tag = vocoder_tag.replace("parallel_wavegan/", "") vocoder_file = download_pretrained_model(vocoder_tag) vocoder_config = Path(vocoder_file).parent / "config.yml" @@ -250,7 +236,7 @@ def load_vocoder(vocoder_tag="vctk_parallel_wavegan.v1.long"): return vocoder -def load_model(model_name): +def load_model(model_name: str): config_path = './pretrained_model/{}/config.yaml'.format(model_name) model_path = './pretrained_model/{}/model.pdparams'.format(model_name) mlm_model, args = build_model_from_file( @@ -258,7 +244,7 @@ def load_model(model_name): return mlm_model, args -def read_data(uid, prefix): +def read_data(uid: str, prefix: str): mfa_text = read_2column_text(prefix + '/text')[uid] mfa_wav_path = read_2column_text(prefix + '/wav.scp')[uid] if 'mnt' not in mfa_wav_path: @@ -266,7 +252,7 @@ def read_data(uid, prefix): return mfa_text, mfa_wav_path -def get_align_data(uid, prefix): +def get_align_data(uid: str, prefix: str): mfa_path = prefix + "mfa_" mfa_text = read_2column_text(mfa_path + 'text')[uid] mfa_start = load_num_sequence_text( @@ -277,43 +263,45 @@ def get_align_data(uid, prefix): return mfa_text, mfa_start, mfa_end, mfa_wav_path -def get_masked_mel_boundary(mfa_start, mfa_end, fs, hop_length, - span_tobe_replaced): +def get_masked_mel_bdy(mfa_start: List[float], + mfa_end: List[float], + fs: int, + hop_length: int, + span_to_repl: List[List[int]]): align_start = paddle.to_tensor(mfa_start).unsqueeze(0) align_end = paddle.to_tensor(mfa_end).unsqueeze(0) align_start = paddle.floor(fs * align_start / hop_length).int() align_end = paddle.floor(fs * align_end / hop_length).int() - if span_tobe_replaced[0] >= len(mfa_start): - span_boundary = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]] + if span_to_repl[0] >= len(mfa_start): + span_bdy = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]] else: - span_boundary = [ - align_start[0].tolist()[span_tobe_replaced[0]], - align_end[0].tolist()[span_tobe_replaced[1] - 1] + span_bdy = [ + align_start[0].tolist()[span_to_repl[0]], + align_end[0].tolist()[span_to_repl[1] - 1] ] - return span_boundary + return span_bdy -def recover_dict(word2phns, tp_word2phns): +def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]): dic = {} - need_del_key = [] - exist_index = [] + keys_to_del = [] + exist_idx = [] sp_count = 0 add_sp_count = 0 for key in word2phns.keys(): idx, wrd = key.split('_') if wrd == 'sp': sp_count += 1 - exist_index.append(int(idx)) + exist_idx.append(int(idx)) else: - need_del_key.append(key) + keys_to_del.append(key) - for key in need_del_key: + for key in keys_to_del: del word2phns[key] cur_id = 0 for key in tp_word2phns.keys(): - # print("debug: ",key) - if cur_id in exist_index: + if cur_id in exist_idx: dic[str(cur_id) + "_sp"] = 'sp' cur_id += 1 add_sp_count += 1 @@ -329,14 +317,17 @@ def recover_dict(word2phns, tp_word2phns): return dic -def get_phns_and_spans(wav_path, old_str, new_str, source_language, - clone_target_language): +def get_phns_and_spans(wav_path: str, + old_str: str="", + new_str: str="", + source_lang: str="english", + target_lang: str="english"): append_new_str = (old_str == new_str[:len(old_str)]) old_phns, mfa_start, mfa_end = [], [], [] - if source_language == "english": + if source_lang == "english": times2, word2phns = alignment(wav_path, old_str) - elif source_language == "chinese": + elif source_lang == "chinese": times2, word2phns = alignment_zh(wav_path, old_str) _, tp_word2phns = words2phns_zh(old_str) @@ -348,14 +339,14 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, word2phns = recover_dict(word2phns, tp_word2phns) else: - assert source_language == "chinese" or source_language == "english", "source_language is wrong..." + assert source_lang == "chinese" or source_lang == "english", "source_lang is wrong..." for item in times2: mfa_start.append(float(item[1])) mfa_end.append(float(item[2])) old_phns.append(item[0]) - if append_new_str and (source_language != clone_target_language): + if append_new_str and (source_lang != target_lang): is_cross_lingual_clone = True else: is_cross_lingual_clone = False @@ -364,18 +355,21 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, new_str_origin = new_str[:len(old_str)] new_str_append = new_str[len(old_str):] - if clone_target_language == "chinese": + if target_lang == "chinese": new_phns_origin, new_origin_word2phns = words2phns(new_str_origin) new_phns_append, temp_new_append_word2phns = words2phns_zh( new_str_append) - elif clone_target_language == "english": + elif target_lang == "english": + # 原始句子 new_phns_origin, new_origin_word2phns = words2phns_zh( - new_str_origin) # 原始句子 + new_str_origin) + # clone句子 new_phns_append, temp_new_append_word2phns = words2phns( - new_str_append) # clone句子 + new_str_append) else: - assert clone_target_language == "chinese" or clone_target_language == "english", "cloning is not support for this language, please check it." + assert target_lang == "chinese" or target_lang == "english", \ + "cloning is not support for this language, please check it." new_phns = new_phns_origin + new_phns_append @@ -390,16 +384,17 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, new_append_word2phns.items())) else: - if source_language == clone_target_language and clone_target_language == "english": + if source_lang == target_lang and target_lang == "english": new_phns, new_word2phns = words2phns(new_str) - elif source_language == clone_target_language and clone_target_language == "chinese": + elif source_lang == target_lang and target_lang == "chinese": new_phns, new_word2phns = words2phns_zh(new_str) else: - assert source_language == clone_target_language, "source language is not same with target language..." + assert source_lang == target_lang, \ + "source language is not same with target language..." - span_tobe_replaced = [0, len(old_phns) - 1] - span_tobe_added = [0, len(new_phns) - 1] - left_index = 0 + span_to_repl = [0, len(old_phns) - 1] + span_to_add = [0, len(new_phns) - 1] + left_idx = 0 new_phns_left = [] sp_count = 0 # find the left different index @@ -411,27 +406,27 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, else: idx = str(int(idx) - sp_count) if idx + '_' + wrd in new_word2phns: - left_index += len(new_word2phns[idx + '_' + wrd]) + left_idx += len(new_word2phns[idx + '_' + wrd]) new_phns_left.extend(word2phns[key].split()) else: - span_tobe_replaced[0] = len(new_phns_left) - span_tobe_added[0] = len(new_phns_left) + span_to_repl[0] = len(new_phns_left) + span_to_add[0] = len(new_phns_left) break # reverse word2phns and new_word2phns - right_index = 0 + right_idx = 0 new_phns_right = [] sp_count = 0 - word2phns_max_index = int(list(word2phns.keys())[-1].split('_')[0]) - new_word2phns_max_index = int(list(new_word2phns.keys())[-1].split('_')[0]) - new_phns_middle = [] + word2phns_max_idx = int(list(word2phns.keys())[-1].split('_')[0]) + new_word2phns_max_idx = int(list(new_word2phns.keys())[-1].split('_')[0]) + new_phns_mid = [] if append_new_str: new_phns_right = [] - new_phns_middle = new_phns[left_index:] - span_tobe_replaced[0] = len(new_phns_left) - span_tobe_added[0] = len(new_phns_left) - span_tobe_added[1] = len(new_phns_left) + len(new_phns_middle) - span_tobe_replaced[1] = len(old_phns) - len(new_phns_right) + new_phns_mid = new_phns[left_idx:] + span_to_repl[0] = len(new_phns_left) + span_to_add[0] = len(new_phns_left) + span_to_add[1] = len(new_phns_left) + len(new_phns_mid) + span_to_repl[1] = len(old_phns) - len(new_phns_right) else: for key in list(word2phns.keys())[::-1]: idx, wrd = key.split('_') @@ -439,33 +434,31 @@ def get_phns_and_spans(wav_path, old_str, new_str, source_language, sp_count += 1 new_phns_right = ['sp'] + new_phns_right else: - idx = str(new_word2phns_max_index - (word2phns_max_index - int( - idx) - sp_count)) + idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx) + - sp_count)) if idx + '_' + wrd in new_word2phns: - right_index -= len(new_word2phns[idx + '_' + wrd]) + right_idx -= len(new_word2phns[idx + '_' + wrd]) new_phns_right = word2phns[key].split() + new_phns_right else: - span_tobe_replaced[1] = len(old_phns) - len(new_phns_right) - new_phns_middle = new_phns[left_index:right_index] - span_tobe_added[1] = len(new_phns_left) + len( - new_phns_middle) - if len(new_phns_middle) == 0: - span_tobe_added[1] = min(span_tobe_added[1] + 1, - len(new_phns)) - span_tobe_added[0] = max(0, span_tobe_added[0] - 1) - span_tobe_replaced[0] = max(0, - span_tobe_replaced[0] - 1) - span_tobe_replaced[1] = min(span_tobe_replaced[1] + 1, - len(old_phns)) + span_to_repl[1] = len(old_phns) - len(new_phns_right) + new_phns_mid = new_phns[left_idx:right_idx] + span_to_add[1] = len(new_phns_left) + len(new_phns_mid) + if len(new_phns_mid) == 0: + span_to_add[1] = min(span_to_add[1] + 1, len(new_phns)) + span_to_add[0] = max(0, span_to_add[0] - 1) + span_to_repl[0] = max(0, span_to_repl[0] - 1) + span_to_repl[1] = min(span_to_repl[1] + 1, + len(old_phns)) break - new_phns = new_phns_left + new_phns_middle + new_phns_right + new_phns = new_phns_left + new_phns_mid + new_phns_right - return mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added + return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add -def duration_adjust_factor(original_dur, pred_dur, phns): +def duration_adjust_factor(original_dur: List[int], + pred_dur: List[int], + phns: List[str]): length = 0 - accumulate = 0 factor_list = [] for ori, pred, phn in zip(original_dur, pred_dur, phns): if pred == 0 or phn == 'sp': @@ -481,242 +474,224 @@ def duration_adjust_factor(original_dur, pred_dur, phns): return np.average(factor_list[length:-length]) -def prepare_features_with_duration(uid, - prefix, - clone_uid, - clone_prefix, - source_language, - target_language, - mlm_model, - old_str, - new_str, - wav_path, - duration_preditor_path, - sid=None, - mask_reconstruct=False, - duration_adjust=True, - start_end_sp=False, +def prepare_features_with_duration(uid: str, + prefix: str, + wav_path: str, + mlm_model: nn.Layer, + source_lang: str="English", + target_lang: str="English", + old_str: str="", + new_str: str="", + duration_preditor_path: str=None, + sid: str=None, + mask_reconstruct: bool=False, + duration_adjust: bool=True, + start_end_sp: bool=False, train_args=None): wav_org, rate = librosa.load( wav_path, sr=train_args.feats_extract_conf['fs']) fs = train_args.feats_extract_conf['fs'] hop_length = train_args.feats_extract_conf['hop_length'] - mfa_start, mfa_end, old_phns, new_phns, span_tobe_replaced, span_tobe_added = get_phns_and_spans( - wav_path, old_str, new_str, source_language, target_language) + mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans( + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + source_lang=source_lang, + target_lang=target_lang) if start_end_sp: if new_phns[-1] != 'sp': new_phns = new_phns + ['sp'] - if target_language == "english": - old_durations = evaluate_durations( - old_phns, target_language=target_language) + if target_lang == "english": + old_durations = evaluate_durations(old_phns, target_lang=target_lang) - elif target_language == "chinese": + elif target_lang == "chinese": - if source_language == "english": + if source_lang == "english": old_durations = evaluate_durations( - old_phns, target_language=source_language) + old_phns, target_lang=source_lang) - elif source_language == "chinese": + elif source_lang == "chinese": old_durations = evaluate_durations( - old_phns, target_language=source_language) + old_phns, target_lang=source_lang) else: - assert target_language == "chinese" or target_language == "english", "calculate duration_predict is not support for this language..." + assert target_lang == "chinese" or target_lang == "english", "calculate duration_predict is not support for this language..." original_old_durations = [e - s for e, s in zip(mfa_end, mfa_start)] if '[MASK]' in new_str: new_phns = old_phns - span_tobe_added = span_tobe_replaced + span_to_add = span_to_repl d_factor_left = duration_adjust_factor( - original_old_durations[:span_tobe_replaced[0]], - old_durations[:span_tobe_replaced[0]], - old_phns[:span_tobe_replaced[0]]) + original_old_durations[:span_to_repl[0]], + old_durations[:span_to_repl[0]], old_phns[:span_to_repl[0]]) d_factor_right = duration_adjust_factor( - original_old_durations[span_tobe_replaced[1]:], - old_durations[span_tobe_replaced[1]:], - old_phns[span_tobe_replaced[1]:]) + original_old_durations[span_to_repl[1]:], + old_durations[span_to_repl[1]:], old_phns[span_to_repl[1]:]) d_factor = (d_factor_left + d_factor_right) / 2 new_durations_adjusted = [d_factor * i for i in old_durations] else: if duration_adjust: d_factor = duration_adjust_factor(original_old_durations, old_durations, old_phns) - d_factor_paddle = duration_adjust_factor(original_old_durations, - old_durations, old_phns) d_factor = d_factor * 1.25 else: d_factor = 1 - if target_language == "english": + if target_lang == "english": new_durations = evaluate_durations( - new_phns, target_language=target_language) + new_phns, target_lang=target_lang) - elif target_language == "chinese": + elif target_lang == "chinese": new_durations = evaluate_durations( - new_phns, target_language=target_language) + new_phns, target_lang=target_lang) new_durations_adjusted = [d_factor * i for i in new_durations] - if span_tobe_replaced[0] < len(old_phns) and old_phns[ - span_tobe_replaced[0]] == new_phns[span_tobe_added[0]]: - new_durations_adjusted[span_tobe_added[0]] = original_old_durations[ - span_tobe_replaced[0]] - if span_tobe_replaced[1] < len(old_phns) and span_tobe_added[1] < len( - new_phns): - if old_phns[span_tobe_replaced[1]] == new_phns[span_tobe_added[1]]: - new_durations_adjusted[span_tobe_added[ - 1]] = original_old_durations[span_tobe_replaced[1]] + if span_to_repl[0] < len(old_phns) and old_phns[span_to_repl[ + 0]] == new_phns[span_to_add[0]]: + new_durations_adjusted[span_to_add[0]] = original_old_durations[ + span_to_repl[0]] + if span_to_repl[1] < len(old_phns) and span_to_add[1] < len(new_phns): + if old_phns[span_to_repl[1]] == new_phns[span_to_add[1]]: + new_durations_adjusted[span_to_add[1]] = original_old_durations[ + span_to_repl[1]] new_span_duration_sum = sum( - new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]]) + new_durations_adjusted[span_to_add[0]:span_to_add[1]]) old_span_duration_sum = sum( - original_old_durations[span_tobe_replaced[0]:span_tobe_replaced[1]]) + original_old_durations[span_to_repl[0]:span_to_repl[1]]) duration_offset = new_span_duration_sum - old_span_duration_sum - new_mfa_start = mfa_start[:span_tobe_replaced[0]] - new_mfa_end = mfa_end[:span_tobe_replaced[0]] - for i in new_durations_adjusted[span_tobe_added[0]:span_tobe_added[1]]: + new_mfa_start = mfa_start[:span_to_repl[0]] + new_mfa_end = mfa_end[:span_to_repl[0]] + for i in new_durations_adjusted[span_to_add[0]:span_to_add[1]]: if len(new_mfa_end) == 0: new_mfa_start.append(0) new_mfa_end.append(i) else: new_mfa_start.append(new_mfa_end[-1]) new_mfa_end.append(new_mfa_end[-1] + i) - new_mfa_start += [ - i + duration_offset for i in mfa_start[span_tobe_replaced[1]:] - ] - new_mfa_end += [ - i + duration_offset for i in mfa_end[span_tobe_replaced[1]:] - ] + new_mfa_start += [i + duration_offset for i in mfa_start[span_to_repl[1]:]] + new_mfa_end += [i + duration_offset for i in mfa_end[span_to_repl[1]:]] # 3. get new wav - if span_tobe_replaced[0] >= len(mfa_start): - left_index = len(wav_org) - right_index = left_index + if span_to_repl[0] >= len(mfa_start): + left_idx = len(wav_org) + right_idx = left_idx else: - left_index = int(np.floor(mfa_start[span_tobe_replaced[0]] * fs)) - right_index = int(np.ceil(mfa_end[span_tobe_replaced[1] - 1] * fs)) + left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs)) + right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs)) new_blank_wav = np.zeros( (int(np.ceil(new_span_duration_sum * fs)), ), dtype=wav_org.dtype) new_wav_org = np.concatenate( - [wav_org[:left_index], new_blank_wav, wav_org[right_index:]]) + [wav_org[:left_idx], new_blank_wav, wav_org[right_idx:]]) # 4. get old and new mel span to be mask - old_span_boundary = get_masked_mel_boundary( - mfa_start, mfa_end, fs, hop_length, span_tobe_replaced) # [92, 92] - new_span_boundary = get_masked_mel_boundary(new_mfa_start, new_mfa_end, fs, - hop_length, - span_tobe_added) # [92, 174] - - return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_boundary, new_span_boundary - - -def prepare_features(uid, - prefix, - clone_uid, - clone_prefix, - source_language, - target_language, - mlm_model, + # [92, 92] + old_span_bdy = get_masked_mel_bdy(mfa_start, mfa_end, fs, hop_length, + span_to_repl) + # [92, 174] + new_span_bdy = get_masked_mel_bdy(new_mfa_start, new_mfa_end, fs, + hop_length, span_to_add) + + return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy + + +def prepare_features(uid: str, + mlm_model: nn.Layer, processor, - wav_path, - old_str, - new_str, - duration_preditor_path, - sid=None, - duration_adjust=True, - start_end_sp=False, - mask_reconstruct=False, + wav_path: str, + prefix: str="./prompt/dev/", + source_lang: str="english", + target_lang: str="english", + old_str: str="", + new_str: str="", + duration_preditor_path: str=None, + sid: str=None, + duration_adjust: bool=True, + start_end_sp: bool=False, + mask_reconstruct: bool=False, train_args=None): - wav_org, phns_list, mfa_start, mfa_end, old_span_boundary, new_span_boundary = prepare_features_with_duration( - uid, - prefix, - clone_uid, - clone_prefix, - source_language, - target_language, - mlm_model, - old_str, - new_str, - wav_path, - duration_preditor_path, + wav_org, phns_list, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prepare_features_with_duration( + uid=uid, + prefix=prefix, + source_lang=source_lang, + target_lang=target_lang, + mlm_model=mlm_model, + old_str=old_str, + new_str=new_str, + wav_path=wav_path, + duration_preditor_path=duration_preditor_path, sid=sid, duration_adjust=duration_adjust, start_end_sp=start_end_sp, mask_reconstruct=mask_reconstruct, train_args=train_args) - speech = np.array(wav_org, dtype=np.float32) + speech = wav_org align_start = np.array(mfa_start) align_end = np.array(mfa_end) token_to_id = {item: i for i, item in enumerate(train_args.token_list)} text = np.array( list( map(lambda x: token_to_id.get(x, token_to_id['']), phns_list))) - # print('unk id is', token_to_id['']) - # text = np.array(processor(uid='1', data={'text':" ".join(phns_list)})['text']) - span_boundary = np.array(new_span_boundary) + + span_bdy = np.array(new_span_bdy) batch = [('1', { "speech": speech, "align_start": align_start, "align_end": align_end, "text": text, - "span_boundary": span_boundary + "span_bdy": span_bdy })] - return batch, old_span_boundary, new_span_boundary + return batch, old_span_bdy, new_span_bdy -def decode_with_model(uid, - prefix, - clone_uid, - clone_prefix, - source_language, - target_language, - mlm_model, +def decode_with_model(uid: str, + mlm_model: nn.Layer, processor, collate_fn, - wav_path, - old_str, - new_str, - duration_preditor_path, - sid=None, - decoder=False, - use_teacher_forcing=False, - duration_adjust=True, - start_end_sp=False, + wav_path: str, + prefix: str="./prompt/dev/", + source_lang: str="english", + target_lang: str="english", + old_str: str="", + new_str: str="", + duration_preditor_path: str=None, + sid: str=None, + decoder: bool=False, + use_teacher_forcing: bool=False, + duration_adjust: bool=True, + start_end_sp: bool=False, train_args=None): fs, hop_length = train_args.feats_extract_conf[ 'fs'], train_args.feats_extract_conf['hop_length'] - batch, old_span_boundary, new_span_boundary = prepare_features( - uid, - prefix, - clone_uid, - clone_prefix, - source_language, - target_language, - mlm_model, - processor, - wav_path, - old_str, - new_str, - duration_preditor_path, - sid, + batch, old_span_bdy, new_span_bdy = prepare_features( + uid=uid, + prefix=prefix, + source_lang=source_lang, + target_lang=target_lang, + mlm_model=mlm_model, + processor=processor, + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + duration_preditor_path=duration_preditor_path, + sid=sid, duration_adjust=duration_adjust, start_end_sp=start_end_sp, train_args=train_args) feats = collate_fn(batch)[1] - if 'text_masked_position' in feats.keys(): - feats.pop('text_masked_position') + if 'text_masked_pos' in feats.keys(): + feats.pop('text_masked_pos') for k, v in feats.items(): feats[k] = paddle.to_tensor(v) rtn = mlm_model.inference( - **feats, - span_boundary=new_span_boundary, - use_teacher_forcing=use_teacher_forcing) + **feats, span_bdy=new_span_bdy, use_teacher_forcing=use_teacher_forcing) output = rtn['feat_gen'] if 0 in output[0].shape and 0 not in output[-1].shape: output_feat = paddle.concat( @@ -731,12 +706,9 @@ def decode_with_model(uid, [output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)], axis=0).cpu() - wav_org, rate = librosa.load( + wav_org, _ = librosa.load( wav_path, sr=train_args.feats_extract_conf['fs']) - origin_speech = paddle.to_tensor( - np.array(wav_org, dtype=np.float32)).unsqueeze(0) - speech_lengths = paddle.to_tensor(len(wav_org)).unsqueeze(0) - return wav_org, None, output_feat, old_span_boundary, new_span_boundary, fs, hop_length + return wav_org, None, output_feat, old_span_bdy, new_span_bdy, fs, hop_length class MLMCollateFn: @@ -800,33 +772,15 @@ def mlm_collate_fn( sega_emb: bool=False, duration_collect: bool=False, text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]: - """Concatenate ndarray-list to an array and convert to torch.Tensor. - - Examples: - >>> from espnet2.samplers.constant_batch_sampler import ConstantBatchSampler, - >>> import espnet2.tasks.abs_task - >>> from espnet2.train.dataset import ESPnetDataset - >>> sampler = ConstantBatchSampler(...) - >>> dataset = ESPnetDataset(...) - >>> keys = next(iter(sampler) - >>> batch = [dataset[key] for key in keys] - >>> batch = common_collate_fn(batch) - >>> model(**batch) - - Note that the dict-keys of batch are propagated from - that of the dataset as they are. - - """ uttids = [u for u, _ in data] data = [d for _, d in data] assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" - assert all(not k.endswith("_lengths") - for k in data[0]), f"*_lengths is reserved: {list(data[0])}" + assert all(not k.endswith("_lens") + for k in data[0]), f"*_lens is reserved: {list(data[0])}" output = {} for key in data[0]: - # NOTE(kamo): # Each models, which accepts these values finally, are responsible # to repaint the pad_value to the desired value for each tasks. if data[0][key].dtype.kind == "i": @@ -846,37 +800,35 @@ def mlm_collate_fn( # lens: (Batch,) if key not in not_sequence: lens = paddle.to_tensor( - [d[key].shape[0] for d in data], dtype=paddle.long) - output[key + "_lengths"] = lens + [d[key].shape[0] for d in data], dtype=paddle.int64) + output[key + "_lens"] = lens feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0])) feats = paddle.to_tensor(feats) - # print('out shape', paddle.shape(feats)) - feats_lengths = paddle.shape(feats)[0] + feats_lens = paddle.shape(feats)[0] feats = paddle.unsqueeze(feats, 0) - batch_size = paddle.shape(feats)[0] if 'text' not in output: - text = paddle.zeros_like(feats_lengths.unsqueeze(-1)) - 2 - text_lengths = paddle.zeros_like(feats_lengths) + 1 + text = paddle.zeros(paddle.shape(feats_lens.unsqueeze(-1))) - 2 + text_lens = paddle.zeros(paddle.shape(feats_lens)) + 1 max_tlen = 1 - align_start = paddle.zeros_like(text) - align_end = paddle.zeros_like(text) - align_start_lengths = paddle.zeros_like(feats_lengths) - align_end_lengths = paddle.zeros_like(feats_lengths) + align_start = paddle.zeros(paddle.shape(text)) + align_end = paddle.zeros(paddle.shape(text)) + align_start_lens = paddle.zeros(paddle.shape(feats_lens)) sega_emb = False mean_phn_span = 0 mlm_prob = 0.15 else: - text, text_lengths = output["text"], output["text_lengths"] - align_start, align_start_lengths, align_end, align_end_lengths = output[ - "align_start"], output["align_start_lengths"], output[ - "align_end"], output["align_end_lengths"] + text = output["text"] + text_lens = output["text_lens"] + align_start = output["align_start"] + align_start_lens = output["align_start_lens"] + align_end = output["align_end"] align_start = paddle.floor(feats_extract.sr * align_start / feats_extract.hop_length).int() align_end = paddle.floor(feats_extract.sr * align_end / feats_extract.hop_length).int() - max_tlen = max(text_lengths).item() - max_slen = max(feats_lengths).item() + max_tlen = max(text_lens) + max_slen = max(feats_lens) speech_pad = feats[:, :max_slen] if attention_window > 0 and pad_speech: speech_pad, max_slen = pad_to_longformer_att_window( @@ -888,51 +840,49 @@ def mlm_collate_fn( else: text_pad = text text_mask = make_non_pad_mask( - text_lengths.tolist(), text_pad, length_dim=1).unsqueeze(-2) + text_lens, text_pad, length_dim=1).unsqueeze(-2) if attention_window > 0: text_mask = text_mask * 2 speech_mask = make_non_pad_mask( - feats_lengths.tolist(), speech_pad[:, :, 0], length_dim=1).unsqueeze(-2) - span_boundary = None - if 'span_boundary' in output.keys(): - span_boundary = output['span_boundary'] + feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2) + span_bdy = None + if 'span_bdy' in output.keys(): + span_bdy = output['span_bdy'] if text_masking: - masked_position, text_masked_position, _ = phones_text_masking( + masked_pos, text_masked_pos, _ = phones_text_masking( speech_pad, speech_mask, text_pad, text_mask, align_start, - align_end, align_start_lengths, mlm_prob, mean_phn_span, - span_boundary) + align_end, align_start_lens, mlm_prob, mean_phn_span, span_bdy) else: - text_masked_position = np.zeros(text_pad.size()) - masked_position, _ = phones_masking( - speech_pad, speech_mask, align_start, align_end, - align_start_lengths, mlm_prob, mean_phn_span, span_boundary) + text_masked_pos = paddle.zeros(paddle.shape(text_pad)) + masked_pos, _ = phones_masking(speech_pad, speech_mask, align_start, + align_end, align_start_lens, mlm_prob, + mean_phn_span, span_bdy) output_dict = {} if duration_collect and 'text' in output: - reordered_index, speech_segment_pos, text_segment_pos, durations, feats_lengths = get_segment_pos_reduce_duration( - speech_pad, text_pad, align_start, align_end, align_start_lengths, - sega_emb, masked_position, feats_lengths) + reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens = get_seg_pos_reduce_duration( + speech_pad, text_pad, align_start, align_end, align_start_lens, + sega_emb, masked_pos, feats_lens) speech_mask = make_non_pad_mask( - feats_lengths.tolist(), - speech_pad[:, :reordered_index.shape[1], 0], + feats_lens, speech_pad[:, :reordered_idx.shape[1], 0], length_dim=1).unsqueeze(-2) output_dict['durations'] = durations - output_dict['reordered_index'] = reordered_index + output_dict['reordered_idx'] = reordered_idx else: - speech_segment_pos, text_segment_pos = get_segment_pos( - speech_pad, text_pad, align_start, align_end, align_start_lengths, - sega_emb) + speech_seg_pos, text_seg_pos = get_seg_pos(speech_pad, text_pad, + align_start, align_end, + align_start_lens, sega_emb) output_dict['speech'] = speech_pad output_dict['text'] = text_pad - output_dict['masked_position'] = masked_position - output_dict['text_masked_position'] = text_masked_position + output_dict['masked_pos'] = masked_pos + output_dict['text_masked_pos'] = text_masked_pos output_dict['speech_mask'] = speech_mask output_dict['text_mask'] = text_mask - output_dict['speech_segment_pos'] = speech_segment_pos - output_dict['text_segment_pos'] = text_segment_pos - output_dict['speech_lengths'] = output["speech_lengths"] - output_dict['text_lengths'] = text_lengths + output_dict['speech_seg_pos'] = speech_seg_pos + output_dict['text_seg_pos'] = text_seg_pos + output_dict['speech_lens'] = output["speech_lens"] + output_dict['text_lens'] = text_lens output = (uttids, output_dict) return output @@ -940,13 +890,13 @@ def mlm_collate_fn( def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1): # -> Callable[ # [Collection[Tuple[str, Dict[str, np.ndarray]]]], - # Tuple[List[str], Dict[str, torch.Tensor]], + # Tuple[List[str], Dict[str, Tensor]], # ]: # assert check_argument_types() # return CommonCollateFn(float_pad_value=0.0, int_pad_value=0) feats_extract_class = LogMelFBank - if args.feats_extract_conf['win_length'] == None: + if args.feats_extract_conf['win_length'] is None: args.feats_extract_conf['win_length'] = args.feats_extract_conf['n_fft'] args_dic = {} @@ -955,7 +905,6 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1): args_dic['sr'] = v else: args_dic[k] = v - # feats_extract = feats_extract_class(**args.feats_extract_conf) feats_extract = feats_extract_class(**args_dic) sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False @@ -969,8 +918,7 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1): if epoch == -1: mlm_prob_factor = 1 else: - mlm_probs = [1.0, 1.0, 0.7, 0.6, 0.5] - mlm_prob_factor = 0.8 #mlm_probs[epoch // 100] + mlm_prob_factor = 0.8 if 'duration_predictor_layers' in args.model_conf.keys( ) and args.model_conf['duration_predictor_layers'] > 0: duration_collect = True @@ -989,42 +937,37 @@ def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1): duration_collect=duration_collect) -def get_mlm_output(uid, - prefix, - clone_uid, - clone_prefix, - source_language, - target_language, - model_name, - wav_path, - old_str, - new_str, - duration_preditor_path, - sid=None, - decoder=False, - use_teacher_forcing=False, - dynamic_eval=(0, 0), - duration_adjust=True, - start_end_sp=False): +def get_mlm_output(uid: str, + wav_path: str, + prefix: str="./prompt/dev/", + model_name: str="conformer", + source_lang: str="english", + target_lang: str="english", + old_str: str="", + new_str: str="", + duration_preditor_path: str=None, + sid: str=None, + decoder: bool=False, + use_teacher_forcing: bool=False, + duration_adjust: bool=True, + start_end_sp: bool=False): mlm_model, train_args = load_model(model_name) mlm_model.eval() processor = None collate_fn = build_collate_fn(train_args, False) return decode_with_model( - uid, - prefix, - clone_uid, - clone_prefix, - source_language, - target_language, - mlm_model, - processor, - collate_fn, - wav_path, - old_str, - new_str, - duration_preditor_path, + uid=uid, + prefix=prefix, + source_lang=source_lang, + target_lang=target_lang, + mlm_model=mlm_model, + processor=processor, + collate_fn=collate_fn, + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + duration_preditor_path=duration_preditor_path, sid=sid, decoder=decoder, use_teacher_forcing=use_teacher_forcing, @@ -1033,23 +976,20 @@ def get_mlm_output(uid, train_args=train_args) -def test_vctk(uid, - clone_uid, - clone_prefix, - source_language, - target_language, - vocoder, - prefix='dump/raw/dev', - model_name="conformer", - old_str="", - new_str="", - prompt_decoding=False, - dynamic_eval=(0, 0), - task_name=None): +def evaluate(uid: str, + source_lang: str="english", + target_lang: str="english", + use_pt_vocoder: bool=False, + prefix: str="./prompt/dev/", + model_name: str="conformer", + old_str: str="", + new_str: str="", + prompt_decoding: bool=False, + task_name: str=None): duration_preditor_path = None spemd = None - full_origin_str, wav_path = read_data(uid, prefix) + full_origin_str, wav_path = read_data(uid=uid, prefix=prefix) if task_name == 'edit': new_str = new_str @@ -1065,19 +1005,17 @@ def test_vctk(uid, old_str = full_origin_str results_dict, old_span = plot_mel_and_vocode_wav( - uid, - prefix, - clone_uid, - clone_prefix, - source_language, - target_language, - model_name, - wav_path, - full_origin_str, - old_str, - new_str, - vocoder, - duration_preditor_path, + uid=uid, + prefix=prefix, + source_lang=source_lang, + target_lang=target_lang, + model_name=model_name, + wav_path=wav_path, + full_origin_str=full_origin_str, + old_str=old_str, + new_str=new_str, + use_pt_vocoder=use_pt_vocoder, + duration_preditor_path=duration_preditor_path, sid=spemd) return results_dict @@ -1086,17 +1024,14 @@ if __name__ == "__main__": # parse config and args args = parse_args() - data_dict = test_vctk( - args.uid, - args.clone_uid, - args.clone_prefix, - args.source_language, - args.target_language, - args.use_pt_vocoder, - args.prefix, - args.model_name, + data_dict = evaluate( + uid=args.uid, + source_lang=args.source_lang, + target_lang=args.target_lang, + use_pt_vocoder=args.use_pt_vocoder, + prefix=args.prefix, + model_name=args.model_name, new_str=args.new_str, task_name=args.task_name) sf.write(args.output_name, data_dict['output'], samplerate=24000) print("finished...") - # exit() diff --git a/ernie-sat/model_paddle.py b/ernie-sat/model_paddle.py index a93c7a410595f03d7bebac5d1cbc822fe6fb7472..f33c49ed233db83e1f7785a68a36bfeaad6a37df 100644 --- a/ernie-sat/model_paddle.py +++ b/ernie-sat/model_paddle.py @@ -121,12 +121,10 @@ class NewMaskInputLayer(nn.Layer): default_initializer=paddle.nn.initializer.Assign( paddle.normal(shape=(1, 1, out_features)))) - def forward(self, input: paddle.Tensor, - masked_position=None) -> paddle.Tensor: - masked_position = paddle.expand_as( - paddle.unsqueeze(masked_position, -1), input) - masked_input = masked_fill(input, masked_position, 0) + masked_fill( - paddle.expand_as(self.mask_feature, input), ~masked_position, 0) + def forward(self, input: paddle.Tensor, masked_pos=None) -> paddle.Tensor: + masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input) + masked_input = masked_fill(input, masked_pos, 0) + masked_fill( + paddle.expand_as(self.mask_feature, input), ~masked_pos, 0) return masked_input @@ -443,37 +441,34 @@ class MLMEncoder(nn.Layer): def forward(self, speech_pad, text_pad, - masked_position, + masked_pos, speech_mask=None, text_mask=None, - speech_segment_pos=None, - text_segment_pos=None): + speech_seg_pos=None, + text_seg_pos=None): """Encode input sequence. """ - if masked_position is not None: - speech_pad = self.speech_embed(speech_pad, masked_position) + if masked_pos is not None: + speech_pad = self.speech_embed(speech_pad, masked_pos) else: speech_pad = self.speech_embed(speech_pad) # pure speech input if -2 in np.array(text_pad): text_pad = text_pad + 3 text_mask = paddle.unsqueeze(bool(text_pad), 1) - text_segment_pos = paddle.zeros_like(text_pad) + text_seg_pos = paddle.zeros_like(text_pad) text_pad = self.text_embed(text_pad) - text_pad = (text_pad[0] + self.segment_emb(text_segment_pos), + text_pad = (text_pad[0] + self.segment_emb(text_seg_pos), text_pad[1]) - text_segment_pos = None + text_seg_pos = None elif text_pad is not None: text_pad = self.text_embed(text_pad) - segment_emb = None - if speech_segment_pos is not None and text_segment_pos is not None and self.segment_emb: - speech_segment_emb = self.segment_emb(speech_segment_pos) - text_segment_emb = self.segment_emb(text_segment_pos) - text_pad = (text_pad[0] + text_segment_emb, text_pad[1]) - speech_pad = (speech_pad[0] + speech_segment_emb, speech_pad[1]) - segment_emb = paddle.concat( - [speech_segment_emb, text_segment_emb], axis=1) + if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb: + speech_seg_emb = self.segment_emb(speech_seg_pos) + text_seg_emb = self.segment_emb(text_seg_pos) + text_pad = (text_pad[0] + text_seg_emb, text_pad[1]) + speech_pad = (speech_pad[0] + speech_seg_emb, speech_pad[1]) if self.pre_speech_encoders: speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask) @@ -493,11 +488,11 @@ class MLMEncoder(nn.Layer): if self.normalize_before: xs = self.after_norm(xs) - return xs, masks #, segment_emb + return xs, masks class MLMDecoder(MLMEncoder): - def forward(self, xs, masks, masked_position=None, segment_emb=None): + def forward(self, xs, masks, masked_pos=None, segment_emb=None): """Encode input sequence. Args: @@ -509,9 +504,8 @@ class MLMDecoder(MLMEncoder): paddle.Tensor: Mask tensor (#batch, time). """ - emb, mlm_position = None, None if not self.training: - masked_position = None + masked_pos = None xs = self.embed(xs) if segment_emb: xs = (xs[0] + segment_emb, xs[1]) @@ -632,18 +626,18 @@ class MLMModel(nn.Layer): def collect_feats(self, speech, - speech_lengths, + speech_lens, text, - text_lengths, - masked_position, + text_lens, + masked_pos, speech_mask, text_mask, - speech_segment_pos, - text_segment_pos, + speech_seg_pos, + text_seg_pos, y_masks=None) -> Dict[str, paddle.Tensor]: - return {"feats": speech, "feats_lengths": speech_lengths} + return {"feats": speech, "feats_lens": speech_lens} - def forward(self, batch, speech_segment_pos, y_masks=None): + def forward(self, batch, speech_seg_pos, y_masks=None): # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) speech_pad_placeholder = batch['speech_pad'] @@ -654,7 +648,7 @@ class MLMModel(nn.Layer): if self.decoder is not None: zs, _ = self.decoder(ys_in, y_masks, encoder_out, bool(h_masks), - self.encoder.segment_emb(speech_segment_pos)) + self.encoder.segment_emb(speech_seg_pos)) speech_hidden_states = zs else: speech_hidden_states = encoder_out[:, :paddle.shape(batch[ @@ -672,21 +666,21 @@ class MLMModel(nn.Layer): else: after_outs = None return before_outs, after_outs, speech_pad_placeholder, batch[ - 'masked_position'] + 'masked_pos'] def inference( self, speech, text, - masked_position, + masked_pos, speech_mask, text_mask, - speech_segment_pos, - text_segment_pos, - span_boundary, + speech_seg_pos, + text_seg_pos, + span_bdy, y_masks=None, - speech_lengths=None, - text_lengths=None, + speech_lens=None, + text_lens=None, feats: Optional[paddle.Tensor]=None, spembs: Optional[paddle.Tensor]=None, sids: Optional[paddle.Tensor]=None, @@ -699,24 +693,24 @@ class MLMModel(nn.Layer): batch = dict( speech_pad=speech, text_pad=text, - masked_position=masked_position, + masked_pos=masked_pos, speech_mask=speech_mask, text_mask=text_mask, - speech_segment_pos=speech_segment_pos, - text_segment_pos=text_segment_pos, ) + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos, ) # # inference with teacher forcing # hs, h_masks = self.encoder(**batch) - outs = [batch['speech_pad'][:, :span_boundary[0]]] + outs = [batch['speech_pad'][:, :span_bdy[0]]] z_cache = None if use_teacher_forcing: before, zs, _, _ = self.forward( - batch, speech_segment_pos, y_masks=y_masks) + batch, speech_seg_pos, y_masks=y_masks) if zs is None: zs = before - outs += [zs[0][span_boundary[0]:span_boundary[1]]] - outs += [batch['speech_pad'][:, span_boundary[1]:]] + outs += [zs[0][span_bdy[0]:span_bdy[1]]] + outs += [batch['speech_pad'][:, span_bdy[1]:]] return dict(feat_gen=outs) return None @@ -733,7 +727,7 @@ class MLMModel(nn.Layer): class MLMEncAsDecoderModel(MLMModel): - def forward(self, batch, speech_segment_pos, y_masks=None): + def forward(self, batch, speech_seg_pos, y_masks=None): # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) speech_pad_placeholder = batch['speech_pad'] @@ -756,7 +750,7 @@ class MLMEncAsDecoderModel(MLMModel): else: after_outs = None return before_outs, after_outs, speech_pad_placeholder, batch[ - 'masked_position'] + 'masked_pos'] class MLMDualMaksingModel(MLMModel): @@ -767,9 +761,9 @@ class MLMDualMaksingModel(MLMModel): batch): xs_pad = batch['speech_pad'] text_pad = batch['text_pad'] - masked_position = batch['masked_position'] - text_masked_position = batch['text_masked_position'] - mlm_loss_position = masked_position > 0 + masked_pos = batch['masked_pos'] + text_masked_pos = batch['text_masked_pos'] + mlm_loss_pos = masked_pos > 0 loss = paddle.sum( self.l1_loss_func( paddle.reshape(before_outs, (-1, self.odim)), @@ -782,19 +776,17 @@ class MLMDualMaksingModel(MLMModel): paddle.reshape(xs_pad, (-1, self.odim))), axis=-1) loss_mlm = paddle.sum((loss * paddle.reshape( - mlm_loss_position, [-1]))) / paddle.sum((mlm_loss_position) + 1e-10) + mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10) loss_text = paddle.sum((self.text_mlm_loss( paddle.reshape(text_outs, (-1, self.vocab_size)), paddle.reshape(text_pad, (-1))) * paddle.reshape( - text_masked_position, - (-1)))) / paddle.sum((text_masked_position) + 1e-10) + text_masked_pos, (-1)))) / paddle.sum((text_masked_pos) + 1e-10) return loss_mlm, loss_text - def forward(self, batch, speech_segment_pos, y_masks=None): + def forward(self, batch, speech_seg_pos, y_masks=None): # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) - speech_pad_placeholder = batch['speech_pad'] encoder_out, h_masks = self.encoder(**batch) # segment_emb if self.decoder is not None: zs, _ = self.decoder(encoder_out, h_masks) @@ -819,7 +811,7 @@ class MLMDualMaksingModel(MLMModel): [0, 2, 1]) else: after_outs = None - return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_position'],batch['text_masked_position'] + return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_pos'],batch['text_masked_pos'] def build_model_from_file(config_file, model_file): diff --git a/ernie-sat/paddlespeech/t2s/modules/nets_utils.py b/ernie-sat/paddlespeech/t2s/modules/nets_utils.py index 4207d316c4d07922924a649b0cb5ae45f6032450..097bf1d335dc517913a8be40b8f02f28aed29054 100644 --- a/ernie-sat/paddlespeech/t2s/modules/nets_utils.py +++ b/ernie-sat/paddlespeech/t2s/modules/nets_utils.py @@ -38,7 +38,7 @@ def pad_list(xs, pad_value): """ n_batch = len(xs) max_len = max(x.shape[0] for x in xs) - pad = paddle.full([n_batch, max_len, *xs[0].shape[1:]], pad_value) + pad = paddle.full([n_batch, max_len, *xs[0].shape[1:]], pad_value, dtype=xs[0].dtype) for i in range(n_batch): pad[i, :xs[i].shape[0]] = xs[i] @@ -46,13 +46,18 @@ def pad_list(xs, pad_value): return pad -def make_pad_mask(lengths, length_dim=-1): + +def make_pad_mask(lengths, xs=None, length_dim=-1): """Make mask tensor containing indices of padded part. Args: lengths (Tensor(int64)): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. - Returns: + Returns: Tensor(bool): Mask tensor containing indices of padded part bool. Examples: @@ -61,23 +66,98 @@ def make_pad_mask(lengths, length_dim=-1): >>> lengths = [5, 3, 2] >>> make_non_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], - [0, 0, 0, 1, 1], - [0, 0, 1, 1, 1]] + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + + With the reference tensor. + + >>> xs = paddle.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]]) + >>> xs = paddle.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]]) + + With the reference tensor and dimension indicator. + + >>> xs = paddle.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]]) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]],) + """ if length_dim == 0: raise ValueError("length_dim cannot be 0: {}".format(length_dim)) bs = paddle.shape(lengths)[0] - maxlen = lengths.max() + if xs is None: + maxlen = lengths.max() + else: + maxlen = paddle.shape(xs)[length_dim] + seq_range = paddle.arange(0, maxlen, dtype=paddle.int64) seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen]) seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand - return mask + if xs is not None: + assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs) + if length_dim < 0: + length_dim = len(paddle.shape(xs)) + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple( + slice(None) if i in (0, length_dim) else None + for i in range(len(paddle.shape(xs)))) + mask = paddle.expand(mask[ind], paddle.shape(xs)) + return mask -def make_non_pad_mask(lengths, length_dim=-1): +def make_non_pad_mask(lengths, xs=None, length_dim=-1): """Make mask tensor containing indices of non-padded part. Args: @@ -90,16 +170,78 @@ def make_non_pad_mask(lengths, length_dim=-1): Returns: Tensor(bool): mask tensor containing indices of padded part bool. - Examples: + Examples: With only lengths. >>> lengths = [5, 3, 2] >>> make_non_pad_mask(lengths) masks = [[1, 1, 1, 1 ,1], - [1, 1, 1, 0, 0], - [1, 1, 0, 0, 0]] + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + + With the reference tensor. + + >>> xs = paddle.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]]) + >>> xs = paddle.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]]) + + With the reference tensor and dimension indicator. + + >>> xs = paddle.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]]) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]]) + """ - return paddle.logical_not(make_pad_mask(lengths, length_dim)) + return paddle.logical_not(make_pad_mask(lengths, xs, length_dim)) def initialize(model: nn.Layer, init: str): diff --git a/ernie-sat/run_clone_en_to_zh.sh b/ernie-sat/run_clone_en_to_zh.sh index 9266b67f1609d338b9fc2ac816dc29ba752eea78..b30b494b06f84f36c99463f17e068709db5370cb 100755 --- a/ernie-sat/run_clone_en_to_zh.sh +++ b/ernie-sat/run_clone_en_to_zh.sh @@ -10,8 +10,8 @@ python inference.py \ --uid=Prompt_003_new \ --new_str='今天天气很好.' \ --prefix='./prompt/dev/' \ - --source_language=english \ - --target_language=chinese \ + --source_lang=english \ + --target_lang=chinese \ --output_name=pred_clone.wav \ --use_pt_vocoder=False \ --voc=pwgan_aishell3 \ diff --git a/ernie-sat/run_gen_en.sh b/ernie-sat/run_gen_en.sh index 1f7fbf69068199dce3c526a9693b0c6f7bbbd333..ecf6bb82c8d912a7273f5dd86e658c750503ee16 100755 --- a/ernie-sat/run_gen_en.sh +++ b/ernie-sat/run_gen_en.sh @@ -9,8 +9,8 @@ python inference.py \ --uid=p299_096 \ --new_str='I enjoy my life, do you?' \ --prefix='./prompt/dev/' \ - --source_language=english \ - --target_language=english \ + --source_lang=english \ + --target_lang=english \ --output_name=pred_gen.wav \ --use_pt_vocoder=False \ --voc=pwgan_aishell3 \ diff --git a/ernie-sat/run_sedit_en.sh b/ernie-sat/run_sedit_en.sh index 4c2b248b9b3a1803815fe1d8be1903753efb84fb..7421f0a9c8d9b2739b208630886ec3bfb7048006 100755 --- a/ernie-sat/run_sedit_en.sh +++ b/ernie-sat/run_sedit_en.sh @@ -10,8 +10,8 @@ python inference.py \ --uid=p243_new \ --new_str='for that reason cover is impossible to be given.' \ --prefix='./prompt/dev/' \ - --source_language=english \ - --target_language=english \ + --source_lang=english \ + --target_lang=english \ --output_name=pred_edit.wav \ --use_pt_vocoder=False \ --voc=pwgan_aishell3 \ diff --git a/ernie-sat/sedit_arg_parser.py b/ernie-sat/sedit_arg_parser.py index a8b8f6feec80c6a08655c3a64e95fa816acaa0ca..01d0b47ef7831c973fe54f70c21c8f9600b53d52 100644 --- a/ernie-sat/sedit_arg_parser.py +++ b/ernie-sat/sedit_arg_parser.py @@ -80,10 +80,8 @@ def parse_args(): parser.add_argument("--uid", type=str, help="uid") parser.add_argument("--new_str", type=str, help="new string") parser.add_argument("--prefix", type=str, help="prefix") - parser.add_argument("--clone_prefix", type=str, default=None, help="clone prefix") - parser.add_argument("--clone_uid", type=str, default=None, help="clone uid") - parser.add_argument("--source_language", type=str, help="source language") - parser.add_argument("--target_language", type=str, help="target language") + parser.add_argument("--source_lang", type=str, default="english", help="source language") + parser.add_argument("--target_lang", type=str, default="english", help="target language") parser.add_argument("--output_name", type=str, help="output name") parser.add_argument("--task_name", type=str, help="task name") parser.add_argument( diff --git a/ernie-sat/tools/parallel_wavegan_pretrained_vocoder.py b/ernie-sat/tools/torch_pwgan.py similarity index 96% rename from ernie-sat/tools/parallel_wavegan_pretrained_vocoder.py rename to ernie-sat/tools/torch_pwgan.py index 856a0c0e8cb004b9b142ec6a88a6c4bcf235fde8..9295f65708ea6ef58b612aa9993855c6b70357cf 100644 --- a/ernie-sat/tools/parallel_wavegan_pretrained_vocoder.py +++ b/ernie-sat/tools/torch_pwgan.py @@ -9,7 +9,7 @@ import torch import yaml -class ParallelWaveGANPretrainedVocoder(torch.nn.Module): +class TorchPWGAN(torch.nn.Module): """Wrapper class to load the vocoder trained with parallel_wavegan repo.""" def __init__( diff --git a/ernie-sat/utils.py b/ernie-sat/utils.py index cb1e93c265b9baa1ae04d99b6d870a9f5f7fbbba..1cd4f0083afda00d0c0af27dfad7e21cdb96a903 100644 --- a/ernie-sat/utils.py +++ b/ernie-sat/utils.py @@ -1,3 +1,7 @@ +import os +from typing import List +from typing import Optional + import numpy as np import paddle import yaml @@ -5,11 +9,8 @@ from sedit_arg_parser import parse_args from yacs.config import CfgNode from paddlespeech.s2t.utils.dynamic_import import dynamic_import -from paddlespeech.t2s.exps.syn_utils import get_frontend -from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.modules.normalizer import ZScore -from tools.parallel_wavegan_pretrained_vocoder import ParallelWaveGANPretrainedVocoder -# new add +from tools.torch_pwgan import TorchPWGAN model_alias = { # acoustic model @@ -25,6 +26,10 @@ model_alias = { "paddlespeech.t2s.models.tacotron2:Tacotron2", "tacotron2_inference": "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", + "pwgan": + "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", + "pwgan_inference": + "paddlespeech.t2s.models.parallel_wavegan:PWGInference", } @@ -43,60 +48,65 @@ def build_vocoder_from_file( # Build vocoder if str(vocoder_file).endswith(".pkl"): # If the extension is ".pkl", the model is trained with parallel_wavegan - vocoder = ParallelWaveGANPretrainedVocoder(vocoder_file, - vocoder_config_file) + vocoder = TorchPWGAN(vocoder_file, vocoder_config_file) return vocoder.to(device) else: raise ValueError(f"{vocoder_file} is not supported format.") -def get_voc_out(mel, target_language="chinese"): +def get_voc_out(mel, target_lang: str="chinese"): # vocoder args = parse_args() - assert target_language == "chinese" or target_language == "english", "In get_voc_out function, target_language is illegal..." + assert target_lang == "chinese" or target_lang == "english", "In get_voc_out function, target_lang is illegal..." # print("current vocoder: ", args.voc) with open(args.voc_config) as f: voc_config = CfgNode(yaml.safe_load(f)) - # print(voc_config) - - voc_inference = get_voc_inference(args, voc_config) + voc_inference = voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) - mel = paddle.to_tensor(mel) - # print("masked_mel: ", mel.shape) with paddle.no_grad(): wav = voc_inference(mel) - # print("shepe of wav (time x n_channels):%s"%wav.shape) return np.squeeze(wav) # dygraph -def get_am_inference(args, am_config): - with open(args.phones_dict, "r") as f: +def get_am_inference(am: str='fastspeech2_csmsc', + am_config: CfgNode=None, + am_ckpt: Optional[os.PathLike]=None, + am_stat: Optional[os.PathLike]=None, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, + return_am: bool=False): + with open(phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] vocab_size = len(phn_id) - # print("vocab_size:", vocab_size) + print("vocab_size:", vocab_size) tone_size = None - if 'tones_dict' in args and args.tones_dict: - with open(args.tones_dict, "r") as f: + if tones_dict is not None: + with open(tones_dict, "r") as f: tone_id = [line.strip().split() for line in f.readlines()] tone_size = len(tone_id) print("tone_size:", tone_size) spk_num = None - if 'speaker_dict' in args and args.speaker_dict: - with open(args.speaker_dict, 'rt') as f: + if speaker_dict is not None: + with open(speaker_dict, 'rt') as f: spk_id = [line.strip().split() for line in f.readlines()] spk_num = len(spk_id) print("spk_num:", spk_num) odim = am_config.n_mels # model: {model_name}_{dataset} - am_name = args.am[:args.am.rindex('_')] - am_dataset = args.am[args.am.rindex('_') + 1:] + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] am_class = dynamic_import(am_name, model_alias) am_inference_class = dynamic_import(am_name + '_inference', model_alias) @@ -113,39 +123,61 @@ def get_am_inference(args, am_config): elif am_name == 'tacotron2': am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) - am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) + am.set_state_dict(paddle.load(am_ckpt)["main_params"]) am.eval() - am_mu, am_std = np.load(args.am_stat) + am_mu, am_std = np.load(am_stat) am_mu = paddle.to_tensor(am_mu) am_std = paddle.to_tensor(am_std) am_normalizer = ZScore(am_mu, am_std) am_inference = am_inference_class(am_normalizer, am) am_inference.eval() print("acoustic model done!") - return am, am_inference, am_name, am_dataset, phn_id + if return_am: + return am_inference, am + else: + return am_inference -def evaluate_durations(phns, - target_language="chinese", - fs=24000, - hop_length=300): +def get_voc_inference( + voc: str='pwgan_csmsc', + voc_config: Optional[os.PathLike]=None, + voc_ckpt: Optional[os.PathLike]=None, + voc_stat: Optional[os.PathLike]=None, ): + # model: {model_name}_{dataset} + voc_name = voc[:voc.rindex('_')] + voc_class = dynamic_import(voc_name, model_alias) + voc_inference_class = dynamic_import(voc_name + '_inference', model_alias) + if voc_name != 'wavernn': + voc = voc_class(**voc_config["generator_params"]) + voc.set_state_dict(paddle.load(voc_ckpt)["generator_params"]) + voc.remove_weight_norm() + voc.eval() + else: + voc = voc_class(**voc_config["model"]) + voc.set_state_dict(paddle.load(voc_ckpt)["main_params"]) + voc.eval() + + voc_mu, voc_std = np.load(voc_stat) + voc_mu = paddle.to_tensor(voc_mu) + voc_std = paddle.to_tensor(voc_std) + voc_normalizer = ZScore(voc_mu, voc_std) + voc_inference = voc_inference_class(voc_normalizer, voc) + voc_inference.eval() + print("voc done!") + return voc_inference + + +def evaluate_durations(phns: List[str], + target_lang: str="chinese", + fs: int=24000, + hop_length: int=300): args = parse_args() - if target_language == 'english': + if target_lang == 'english': args.lang = 'en' - args.am = "fastspeech2_ljspeech" - args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml" - args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz" - args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy" - args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt" - elif target_language == 'chinese': + elif target_lang == 'chinese': args.lang = 'zh' - args.am = "fastspeech2_csmsc" - args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" - args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz" - args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy" - args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" # args = parser.parse_args(args=[]) if args.ngpu == 0: @@ -155,23 +187,28 @@ def evaluate_durations(phns, else: print("ngpu should >= 0 !") - assert target_language == "chinese" or target_language == "english", "In evaluate_durations function, target_language is illegal..." + assert target_lang == "chinese" or target_lang == "english", "In evaluate_durations function, target_lang is illegal..." # Init body. with open(args.am_config) as f: am_config = CfgNode(yaml.safe_load(f)) - # print("========Config========") - # print(am_config) - # print("---------------------") - # acoustic model - am, am_inference, am_name, am_dataset, phn_id = get_am_inference(args, - am_config) + + am_inference, am = get_am_inference( + am=args.am, + am_config=am_config, + am_ckpt=args.am_ckpt, + am_stat=args.am_stat, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict, + speaker_dict=args.speaker_dict, + return_am=True) torch_phns = phns vocab_phones = {} + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] for tone, id in phn_id: vocab_phones[tone] = int(id) - # print("vocab_phones: ", len(vocab_phones)) vocab_size = len(vocab_phones) phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns] @@ -185,59 +222,3 @@ def evaluate_durations(phns, phoneme_durations_new = pre_d_outs * hop_length / fs phoneme_durations_new = phoneme_durations_new.tolist()[:-1] return phoneme_durations_new - - -def sentence2phns(sentence, target_language="en"): - args = parse_args() - if target_language == 'en': - args.lang = 'en' - args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt" - elif target_language == 'zh': - args.lang = 'zh' - args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" - else: - print("target_language should in {'zh', 'en'}!") - - frontend = get_frontend(args) - merge_sentences = True - get_tone_ids = False - - if target_language == 'zh': - input_ids = frontend.get_input_ids( - sentence, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids, - print_info=False) - phone_ids = input_ids["phone_ids"] - - phonemes = frontend.get_phonemes( - sentence, merge_sentences=merge_sentences, print_info=False) - - return phonemes[0], input_ids["phone_ids"][0] - - elif target_language == 'en': - phonemes = frontend.phoneticize(sentence) - input_ids = frontend.get_input_ids( - sentence, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - - phones_list = [] - vocab_phones = {} - punc = ":,;。?!“”‘’':,;.?!" - with open(args.phones_dict, 'rt') as f: - phn_id = [line.strip().split() for line in f.readlines()] - for phn, id in phn_id: - vocab_phones[phn] = int(id) - - phones = phonemes[1:-1] - phones = [phn for phn in phones if not phn.isspace()] - # replace unk phone with sp - phones = [ - phn if (phn in vocab_phones and phn not in punc) else "sp" - for phn in phones - ] - phones_list.append(phones) - return phones_list[0], input_ids["phone_ids"][0] - - else: - print("lang should in {'zh', 'en'}!")