diff --git a/ernie-sat/README.md b/ernie-sat/README.md index bfecccc5500431ad515802a430987a71f1bb89be..0d204a739cbc0338183bc15de30dde0a34f6de5e 100644 --- a/ernie-sat/README.md +++ b/ernie-sat/README.md @@ -39,9 +39,9 @@ ERNIE-SAT 中我们提出了两项创新: ### 2.预训练模型 预训练模型 ERNIE-SAT 的模型如下所示: -- [ERNIE-SAT_ZH](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-zh.tar.gz) -- [ERNIE-SAT_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en.tar.gz) -- [ERNIE-SAT_ZH_and_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en_zh.tar.gz) +- [ERNIE-SAT_ZH](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-zh.tar.gz) +- [ERNIE-SAT_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en.tar.gz) +- [ERNIE-SAT_ZH_and_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en_zh.tar.gz) 创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压: @@ -108,7 +108,7 @@ prompt/dev 3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset} 4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。 5. `--lang` 对应模型的语言可以是 `zh` 或 `en` 。 -6. `--ngpu` 要使用的GPU数,如果 ngpu==0,则使用 cpu。 +6. `--ngpu` 要使用的 GPU 数,如果 ngpu==0,则使用 cpu。 7. ` --model_name` 模型名称 8. ` --uid` 特定提示(prompt)语音的 id 9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本) @@ -125,4 +125,3 @@ sh run_sedit_en.sh # 语音编辑任务(英文) sh run_gen_en.sh # 个性化语音合成任务(英文) sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆) ``` - diff --git a/ernie-sat/align.py b/ernie-sat/align.py index 5c7144f439d887fac1b7763df8f4376cd2ec73a1..025877ddf35dd72dd5a9e80bfe3714998374eff0 100755 --- a/ernie-sat/align.py +++ b/ernie-sat/align.py @@ -1,13 +1,9 @@ -#!/usr/bin/env python """ Usage: align.py wavfile trsfile outwordfile outphonefile """ -import multiprocessing as mp import os import sys -from tqdm import tqdm - PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' MODEL_DIR_EN = 'tools/aligner/english' MODEL_DIR_ZH = 'tools/aligner/mandarin' @@ -15,6 +11,142 @@ HVITE = 'tools/htk/HTKTools/HVite' HCOPY = 'tools/htk/HTKTools/HCopy' +def get_unk_phns(word_str: str): + tmpbase = '/tmp/tp.' + f = open(tmpbase + 'temp.words', 'w') + f.write(word_str) + f.close() + os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase + + 'temp.phons') + f = open(tmpbase + 'temp.phons', 'r') + lines2 = f.readline().strip().split() + f.close() + phns = [] + for phn in lines2: + phons = phn.replace('\n', '').replace(' ', '') + seq = [] + j = 0 + while (j < len(phons)): + if (phons[j] > 'Z'): + if (phons[j] == 'j'): + seq.append('JH') + elif (phons[j] == 'h'): + seq.append('HH') + else: + seq.append(phons[j].upper()) + j += 1 + else: + p = phons[j:j + 2] + if (p == 'WH'): + seq.append('W') + elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']): + seq.append(p) + elif (p == 'AX'): + seq.append('AH0') + else: + seq.append(p + '1') + j += 2 + phns.extend(seq) + return phns + + +def words2phns(line: str): + ''' + Args: + line (str): input text. + eg: for that reason cover is impossible to be given. + Returns: + List[str]: phones of input text. + eg: + ['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0', + 'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1', + 'G', 'IH1', 'V', 'AH0', 'N'] + + Dict(str, str): key - idx_word + value - phones + eg: + {'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'], + '3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'], + '6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']} + ''' + dictfile = MODEL_DIR_EN + '/dict' + line = line.strip() + words = [] + for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']: + line = line.replace(pun, ' ') + for wrd in line.split(): + if (wrd[-1] == '-'): + wrd = wrd[:-1] + if (wrd[0] == "'"): + wrd = wrd[1:] + if wrd: + words.append(wrd) + ds = set([]) + word2phns_dict = {} + with open(dictfile, 'r') as fid: + for line in fid: + word = line.split()[0] + ds.add(word) + if word not in word2phns_dict.keys(): + word2phns_dict[word] = " ".join(line.split()[1:]) + + phns = [] + wrd2phns = {} + for index, wrd in enumerate(words): + if wrd == '[MASK]': + wrd2phns[str(index) + "_" + wrd] = [wrd] + phns.append(wrd) + elif (wrd.upper() not in ds): + wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd) + phns.extend(get_unk_phns(wrd)) + else: + wrd2phns[str(index) + + "_" + wrd.upper()] = word2phns_dict[wrd.upper()].split() + phns.extend(word2phns_dict[wrd.upper()].split()) + return phns, wrd2phns + + +def words2phns_zh(line: str): + dictfile = MODEL_DIR_ZH + '/dict' + line = line.strip() + words = [] + for pun in [ + ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', + u'。', u':', u';', u'!', u'?', u'(', u')' + ]: + line = line.replace(pun, ' ') + for wrd in line.split(): + if (wrd[-1] == '-'): + wrd = wrd[:-1] + if (wrd[0] == "'"): + wrd = wrd[1:] + if wrd: + words.append(wrd) + + ds = set([]) + word2phns_dict = {} + with open(dictfile, 'r') as fid: + for line in fid: + word = line.split()[0] + ds.add(word) + if word not in word2phns_dict.keys(): + word2phns_dict[word] = " ".join(line.split()[1:]) + + phns = [] + wrd2phns = {} + for index, wrd in enumerate(words): + if wrd == '[MASK]': + wrd2phns[str(index) + "_" + wrd] = [wrd] + phns.append(wrd) + elif (wrd.upper() not in ds): + print("出现非法词错误,请输入正确的文本...") + else: + wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split() + phns.extend(word2phns_dict[wrd].split()) + + return phns, wrd2phns + + def prep_txt_zh(line: str, tmpbase: str, dictfile: str): words = [] @@ -82,7 +214,7 @@ def prep_txt_en(line: str, tmpbase, dictfile): try: os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase + '_unk.phons') - except: + except Exception: print('english2phoneme error!') sys.exit(1) @@ -148,19 +280,22 @@ def _get_user(): def alignment(wav_path: str, text: str): + ''' + intervals: List[phn, start, end] + ''' tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) #prepare wav and trs files try: os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -') - except: + except Exception: print('sox error!') return None #prepare clean_transcript file try: - prep_txt_en(text, tmpbase, MODEL_DIR_EN + '/dict') - except: + prep_txt_en(line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_EN + '/dict') + except Exception: print('prep_txt error!') return None @@ -169,7 +304,7 @@ def alignment(wav_path: str, text: str): with open(tmpbase + '.txt', 'r') as fid: txt = fid.readline() prep_mlf(txt, tmpbase) - except: + except Exception: print('prep_mlf error!') return None @@ -177,7 +312,7 @@ def alignment(wav_path: str, text: str): try: os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase + '.wav' + ' ' + tmpbase + '.plp') - except: + except Exception: print('HCopy error!') return None @@ -188,7 +323,7 @@ def alignment(wav_path: str, text: str): + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + '.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase + '.plp 2>&1 > /dev/null') - except: + except Exception: print('HVite error!') return None @@ -200,7 +335,7 @@ def alignment(wav_path: str, text: str): with open(tmpbase + '.aligned', 'r') as fid: lines = fid.readlines() i = 2 - times2 = [] + intervals = [] word2phns = {} current_word = '' index = 0 @@ -210,7 +345,7 @@ def alignment(wav_path: str, text: str): phn = splited_line[2] pst = (int(splited_line[0]) / 1000 + 125) / 10000 pen = (int(splited_line[1]) / 1000 + 125) / 10000 - times2.append([phn, pst, pen]) + intervals.append([phn, pst, pen]) # splited_line[-1]!='sp' if len(splited_line) == 5: current_word = str(index) + '_' + splited_line[-1] @@ -219,10 +354,10 @@ def alignment(wav_path: str, text: str): elif len(splited_line) == 4: word2phns[current_word] += ' ' + phn i += 1 - return times2, word2phns + return intervals, word2phns -def alignment_zh(wav_path, text_string): +def alignment_zh(wav_path: str, text: str): tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) #prepare wav and trs files @@ -230,18 +365,19 @@ def alignment_zh(wav_path, text_string): os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase + '.wav remix -') - except: + except Exception: print('sox error!') return None #prepare clean_transcript file try: - unk_words = prep_txt_zh(text_string, tmpbase, MODEL_DIR_ZH + '/dict') + unk_words = prep_txt_zh( + line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_ZH + '/dict') if unk_words: print('Error! Please add the following words to dictionary:') for unk in unk_words: print("非法words: ", unk) - except: + except Exception: print('prep_txt error!') return None @@ -250,7 +386,7 @@ def alignment_zh(wav_path, text_string): with open(tmpbase + '.txt', 'r') as fid: txt = fid.readline() prep_mlf(txt, tmpbase) - except: + except Exception: print('prep_mlf error!') return None @@ -258,7 +394,7 @@ def alignment_zh(wav_path, text_string): try: os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase + '.wav' + ' ' + tmpbase + '.plp') - except: + except Exception: print('HCopy error!') return None @@ -270,7 +406,7 @@ def alignment_zh(wav_path, text_string): + '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase + '.plp 2>&1 > /dev/null') - except: + except Exception: print('HVite error!') return None @@ -283,7 +419,7 @@ def alignment_zh(wav_path, text_string): lines = fid.readlines() i = 2 - times2 = [] + intervals = [] word2phns = {} current_word = '' index = 0 @@ -293,7 +429,7 @@ def alignment_zh(wav_path, text_string): phn = splited_line[2] pst = (int(splited_line[0]) / 1000 + 125) / 10000 pen = (int(splited_line[1]) / 1000 + 125) / 10000 - times2.append([phn, pst, pen]) + intervals.append([phn, pst, pen]) # splited_line[-1]!='sp' if len(splited_line) == 5: current_word = str(index) + '_' + splited_line[-1] @@ -302,4 +438,4 @@ def alignment_zh(wav_path, text_string): elif len(splited_line) == 4: word2phns[current_word] += ' ' + phn i += 1 - return times2, word2phns + return intervals, word2phns diff --git a/ernie-sat/collect_fn.py b/ernie-sat/collect_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..363618dc8887faa74e5a70bed41813c2a03a3e56 --- /dev/null +++ b/ernie-sat/collect_fn.py @@ -0,0 +1,217 @@ +from typing import Collection +from typing import Dict +from typing import List +from typing import Tuple +from typing import Union + +import numpy as np +import paddle + +from dataset import get_seg_pos +from dataset import phones_masking +from dataset import phones_text_masking +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask +from paddlespeech.t2s.modules.nets_utils import pad_list + + +class MLMCollateFn: + """Functor class of common_collate_fn()""" + + def __init__(self, + feats_extract, + float_pad_value: Union[float, int]=0.0, + int_pad_value: int=-32768, + not_sequence: Collection[str]=(), + mlm_prob: float=0.8, + mean_phn_span: int=8, + attention_window: int=0, + pad_speech: bool=False, + seg_emb: bool=False, + text_masking: bool=False): + self.mlm_prob = mlm_prob + self.mean_phn_span = mean_phn_span + self.feats_extract = feats_extract + self.float_pad_value = float_pad_value + self.int_pad_value = int_pad_value + self.not_sequence = set(not_sequence) + self.attention_window = attention_window + self.pad_speech = pad_speech + self.seg_emb = seg_emb + self.text_masking = text_masking + + def __repr__(self): + return (f"{self.__class__}(float_pad_value={self.float_pad_value}, " + f"int_pad_value={self.float_pad_value})") + + def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]] + ) -> Tuple[List[str], Dict[str, paddle.Tensor]]: + return mlm_collate_fn( + data, + float_pad_value=self.float_pad_value, + int_pad_value=self.int_pad_value, + not_sequence=self.not_sequence, + mlm_prob=self.mlm_prob, + mean_phn_span=self.mean_phn_span, + feats_extract=self.feats_extract, + attention_window=self.attention_window, + pad_speech=self.pad_speech, + seg_emb=self.seg_emb, + text_masking=self.text_masking) + + +def mlm_collate_fn( + data: Collection[Tuple[str, Dict[str, np.ndarray]]], + float_pad_value: Union[float, int]=0.0, + int_pad_value: int=-32768, + not_sequence: Collection[str]=(), + mlm_prob: float=0.8, + mean_phn_span: int=8, + feats_extract=None, + attention_window: int=0, + pad_speech: bool=False, + seg_emb: bool=False, + text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]: + uttids = [u for u, _ in data] + data = [d for _, d in data] + + assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" + assert all(not k.endswith("_lens") + for k in data[0]), f"*_lens is reserved: {list(data[0])}" + + output = {} + for key in data[0]: + # Each models, which accepts these values finally, are responsible + # to repaint the pad_value to the desired value for each tasks. + if data[0][key].dtype.kind == "i": + pad_value = int_pad_value + else: + pad_value = float_pad_value + + array_list = [d[key] for d in data] + + # Assume the first axis is length: + # tensor_list: Batch x (Length, ...) + tensor_list = [paddle.to_tensor(a) for a in array_list] + # tensor: (Batch, Length, ...) + tensor = pad_list(tensor_list, pad_value) + output[key] = tensor + + # lens: (Batch,) + if key not in not_sequence: + lens = paddle.to_tensor( + [d[key].shape[0] for d in data], dtype=paddle.int64) + output[key + "_lens"] = lens + + feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0])) + feats = paddle.to_tensor(feats) + feats_lens = paddle.shape(feats)[0] + feats = paddle.unsqueeze(feats, 0) + + text = output["text"] + text_lens = output["text_lens"] + align_start = output["align_start"] + align_start_lens = output["align_start_lens"] + align_end = output["align_end"] + + max_tlen = max(text_lens) + max_slen = max(feats_lens) + + speech_pad = feats[:, :max_slen] + + text_pad = text + text_mask = make_non_pad_mask( + text_lens, text_pad, length_dim=1).unsqueeze(-2) + speech_mask = make_non_pad_mask( + feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2) + span_bdy = None + if 'span_bdy' in output.keys(): + span_bdy = output['span_bdy'] + + # dual_mask 的是混合中英时候同时 mask 语音和文本 + # ernie sat 在实现跨语言的时候都 mask 了 + if text_masking: + masked_pos, text_masked_pos = phones_text_masking( + xs_pad=speech_pad, + src_mask=speech_mask, + text_pad=text_pad, + text_mask=text_mask, + align_start=align_start, + align_end=align_end, + align_start_lens=align_start_lens, + mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span, + span_bdy=span_bdy) + # 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了 + # a3t 和 ernie sat 的区别主要在于做 mask 的时候 + else: + masked_pos = phones_masking( + xs_pad=speech_pad, + src_mask=speech_mask, + align_start=align_start, + align_end=align_end, + align_start_lens=align_start_lens, + mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span, + span_bdy=span_bdy) + text_masked_pos = paddle.zeros(paddle.shape(text_pad)) + + output_dict = {} + + speech_seg_pos, text_seg_pos = get_seg_pos( + speech_pad=speech_pad, + text_pad=text_pad, + align_start=align_start, + align_end=align_end, + align_start_lens=align_start_lens, + seg_emb=seg_emb) + output_dict['speech'] = speech_pad + output_dict['text'] = text_pad + output_dict['masked_pos'] = masked_pos + output_dict['text_masked_pos'] = text_masked_pos + output_dict['speech_mask'] = speech_mask + output_dict['text_mask'] = text_mask + output_dict['speech_seg_pos'] = speech_seg_pos + output_dict['text_seg_pos'] = text_seg_pos + output = (uttids, output_dict) + return output + + +def build_collate_fn( + sr: int=24000, + n_fft: int=2048, + hop_length: int=300, + win_length: int=None, + n_mels: int=80, + fmin: int=80, + fmax: int=7600, + mlm_prob: float=0.8, + mean_phn_span: int=8, + train: bool=False, + seg_emb: bool=False, + epoch: int=-1, ): + feats_extract_class = LogMelFBank + + feats_extract = feats_extract_class( + sr=sr, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mels=n_mels, + fmin=fmin, + fmax=fmax) + + pad_speech = False + if epoch == -1: + mlm_prob_factor = 1 + else: + mlm_prob_factor = 0.8 + + return MLMCollateFn( + feats_extract=feats_extract, + float_pad_value=0.0, + int_pad_value=0, + mlm_prob=mlm_prob * mlm_prob_factor, + mean_phn_span=mean_phn_span, + pad_speech=pad_speech, + seg_emb=seg_emb) diff --git a/ernie-sat/dataset.py b/ernie-sat/dataset.py index d8b896ada57380809f57d7e7f8329fe3947c92a9..3bf2d8f12f1b49bfa3a8469d880c66f4b398ee0a 100644 --- a/ernie-sat/dataset.py +++ b/ernie-sat/dataset.py @@ -4,6 +4,68 @@ import numpy as np import paddle +# mask phones +def phones_masking(xs_pad: paddle.Tensor, + src_mask: paddle.Tensor, + align_start: paddle.Tensor, + align_end: paddle.Tensor, + align_start_lens: paddle.Tensor, + mlm_prob: float=0.8, + mean_phn_span: int=8, + span_bdy: paddle.Tensor=None): + ''' + Args: + xs_pad (paddle.Tensor): input speech (B, Tmax, D). + src_mask (paddle.Tensor): mask of speech (B, 1, Tmax). + align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2). + align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2). + align_start_lens (paddle.Tensor): length of align_start (B, ). + mlm_prob (float): + mean_phn_span (int): + span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2). + Returns: + paddle.Tensor[bool]: masked position of input speech (B, Tmax). + ''' + bz, sent_len, _ = paddle.shape(xs_pad) + masked_pos = paddle.zeros((bz, sent_len)) + if mlm_prob == 1.0: + masked_pos += 1 + elif mean_phn_span == 0: + # only speech + length = sent_len + mean_phn_span = min(length * mlm_prob // 3, 50) + masked_phn_idxs = random_spans_noise_mask( + length=length, mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span).nonzero() + masked_pos[:, masked_phn_idxs] = 1 + else: + for idx in range(bz): + # for inference + if span_bdy is not None: + for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]): + masked_pos[idx, s:e] = 1 + # for training + else: + length = align_start_lens[idx] + if length < 2: + continue + masked_phn_idxs = random_spans_noise_mask( + length=length, + mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span).nonzero() + masked_start = align_start[idx][masked_phn_idxs].tolist() + masked_end = align_end[idx][masked_phn_idxs].tolist() + + for s, e in zip(masked_start, masked_end): + masked_pos[idx, s:e] = 1 + non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) + masked_pos = masked_pos * non_eos_mask + masked_pos = paddle.cast(masked_pos, 'bool') + + return masked_pos + + +# mask speech and phones def phones_text_masking(xs_pad: paddle.Tensor, src_mask: paddle.Tensor, text_pad: paddle.Tensor, @@ -11,37 +73,56 @@ def phones_text_masking(xs_pad: paddle.Tensor, align_start: paddle.Tensor, align_end: paddle.Tensor, align_start_lens: paddle.Tensor, - mlm_prob: float, - mean_phn_span: float, + mlm_prob: float=0.8, + mean_phn_span: int=8, span_bdy: paddle.Tensor=None): + ''' + Args: + xs_pad (paddle.Tensor): input speech (B, Tmax, D). + src_mask (paddle.Tensor): mask of speech (B, 1, Tmax). + text_pad (paddle.Tensor): input text (B, Tmax2). + text_mask (paddle.Tensor): mask of text (B, 1, Tmax2). + align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2). + align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2). + align_start_lens (paddle.Tensor): length of align_start (B, ). + mlm_prob (float): + mean_phn_span (int): + span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2). + Returns: + paddle.Tensor[bool]: masked position of input speech (B, Tmax). + paddle.Tensor[bool]: masked position of input text (B, Tmax2). + ''' bz, sent_len, _ = paddle.shape(xs_pad) masked_pos = paddle.zeros((bz, sent_len)) _, text_len = paddle.shape(text_pad) text_mask_num_lower = math.ceil(text_len * (1 - mlm_prob) * 0.5) text_masked_pos = paddle.zeros((bz, text_len)) - y_masks = None if mlm_prob == 1.0: masked_pos += 1 - # y_masks = tril_masks elif mean_phn_span == 0: # only speech length = sent_len mean_phn_span = min(length * mlm_prob // 3, 50) - masked_phn_idxs = random_spans_noise_mask(length, mlm_prob, - mean_phn_span).nonzero() + masked_phn_idxs = random_spans_noise_mask( + length=length, mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span).nonzero() masked_pos[:, masked_phn_idxs] = 1 else: for idx in range(bz): + # for inference if span_bdy is not None: for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]): masked_pos[idx, s:e] = 1 + # for training else: length = align_start_lens[idx] if length < 2: continue masked_phn_idxs = random_spans_noise_mask( - length, mlm_prob, mean_phn_span).nonzero() + length=length, + mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span).nonzero() unmasked_phn_idxs = list( set(range(length)) - set(masked_phn_idxs[0].tolist())) np.random.shuffle(unmasked_phn_idxs) @@ -58,60 +139,76 @@ def phones_text_masking(xs_pad: paddle.Tensor, masked_pos = paddle.cast(masked_pos, 'bool') text_masked_pos = paddle.cast(text_masked_pos, 'bool') - return masked_pos, text_masked_pos, y_masks + return masked_pos, text_masked_pos -def get_seg_pos_reduce_duration( - speech_pad: paddle.Tensor, - text_pad: paddle.Tensor, - align_start: paddle.Tensor, - align_end: paddle.Tensor, - align_start_lens: paddle.Tensor, - sega_emb: bool, - masked_pos: paddle.Tensor, - feats_lens: paddle.Tensor, ): +def get_seg_pos(speech_pad: paddle.Tensor, + text_pad: paddle.Tensor, + align_start: paddle.Tensor, + align_end: paddle.Tensor, + align_start_lens: paddle.Tensor, + seg_emb: bool=False): + ''' + Args: + speech_pad (paddle.Tensor): input speech (B, Tmax, D). + text_pad (paddle.Tensor): input text (B, Tmax2). + align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2). + align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2). + align_start_lens (paddle.Tensor): length of align_start (B, ). + seg_emb (bool): whether to use segment embedding. + Returns: + paddle.Tensor[int]: n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax). + eg: + Tensor(shape=[1, 328], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , + 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , + 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , + 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , + 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 3 , 4 , 4 , 4 , + 5 , 5 , 5 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , + 7 , 8 , 8 , 8 , 8 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 10, 10, 10, 10, 10, + 10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, + 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, + 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, + 17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, + 20, 20, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 23, 23, + 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, + 25, 26, 26, 26, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 29, + 29, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 32, + 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 35, 35, + 35, 35, 35, 35, 35, 35, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, + 37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, + 38, 38, 0 , 0 ]]) + paddle.Tensor[int]: n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2). + eg: + Tensor(shape=[1, 38], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38]]) + ''' + bz, speech_len, _ = paddle.shape(speech_pad) - text_seg_pos = paddle.zeros(paddle.shape(text_pad)) - speech_seg_pos = paddle.zeros((bz, speech_len), dtype=text_pad.dtype) + _, text_len = paddle.shape(text_pad) - reordered_idx = paddle.zeros((bz, speech_len), dtype=align_start_lens.dtype) + text_seg_pos = paddle.zeros((bz, text_len), dtype='int64') + speech_seg_pos = paddle.zeros((bz, speech_len), dtype='int64') - durations = paddle.ones((bz, speech_len), dtype=align_start_lens.dtype) - max_reduced_length = 0 - if not sega_emb: - return speech_pad, masked_pos, speech_seg_pos, text_seg_pos, durations + if not seg_emb: + return speech_seg_pos, text_seg_pos for idx in range(bz): - first_idx = [] - last_idx = [] align_length = align_start_lens[idx] for j in range(align_length): s, e = align_start[idx][j], align_end[idx][j] - if j == 0: - if paddle.sum(masked_pos[idx][0:s]) == 0: - first_idx.extend(range(0, s)) - else: - first_idx.extend([0]) - last_idx.extend(range(1, s)) - if paddle.sum(masked_pos[idx][s:e]) == 0: - first_idx.extend(range(s, e)) - else: - first_idx.extend([s]) - last_idx.extend(range(s + 1, e)) - durations[idx][s] = e - s - speech_seg_pos[idx][s:e] = j + 1 - text_seg_pos[idx][j] = j + 1 - max_reduced_length = max( - len(first_idx) + feats_lens[idx] - e, max_reduced_length) - first_idx.extend(range(e, speech_len)) - reordered_idx[idx] = paddle.to_tensor( - (first_idx + last_idx), dtype=align_start_lens.dtype) - feats_lens[idx] = len(first_idx) - reordered_idx = reordered_idx[:, :max_reduced_length] + speech_seg_pos[idx, s:e] = j + 1 + text_seg_pos[idx, j] = j + 1 - return reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens + return speech_seg_pos, text_seg_pos -def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float): +# randomly select the range of speech and text to mask during training +def random_spans_noise_mask(length: int, + mlm_prob: float=0.8, + mean_phn_span: float=8): """This function is copy of `random_spans_helper `__ . Noise mask consisting of random spans of noise tokens. @@ -126,7 +223,7 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float): noise_density: a float - approximate density of output mask mean_noise_span_length: a number Returns: - a boolean tensor with shape [length] + np.ndarray: a boolean tensor with shape [length] """ orig_length = length @@ -171,87 +268,3 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float): is_noise = np.equal(span_num % 2, 1) return is_noise[:orig_length] - - -def pad_to_longformer_att_window(text: paddle.Tensor, - max_len: int, - max_tlen: int, - attention_window: int=0): - - round = max_len % attention_window - if round != 0: - max_tlen += (attention_window - round) - n_batch = paddle.shape(text)[0] - text_pad = paddle.zeros( - (n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype) - for i in range(n_batch): - text_pad[i, :paddle.shape(text[i])[0]] = text[i] - else: - text_pad = text[:, :max_tlen] - return text_pad, max_tlen - - -def phones_masking(xs_pad: paddle.Tensor, - src_mask: paddle.Tensor, - align_start: paddle.Tensor, - align_end: paddle.Tensor, - align_start_lens: paddle.Tensor, - mlm_prob: float, - mean_phn_span: int, - span_bdy: paddle.Tensor=None): - bz, sent_len, _ = paddle.shape(xs_pad) - masked_pos = paddle.zeros((bz, sent_len)) - y_masks = None - if mlm_prob == 1.0: - masked_pos += 1 - elif mean_phn_span == 0: - # only speech - length = sent_len - mean_phn_span = min(length * mlm_prob // 3, 50) - masked_phn_idxs = random_spans_noise_mask(length, mlm_prob, - mean_phn_span).nonzero() - masked_pos[:, masked_phn_idxs] = 1 - else: - for idx in range(bz): - if span_bdy is not None: - for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]): - masked_pos[idx, s:e] = 1 - else: - length = align_start_lens[idx] - if length < 2: - continue - masked_phn_idxs = random_spans_noise_mask( - length, mlm_prob, mean_phn_span).nonzero() - masked_start = align_start[idx][masked_phn_idxs].tolist() - masked_end = align_end[idx][masked_phn_idxs].tolist() - for s, e in zip(masked_start, masked_end): - masked_pos[idx, s:e] = 1 - non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) - masked_pos = masked_pos * non_eos_mask - masked_pos = paddle.cast(masked_pos, 'bool') - - return masked_pos, y_masks - - -def get_seg_pos(speech_pad: paddle.Tensor, - text_pad: paddle.Tensor, - align_start: paddle.Tensor, - align_end: paddle.Tensor, - align_start_lens: paddle.Tensor, - sega_emb: bool): - bz, speech_len, _ = paddle.shape(speech_pad) - _, text_len = paddle.shape(text_pad) - - text_seg_pos = paddle.zeros((bz, text_len), dtype='int64') - speech_seg_pos = paddle.zeros((bz, speech_len), dtype='int64') - - if not sega_emb: - return speech_seg_pos, text_seg_pos - for idx in range(bz): - align_length = align_start_lens[idx] - for j in range(align_length): - s, e = align_start[idx][j], align_end[idx][j] - speech_seg_pos[idx, s:e] = j + 1 - text_seg_pos[idx, j] = j + 1 - - return speech_seg_pos, text_seg_pos diff --git a/ernie-sat/inference.py b/ernie-sat/inference.py index ee702d2b1e7ec0dc34382472465fa4845ae84f35..1760fa4d5e1a23577b919fb39b483f6fc66f273f 100644 --- a/ernie-sat/inference.py +++ b/ernie-sat/inference.py @@ -1,13 +1,9 @@ #!/usr/bin/env python3 -import argparse import os import random from pathlib import Path -from typing import Collection from typing import Dict from typing import List -from typing import Tuple -from typing import Union import librosa import numpy as np @@ -15,217 +11,62 @@ import paddle import soundfile as sf import torch from paddle import nn - from ParallelWaveGAN.parallel_wavegan.utils.utils import download_pretrained_model from align import alignment from align import alignment_zh -from dataset import get_seg_pos -from dataset import get_seg_pos_reduce_duration -from dataset import pad_to_longformer_att_window -from dataset import phones_masking -from dataset import phones_text_masking -from model_paddle import build_model_from_file +from align import words2phns +from align import words2phns_zh +from collect_fn import build_collate_fn +from mlm import build_model_from_file from read_text import load_num_sequence_text -from read_text import read_2column_text +from read_text import read_2col_text from sedit_arg_parser import parse_args from utils import build_vocoder_from_file -from utils import evaluate_durations +from utils import eval_durs from utils import get_voc_out from utils import is_chinese -from paddlespeech.t2s.datasets.get_feats import LogMelFBank -from paddlespeech.t2s.modules.nets_utils import pad_list -from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask + random.seed(0) np.random.seed(0) -PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' -MODEL_DIR_EN = 'tools/aligner/english' -MODEL_DIR_ZH = 'tools/aligner/mandarin' - - -def plot_mel_and_vocode_wav(uid: str, - wav_path: str, - prefix: str="./prompt/dev/", - source_lang: str='english', - target_lang: str='english', - model_name: str="conformer", - full_origin_str: str="", - old_str: str="", - new_str: str="", - duration_preditor_path: str=None, - use_pt_vocoder: bool=False, - sid: str=None, - non_autoreg: bool=True): - wav_org, input_feat, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output( - uid=uid, - prefix=prefix, + +def get_wav(wav_path: str, + source_lang: str='english', + target_lang: str='english', + model_name: str="paddle_checkpoint_en", + old_str: str="", + new_str: str="", + use_pt_vocoder: bool=False, + non_autoreg: bool=True): + wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output( source_lang=source_lang, target_lang=target_lang, model_name=model_name, wav_path=wav_path, old_str=old_str, new_str=new_str, - duration_preditor_path=duration_preditor_path, - use_teacher_forcing=non_autoreg, - sid=sid) + use_teacher_forcing=non_autoreg) masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]] - if target_lang == 'english': - if use_pt_vocoder: - output_feat = output_feat.cpu().numpy() - output_feat = torch.tensor(output_feat, dtype=torch.float) - vocoder = load_vocoder('vctk_parallel_wavegan.v1.long') - replaced_wav = vocoder(output_feat).cpu().numpy() - else: - replaced_wav = get_voc_out(output_feat, target_lang) + if target_lang == 'english' and use_pt_vocoder: + masked_feat = masked_feat.cpu().numpy() + masked_feat = torch.tensor(masked_feat, dtype=torch.float) + vocoder = load_vocoder('vctk_parallel_wavegan.v1.long') + alt_wav = vocoder(masked_feat).cpu().numpy() - elif target_lang == 'chinese': - replaced_wav_only_mask_fst2_voc = get_voc_out(masked_feat, target_lang) + else: + alt_wav = get_voc_out(masked_feat) old_time_bdy = [hop_length * x for x in old_span_bdy] - new_time_bdy = [hop_length * x for x in new_span_bdy] - - if target_lang == 'english': - wav_org_replaced_paddle_voc = np.concatenate([ - wav_org[:old_time_bdy[0]], - replaced_wav[new_time_bdy[0]:new_time_bdy[1]], - wav_org[old_time_bdy[1]:] - ]) - - data_dict = {"origin": wav_org, "output": wav_org_replaced_paddle_voc} - - elif target_lang == 'chinese': - wav_org_replaced_only_mask_fst2_voc = np.concatenate([ - wav_org[:old_time_bdy[0]], replaced_wav_only_mask_fst2_voc, - wav_org[old_time_bdy[1]:] - ]) - data_dict = { - "origin": wav_org, - "output": wav_org_replaced_only_mask_fst2_voc, - } - - return data_dict, old_span_bdy - - -def get_unk_phns(word_str: str): - tmpbase = '/tmp/tp.' - f = open(tmpbase + 'temp.words', 'w') - f.write(word_str) - f.close() - os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase + - 'temp.phons') - f = open(tmpbase + 'temp.phons', 'r') - lines2 = f.readline().strip().split() - f.close() - phns = [] - for phn in lines2: - phons = phn.replace('\n', '').replace(' ', '') - seq = [] - j = 0 - while (j < len(phons)): - if (phons[j] > 'Z'): - if (phons[j] == 'j'): - seq.append('JH') - elif (phons[j] == 'h'): - seq.append('HH') - else: - seq.append(phons[j].upper()) - j += 1 - else: - p = phons[j:j + 2] - if (p == 'WH'): - seq.append('W') - elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']): - seq.append(p) - elif (p == 'AX'): - seq.append('AH0') - else: - seq.append(p + '1') - j += 2 - phns.extend(seq) - return phns - - -def words2phns(line: str): - dictfile = MODEL_DIR_EN + '/dict' - line = line.strip() - words = [] - for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']: - line = line.replace(pun, ' ') - for wrd in line.split(): - if (wrd[-1] == '-'): - wrd = wrd[:-1] - if (wrd[0] == "'"): - wrd = wrd[1:] - if wrd: - words.append(wrd) - ds = set([]) - word2phns_dict = {} - with open(dictfile, 'r') as fid: - for line in fid: - word = line.split()[0] - ds.add(word) - if word not in word2phns_dict.keys(): - word2phns_dict[word] = " ".join(line.split()[1:]) - - phns = [] - wrd2phns = {} - for index, wrd in enumerate(words): - if wrd == '[MASK]': - wrd2phns[str(index) + "_" + wrd] = [wrd] - phns.append(wrd) - elif (wrd.upper() not in ds): - wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd) - phns.extend(get_unk_phns(wrd)) - else: - wrd2phns[str(index) + - "_" + wrd.upper()] = word2phns_dict[wrd.upper()].split() - phns.extend(word2phns_dict[wrd.upper()].split()) - - return phns, wrd2phns - - -def words2phns_zh(line: str): - dictfile = MODEL_DIR_ZH + '/dict' - line = line.strip() - words = [] - for pun in [ - ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', - u'。', u':', u';', u'!', u'?', u'(', u')' - ]: - line = line.replace(pun, ' ') - for wrd in line.split(): - if (wrd[-1] == '-'): - wrd = wrd[:-1] - if (wrd[0] == "'"): - wrd = wrd[1:] - if wrd: - words.append(wrd) - - ds = set([]) - word2phns_dict = {} - with open(dictfile, 'r') as fid: - for line in fid: - word = line.split()[0] - ds.add(word) - if word not in word2phns_dict.keys(): - word2phns_dict[word] = " ".join(line.split()[1:]) - - phns = [] - wrd2phns = {} - for index, wrd in enumerate(words): - if wrd == '[MASK]': - wrd2phns[str(index) + "_" + wrd] = [wrd] - phns.append(wrd) - elif (wrd.upper() not in ds): - print("出现非法词错误,请输入正确的文本...") - else: - wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split() - phns.extend(word2phns_dict[wrd].split()) - return phns, wrd2phns + wav_replaced = np.concatenate( + [wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]]) + + data_dict = {"origin": wav_org, "output": wav_replaced} + + return data_dict def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"): @@ -236,50 +77,52 @@ def load_vocoder(vocoder_tag: str="vctk_parallel_wavegan.v1.long"): return vocoder -def load_model(model_name: str): +def load_model(model_name: str="paddle_checkpoint_en"): config_path = './pretrained_model/{}/config.yaml'.format(model_name) model_path = './pretrained_model/{}/model.pdparams'.format(model_name) - mlm_model, args = build_model_from_file( + mlm_model, conf = build_model_from_file( config_file=config_path, model_file=model_path) - return mlm_model, args + return mlm_model, conf -def read_data(uid: str, prefix: str): - mfa_text = read_2column_text(prefix + '/text')[uid] - mfa_wav_path = read_2column_text(prefix + '/wav.scp')[uid] - if 'mnt' not in mfa_wav_path: - mfa_wav_path = prefix.split('dump')[0] + mfa_wav_path +def read_data(uid: str, prefix: os.PathLike): + # 获取 uid 对应的文本 + mfa_text = read_2col_text(prefix + '/text')[uid] + # 获取 uid 对应的音频路径 + mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid] + if not os.path.isabs(mfa_wav_path): + mfa_wav_path = prefix + mfa_wav_path return mfa_text, mfa_wav_path -def get_align_data(uid: str, prefix: str): +def get_align_data(uid: str, prefix: os.PathLike): mfa_path = prefix + "mfa_" - mfa_text = read_2column_text(mfa_path + 'text')[uid] + mfa_text = read_2col_text(mfa_path + 'text')[uid] mfa_start = load_num_sequence_text( mfa_path + 'start', loader_type='text_float')[uid] mfa_end = load_num_sequence_text( mfa_path + 'end', loader_type='text_float')[uid] - mfa_wav_path = read_2column_text(mfa_path + 'wav.scp')[uid] + mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid] return mfa_text, mfa_start, mfa_end, mfa_wav_path +# 获取需要被 mask 的 mel 帧的范围 def get_masked_mel_bdy(mfa_start: List[float], mfa_end: List[float], fs: int, hop_length: int, span_to_repl: List[List[int]]): - align_start = paddle.to_tensor(mfa_start).unsqueeze(0) - align_end = paddle.to_tensor(mfa_end).unsqueeze(0) - align_start = paddle.floor(fs * align_start / hop_length).int() - align_end = paddle.floor(fs * align_end / hop_length).int() + align_start = np.array(mfa_start) + align_end = np.array(mfa_end) + align_start = np.floor(fs * align_start / hop_length).astype('int') + align_end = np.floor(fs * align_end / hop_length).astype('int') if span_to_repl[0] >= len(mfa_start): - span_bdy = [align_end[0].tolist()[-1], align_end[0].tolist()[-1]] + span_bdy = [align_end[-1], align_end[-1]] else: span_bdy = [ - align_start[0].tolist()[span_to_repl[0]], - align_end[0].tolist()[span_to_repl[1] - 1] + align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1] ] - return span_bdy + return span_bdy, align_start, align_end def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]): @@ -317,18 +160,22 @@ def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]): return dic +def get_max_idx(dic): + return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1] + + def get_phns_and_spans(wav_path: str, old_str: str="", new_str: str="", source_lang: str="english", target_lang: str="english"): - append_new_str = (old_str == new_str[:len(old_str)]) + is_append = (old_str == new_str[:len(old_str)]) old_phns, mfa_start, mfa_end = [], [], [] - + # source if source_lang == "english": - times2, word2phns = alignment(wav_path, old_str) + intervals, word2phns = alignment(wav_path, old_str) elif source_lang == "chinese": - times2, word2phns = alignment_zh(wav_path, old_str) + intervals, word2phns = alignment_zh(wav_path, old_str) _, tp_word2phns = words2phns_zh(old_str) for key, value in tp_word2phns.items(): @@ -337,51 +184,46 @@ def get_phns_and_spans(wav_path: str, tp_word2phns[key] = cur_val word2phns = recover_dict(word2phns, tp_word2phns) - else: - assert source_lang == "chinese" or source_lang == "english", "source_lang is wrong..." + assert source_lang == "chinese" or source_lang == "english", \ + "source_lang is wrong..." - for item in times2: + for item in intervals: + old_phns.append(item[0]) mfa_start.append(float(item[1])) mfa_end.append(float(item[2])) - old_phns.append(item[0]) - - if append_new_str and (source_lang != target_lang): - is_cross_lingual_clone = True + # target + if is_append and (source_lang != target_lang): + cross_lingual_clone = True else: - is_cross_lingual_clone = False + cross_lingual_clone = False - if is_cross_lingual_clone: - new_str_origin = new_str[:len(old_str)] - new_str_append = new_str[len(old_str):] + if cross_lingual_clone: + str_origin = new_str[:len(old_str)] + str_append = new_str[len(old_str):] if target_lang == "chinese": - new_phns_origin, new_origin_word2phns = words2phns(new_str_origin) - new_phns_append, temp_new_append_word2phns = words2phns_zh( - new_str_append) + phns_origin, origin_word2phns = words2phns(str_origin) + phns_append, append_word2phns_tmp = words2phns_zh(str_append) elif target_lang == "english": # 原始句子 - new_phns_origin, new_origin_word2phns = words2phns_zh( - new_str_origin) - # clone句子 - new_phns_append, temp_new_append_word2phns = words2phns( - new_str_append) + phns_origin, origin_word2phns = words2phns_zh(str_origin) + # clone 句子 + phns_append, append_word2phns_tmp = words2phns(str_append) else: assert target_lang == "chinese" or target_lang == "english", \ "cloning is not support for this language, please check it." - new_phns = new_phns_origin + new_phns_append + new_phns = phns_origin + phns_append - new_append_word2phns = {} - length = len(new_origin_word2phns) - for key, value in temp_new_append_word2phns.items(): + append_word2phns = {} + length = len(origin_word2phns) + for key, value in append_word2phns_tmp.items(): idx, wrd = key.split('_') - new_append_word2phns[str(int(idx) + length) + '_' + wrd] = value - - new_word2phns = dict( - list(new_origin_word2phns.items()) + list( - new_append_word2phns.items())) + append_word2phns[str(int(idx) + length) + '_' + wrd] = value + new_word2phns = origin_word2phns.copy() + new_word2phns.update(append_word2phns) else: if source_lang == target_lang and target_lang == "english": @@ -417,16 +259,17 @@ def get_phns_and_spans(wav_path: str, right_idx = 0 new_phns_right = [] sp_count = 0 - word2phns_max_idx = int(list(word2phns.keys())[-1].split('_')[0]) - new_word2phns_max_idx = int(list(new_word2phns.keys())[-1].split('_')[0]) + word2phns_max_idx = get_max_idx(word2phns) + new_word2phns_max_idx = get_max_idx(new_word2phns) new_phns_mid = [] - if append_new_str: + if is_append: new_phns_right = [] new_phns_mid = new_phns[left_idx:] span_to_repl[0] = len(new_phns_left) span_to_add[0] = len(new_phns_left) span_to_add[1] = len(new_phns_left) + len(new_phns_mid) span_to_repl[1] = len(old_phns) - len(new_phns_right) + # speech edit else: for key in list(word2phns.keys())[::-1]: idx, wrd = key.split('_') @@ -451,47 +294,57 @@ def get_phns_and_spans(wav_path: str, len(old_phns)) break new_phns = new_phns_left + new_phns_mid + new_phns_right - + ''' + For that reason cover should not be given. + For that reason cover is impossible to be given. + span_to_repl: [17, 23] "should not" + span_to_add: [17, 30] "is impossible to" + ''' return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add -def duration_adjust_factor(original_dur: List[int], - pred_dur: List[int], - phns: List[str]): +# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同 +# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放 +def get_dur_adj_factor(orig_dur: List[int], + pred_dur: List[int], + phns: List[str]): length = 0 factor_list = [] - for ori, pred, phn in zip(original_dur, pred_dur, phns): + for orig, pred, phn in zip(orig_dur, pred_dur, phns): if pred == 0 or phn == 'sp': continue else: - factor_list.append(ori / pred) + factor_list.append(orig / pred) factor_list = np.array(factor_list) factor_list.sort() if len(factor_list) < 5: return 1 - length = 2 - return np.average(factor_list[length:-length]) - - -def prepare_features_with_duration(uid: str, - prefix: str, - wav_path: str, - mlm_model: nn.Layer, - source_lang: str="English", - target_lang: str="English", - old_str: str="", - new_str: str="", - duration_preditor_path: str=None, - sid: str=None, - mask_reconstruct: bool=False, - duration_adjust: bool=True, - start_end_sp: bool=False, - train_args=None): - wav_org, rate = librosa.load( - wav_path, sr=train_args.feats_extract_conf['fs']) - fs = train_args.feats_extract_conf['fs'] - hop_length = train_args.feats_extract_conf['hop_length'] + avg = np.average(factor_list[length:-length]) + return avg + + +def prep_feats_with_dur(wav_path: str, + mlm_model: nn.Layer, + source_lang: str="English", + target_lang: str="English", + old_str: str="", + new_str: str="", + mask_reconstruct: bool=False, + duration_adjust: bool=True, + start_end_sp: bool=False, + fs: int=24000, + hop_length: int=300): + ''' + Returns: + np.ndarray: new wav, replace the part to be edited in original wav with 0 + List[str]: new phones + List[float]: mfa start of new wav + List[float]: mfa end of new wav + List[int]: masked mel boundary of original wav + List[int]: masked mel boundary of new wav + ''' + wav_org, _ = librosa.load(wav_path, sr=fs) mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans( wav_path=wav_path, @@ -503,144 +356,129 @@ def prepare_features_with_duration(uid: str, if start_end_sp: if new_phns[-1] != 'sp': new_phns = new_phns + ['sp'] - - if target_lang == "english": - old_durations = evaluate_durations(old_phns, target_lang=target_lang) - - elif target_lang == "chinese": - - if source_lang == "english": - old_durations = evaluate_durations( - old_phns, target_lang=source_lang) - - elif source_lang == "chinese": - old_durations = evaluate_durations( - old_phns, target_lang=source_lang) - + # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替 + if target_lang == "english" or target_lang == "chinese": + old_durs = eval_durs(old_phns, target_lang=source_lang) else: - assert target_lang == "chinese" or target_lang == "english", "calculate duration_predict is not support for this language..." + assert target_lang == "chinese" or target_lang == "english", \ + "calculate duration_predict is not support for this language..." - original_old_durations = [e - s for e, s in zip(mfa_end, mfa_start)] + orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)] if '[MASK]' in new_str: new_phns = old_phns span_to_add = span_to_repl - d_factor_left = duration_adjust_factor( - original_old_durations[:span_to_repl[0]], - old_durations[:span_to_repl[0]], old_phns[:span_to_repl[0]]) - d_factor_right = duration_adjust_factor( - original_old_durations[span_to_repl[1]:], - old_durations[span_to_repl[1]:], old_phns[span_to_repl[1]:]) + d_factor_left = get_dur_adj_factor( + orig_dur=orig_old_durs[:span_to_repl[0]], + pred_dur=old_durs[:span_to_repl[0]], + phns=old_phns[:span_to_repl[0]]) + d_factor_right = get_dur_adj_factor( + orig_dur=orig_old_durs[span_to_repl[1]:], + pred_dur=old_durs[span_to_repl[1]:], + phns=old_phns[span_to_repl[1]:]) d_factor = (d_factor_left + d_factor_right) / 2 - new_durations_adjusted = [d_factor * i for i in old_durations] + new_durs_adjusted = [d_factor * i for i in old_durs] else: if duration_adjust: - d_factor = duration_adjust_factor(original_old_durations, - old_durations, old_phns) + d_factor = get_dur_adj_factor( + orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns) d_factor = d_factor * 1.25 else: d_factor = 1 - if target_lang == "english": - new_durations = evaluate_durations( - new_phns, target_lang=target_lang) - - elif target_lang == "chinese": - new_durations = evaluate_durations( - new_phns, target_lang=target_lang) - - new_durations_adjusted = [d_factor * i for i in new_durations] - - if span_to_repl[0] < len(old_phns) and old_phns[span_to_repl[ - 0]] == new_phns[span_to_add[0]]: - new_durations_adjusted[span_to_add[0]] = original_old_durations[ - span_to_repl[0]] - if span_to_repl[1] < len(old_phns) and span_to_add[1] < len(new_phns): - if old_phns[span_to_repl[1]] == new_phns[span_to_add[1]]: - new_durations_adjusted[span_to_add[1]] = original_old_durations[ - span_to_repl[1]] - new_span_duration_sum = sum( - new_durations_adjusted[span_to_add[0]:span_to_add[1]]) - old_span_duration_sum = sum( - original_old_durations[span_to_repl[0]:span_to_repl[1]]) - duration_offset = new_span_duration_sum - old_span_duration_sum + if target_lang == "english" or target_lang == "chinese": + new_durs = eval_durs(new_phns, target_lang=target_lang) + else: + assert target_lang == "chinese" or target_lang == "english", \ + "calculate duration_predict is not support for this language..." + + new_durs_adjusted = [d_factor * i for i in new_durs] + + new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]]) + old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]]) + dur_offset = new_span_dur_sum - old_span_dur_sum new_mfa_start = mfa_start[:span_to_repl[0]] new_mfa_end = mfa_end[:span_to_repl[0]] - for i in new_durations_adjusted[span_to_add[0]:span_to_add[1]]: + for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]: if len(new_mfa_end) == 0: new_mfa_start.append(0) new_mfa_end.append(i) else: new_mfa_start.append(new_mfa_end[-1]) new_mfa_end.append(new_mfa_end[-1] + i) - new_mfa_start += [i + duration_offset for i in mfa_start[span_to_repl[1]:]] - new_mfa_end += [i + duration_offset for i in mfa_end[span_to_repl[1]:]] + new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]] + new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]] - # 3. get new wav + # 3. get new wav + # 在原始句子后拼接 if span_to_repl[0] >= len(mfa_start): left_idx = len(wav_org) right_idx = left_idx + # 在原始句子中间替换 else: left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs)) right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs)) - new_blank_wav = np.zeros( - (int(np.ceil(new_span_duration_sum * fs)), ), dtype=wav_org.dtype) - new_wav_org = np.concatenate( - [wav_org[:left_idx], new_blank_wav, wav_org[right_idx:]]) + blank_wav = np.zeros( + (int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype) + # 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定 + new_wav = np.concatenate( + [wav_org[:left_idx], blank_wav, wav_org[right_idx:]]) # 4. get old and new mel span to be mask # [92, 92] - old_span_bdy = get_masked_mel_bdy(mfa_start, mfa_end, fs, hop_length, - span_to_repl) + + old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy( + mfa_start=mfa_start, + mfa_end=mfa_end, + fs=fs, + hop_length=hop_length, + span_to_repl=span_to_repl) # [92, 174] - new_span_bdy = get_masked_mel_bdy(new_mfa_start, new_mfa_end, fs, - hop_length, span_to_add) - - return new_wav_org, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy - - -def prepare_features(uid: str, - mlm_model: nn.Layer, - processor, - wav_path: str, - prefix: str="./prompt/dev/", - source_lang: str="english", - target_lang: str="english", - old_str: str="", - new_str: str="", - duration_preditor_path: str=None, - sid: str=None, - duration_adjust: bool=True, - start_end_sp: bool=False, - mask_reconstruct: bool=False, - train_args=None): - wav_org, phns_list, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prepare_features_with_duration( - uid=uid, - prefix=prefix, + # new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别 + new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy( + mfa_start=new_mfa_start, + mfa_end=new_mfa_end, + fs=fs, + hop_length=hop_length, + span_to_repl=span_to_add) + + # old_span_bdy, new_span_bdy 是帧级别的范围 + return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy + + +def prep_feats(mlm_model: nn.Layer, + wav_path: str, + source_lang: str="english", + target_lang: str="english", + old_str: str="", + new_str: str="", + duration_adjust: bool=True, + start_end_sp: bool=False, + mask_reconstruct: bool=False, + fs: int=24000, + hop_length: int=300, + token_list: List[str]=[]): + wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur( source_lang=source_lang, target_lang=target_lang, mlm_model=mlm_model, old_str=old_str, new_str=new_str, wav_path=wav_path, - duration_preditor_path=duration_preditor_path, - sid=sid, duration_adjust=duration_adjust, start_end_sp=start_end_sp, mask_reconstruct=mask_reconstruct, - train_args=train_args) - speech = wav_org - align_start = np.array(mfa_start) - align_end = np.array(mfa_end) - token_to_id = {item: i for i, item in enumerate(train_args.token_list)} - text = np.array( - list( - map(lambda x: token_to_id.get(x, token_to_id['']), phns_list))) + fs=fs, + hop_length=hop_length) + token_to_id = {item: i for i, item in enumerate(token_list)} + text = np.array( + list(map(lambda x: token_to_id.get(x, token_to_id['']), phns))) span_bdy = np.array(new_span_bdy) + batch = [('1', { - "speech": speech, - "align_start": align_start, - "align_end": align_end, + "speech": wav, + "align_start": mfa_start, + "align_end": mfa_end, "text": text, "span_bdy": span_bdy })] @@ -648,375 +486,125 @@ def prepare_features(uid: str, return batch, old_span_bdy, new_span_bdy -def decode_with_model(uid: str, - mlm_model: nn.Layer, - processor, +def decode_with_model(mlm_model: nn.Layer, collate_fn, wav_path: str, - prefix: str="./prompt/dev/", source_lang: str="english", target_lang: str="english", old_str: str="", new_str: str="", - duration_preditor_path: str=None, - sid: str=None, - decoder: bool=False, use_teacher_forcing: bool=False, duration_adjust: bool=True, start_end_sp: bool=False, - train_args=None): - fs, hop_length = train_args.feats_extract_conf[ - 'fs'], train_args.feats_extract_conf['hop_length'] - - batch, old_span_bdy, new_span_bdy = prepare_features( - uid=uid, - prefix=prefix, + fs: int=24000, + hop_length: int=300, + token_list: List[str]=[]): + batch, old_span_bdy, new_span_bdy = prep_feats( source_lang=source_lang, target_lang=target_lang, mlm_model=mlm_model, - processor=processor, wav_path=wav_path, old_str=old_str, new_str=new_str, - duration_preditor_path=duration_preditor_path, - sid=sid, duration_adjust=duration_adjust, start_end_sp=start_end_sp, - train_args=train_args) + fs=fs, + hop_length=hop_length, + token_list=token_list) feats = collate_fn(batch)[1] if 'text_masked_pos' in feats.keys(): feats.pop('text_masked_pos') - for k, v in feats.items(): - feats[k] = paddle.to_tensor(v) - rtn = mlm_model.inference( - **feats, span_bdy=new_span_bdy, use_teacher_forcing=use_teacher_forcing) - output = rtn['feat_gen'] - if 0 in output[0].shape and 0 not in output[-1].shape: - output_feat = paddle.concat( - output[1:-1] + [output[-1].squeeze()], axis=0).cpu() - elif 0 not in output[0].shape and 0 in output[-1].shape: - output_feat = paddle.concat( - [output[0].squeeze()] + output[1:-1], axis=0).cpu() - elif 0 in output[0].shape and 0 in output[-1].shape: - output_feat = paddle.concat(output[1:-1], axis=0).cpu() - else: - output_feat = paddle.concat( - [output[0].squeeze(0)] + output[1:-1] + [output[-1].squeeze(0)], - axis=0).cpu() - - wav_org, _ = librosa.load( - wav_path, sr=train_args.feats_extract_conf['fs']) - return wav_org, None, output_feat, old_span_bdy, new_span_bdy, fs, hop_length - - -class MLMCollateFn: - """Functor class of common_collate_fn()""" - - def __init__(self, - feats_extract, - float_pad_value: Union[float, int]=0.0, - int_pad_value: int=-32768, - not_sequence: Collection[str]=(), - mlm_prob: float=0.8, - mean_phn_span: int=8, - attention_window: int=0, - pad_speech: bool=False, - sega_emb: bool=False, - duration_collect: bool=False, - text_masking: bool=False): - self.mlm_prob = mlm_prob - self.mean_phn_span = mean_phn_span - self.feats_extract = feats_extract - self.float_pad_value = float_pad_value - self.int_pad_value = int_pad_value - self.not_sequence = set(not_sequence) - self.attention_window = attention_window - self.pad_speech = pad_speech - self.sega_emb = sega_emb - self.duration_collect = duration_collect - self.text_masking = text_masking - - def __repr__(self): - return (f"{self.__class__}(float_pad_value={self.float_pad_value}, " - f"int_pad_value={self.float_pad_value})") - - def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]] - ) -> Tuple[List[str], Dict[str, paddle.Tensor]]: - return mlm_collate_fn( - data, - float_pad_value=self.float_pad_value, - int_pad_value=self.int_pad_value, - not_sequence=self.not_sequence, - mlm_prob=self.mlm_prob, - mean_phn_span=self.mean_phn_span, - feats_extract=self.feats_extract, - attention_window=self.attention_window, - pad_speech=self.pad_speech, - sega_emb=self.sega_emb, - duration_collect=self.duration_collect, - text_masking=self.text_masking) - - -def mlm_collate_fn( - data: Collection[Tuple[str, Dict[str, np.ndarray]]], - float_pad_value: Union[float, int]=0.0, - int_pad_value: int=-32768, - not_sequence: Collection[str]=(), - mlm_prob: float=0.8, - mean_phn_span: int=8, - feats_extract=None, - attention_window: int=0, - pad_speech: bool=False, - sega_emb: bool=False, - duration_collect: bool=False, - text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]: - uttids = [u for u, _ in data] - data = [d for _, d in data] - - assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" - assert all(not k.endswith("_lens") - for k in data[0]), f"*_lens is reserved: {list(data[0])}" - - output = {} - for key in data[0]: - # Each models, which accepts these values finally, are responsible - # to repaint the pad_value to the desired value for each tasks. - if data[0][key].dtype.kind == "i": - pad_value = int_pad_value - else: - pad_value = float_pad_value - - array_list = [d[key] for d in data] - - # Assume the first axis is length: - # tensor_list: Batch x (Length, ...) - tensor_list = [paddle.to_tensor(a) for a in array_list] - # tensor: (Batch, Length, ...) - tensor = pad_list(tensor_list, pad_value) - output[key] = tensor - - # lens: (Batch,) - if key not in not_sequence: - lens = paddle.to_tensor( - [d[key].shape[0] for d in data], dtype=paddle.int64) - output[key + "_lens"] = lens - - feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0])) - feats = paddle.to_tensor(feats) - feats_lens = paddle.shape(feats)[0] - feats = paddle.unsqueeze(feats, 0) - if 'text' not in output: - text = paddle.zeros(paddle.shape(feats_lens.unsqueeze(-1))) - 2 - text_lens = paddle.zeros(paddle.shape(feats_lens)) + 1 - max_tlen = 1 - align_start = paddle.zeros(paddle.shape(text)) - align_end = paddle.zeros(paddle.shape(text)) - align_start_lens = paddle.zeros(paddle.shape(feats_lens)) - sega_emb = False - mean_phn_span = 0 - mlm_prob = 0.15 - else: - text = output["text"] - text_lens = output["text_lens"] - align_start = output["align_start"] - align_start_lens = output["align_start_lens"] - align_end = output["align_end"] - align_start = paddle.floor(feats_extract.sr * align_start / - feats_extract.hop_length).int() - align_end = paddle.floor(feats_extract.sr * align_end / - feats_extract.hop_length).int() - max_tlen = max(text_lens) - max_slen = max(feats_lens) - speech_pad = feats[:, :max_slen] - if attention_window > 0 and pad_speech: - speech_pad, max_slen = pad_to_longformer_att_window( - speech_pad, max_slen, max_slen, attention_window) - max_len = max_slen + max_tlen - if attention_window > 0: - text_pad, max_tlen = pad_to_longformer_att_window( - text, max_len, max_tlen, attention_window) - else: - text_pad = text - text_mask = make_non_pad_mask( - text_lens, text_pad, length_dim=1).unsqueeze(-2) - if attention_window > 0: - text_mask = text_mask * 2 - speech_mask = make_non_pad_mask( - feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2) - span_bdy = None - if 'span_bdy' in output.keys(): - span_bdy = output['span_bdy'] - - if text_masking: - masked_pos, text_masked_pos, _ = phones_text_masking( - speech_pad, speech_mask, text_pad, text_mask, align_start, - align_end, align_start_lens, mlm_prob, mean_phn_span, span_bdy) - else: - text_masked_pos = paddle.zeros(paddle.shape(text_pad)) - masked_pos, _ = phones_masking(speech_pad, speech_mask, align_start, - align_end, align_start_lens, mlm_prob, - mean_phn_span, span_bdy) - - output_dict = {} - if duration_collect and 'text' in output: - reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens = get_seg_pos_reduce_duration( - speech_pad, text_pad, align_start, align_end, align_start_lens, - sega_emb, masked_pos, feats_lens) - speech_mask = make_non_pad_mask( - feats_lens, speech_pad[:, :reordered_idx.shape[1], 0], - length_dim=1).unsqueeze(-2) - output_dict['durations'] = durations - output_dict['reordered_idx'] = reordered_idx - else: - speech_seg_pos, text_seg_pos = get_seg_pos(speech_pad, text_pad, - align_start, align_end, - align_start_lens, sega_emb) - output_dict['speech'] = speech_pad - output_dict['text'] = text_pad - output_dict['masked_pos'] = masked_pos - output_dict['text_masked_pos'] = text_masked_pos - output_dict['speech_mask'] = speech_mask - output_dict['text_mask'] = text_mask - output_dict['speech_seg_pos'] = speech_seg_pos - output_dict['text_seg_pos'] = text_seg_pos - output_dict['speech_lens'] = output["speech_lens"] - output_dict['text_lens'] = text_lens - output = (uttids, output_dict) - return output - - -def build_collate_fn(args: argparse.Namespace, train: bool, epoch=-1): - # -> Callable[ - # [Collection[Tuple[str, Dict[str, np.ndarray]]]], - # Tuple[List[str], Dict[str, Tensor]], - # ]: - - # assert check_argument_types() - # return CommonCollateFn(float_pad_value=0.0, int_pad_value=0) - feats_extract_class = LogMelFBank - if args.feats_extract_conf['win_length'] is None: - args.feats_extract_conf['win_length'] = args.feats_extract_conf['n_fft'] - - args_dic = {} - for k, v in args.feats_extract_conf.items(): - if k == 'fs': - args_dic['sr'] = v - else: - args_dic[k] = v - feats_extract = feats_extract_class(**args_dic) - - sega_emb = True if args.encoder_conf['input_layer'] == 'sega_mlm' else False - if args.encoder_conf['selfattention_layer_type'] == 'longformer': - attention_window = args.encoder_conf['attention_window'] - pad_speech = True if 'pre_speech_layer' in args.encoder_conf and args.encoder_conf[ - 'pre_speech_layer'] > 0 else False - else: - attention_window = 0 - pad_speech = False - if epoch == -1: - mlm_prob_factor = 1 - else: - mlm_prob_factor = 0.8 - if 'duration_predictor_layers' in args.model_conf.keys( - ) and args.model_conf['duration_predictor_layers'] > 0: - duration_collect = True - else: - duration_collect = False - - return MLMCollateFn( - feats_extract, - float_pad_value=0.0, - int_pad_value=0, - mlm_prob=args.model_conf['mlm_prob'] * mlm_prob_factor, - mean_phn_span=args.model_conf['mean_phn_span'], - attention_window=attention_window, - pad_speech=pad_speech, - sega_emb=sega_emb, - duration_collect=duration_collect) - - -def get_mlm_output(uid: str, - wav_path: str, - prefix: str="./prompt/dev/", - model_name: str="conformer", + + output = mlm_model.inference( + text=feats['text'], + speech=feats['speech'], + masked_pos=feats['masked_pos'], + speech_mask=feats['speech_mask'], + text_mask=feats['text_mask'], + speech_seg_pos=feats['speech_seg_pos'], + text_seg_pos=feats['text_seg_pos'], + span_bdy=new_span_bdy, + use_teacher_forcing=use_teacher_forcing) + + # 拼接音频 + output_feat = paddle.concat(x=output, axis=0) + wav_org, _ = librosa.load(wav_path, sr=fs) + return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length + + +def get_mlm_output(wav_path: str, + model_name: str="paddle_checkpoint_en", source_lang: str="english", target_lang: str="english", old_str: str="", new_str: str="", - duration_preditor_path: str=None, - sid: str=None, - decoder: bool=False, use_teacher_forcing: bool=False, duration_adjust: bool=True, start_end_sp: bool=False): - mlm_model, train_args = load_model(model_name) + mlm_model, train_conf = load_model(model_name) mlm_model.eval() - processor = None - collate_fn = build_collate_fn(train_args, False) + + collate_fn = build_collate_fn( + sr=train_conf.feats_extract_conf['fs'], + n_fft=train_conf.feats_extract_conf['n_fft'], + hop_length=train_conf.feats_extract_conf['hop_length'], + win_length=train_conf.feats_extract_conf['win_length'], + n_mels=train_conf.feats_extract_conf['n_mels'], + fmin=train_conf.feats_extract_conf['fmin'], + fmax=train_conf.feats_extract_conf['fmax'], + mlm_prob=train_conf['mlm_prob'], + mean_phn_span=train_conf['mean_phn_span'], + train=False, + seg_emb=train_conf.encoder_conf['input_layer'] == 'sega_mlm') return decode_with_model( - uid=uid, - prefix=prefix, source_lang=source_lang, target_lang=target_lang, mlm_model=mlm_model, - processor=processor, collate_fn=collate_fn, wav_path=wav_path, old_str=old_str, new_str=new_str, - duration_preditor_path=duration_preditor_path, - sid=sid, - decoder=decoder, use_teacher_forcing=use_teacher_forcing, duration_adjust=duration_adjust, start_end_sp=start_end_sp, - train_args=train_args) + fs=train_conf.feats_extract_conf['fs'], + hop_length=train_conf.feats_extract_conf['hop_length'], + token_list=train_conf.token_list) def evaluate(uid: str, source_lang: str="english", target_lang: str="english", use_pt_vocoder: bool=False, - prefix: str="./prompt/dev/", - model_name: str="conformer", - old_str: str="", + prefix: os.PathLike="./prompt/dev/", + model_name: str="paddle_checkpoint_en", new_str: str="", prompt_decoding: bool=False, task_name: str=None): - duration_preditor_path = None - spemd = None - full_origin_str, wav_path = read_data(uid=uid, prefix=prefix) + # get origin text and path of origin wav + old_str, wav_path = read_data(uid=uid, prefix=prefix) if task_name == 'edit': new_str = new_str elif task_name == 'synthesize': - new_str = full_origin_str + new_str + new_str = old_str + new_str else: - new_str = full_origin_str + ' '.join( - [ch for ch in new_str if is_chinese(ch)]) + new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)]) print('new_str is ', new_str) - if not old_str: - old_str = full_origin_str - - results_dict, old_span = plot_mel_and_vocode_wav( - uid=uid, - prefix=prefix, + results_dict = get_wav( source_lang=source_lang, target_lang=target_lang, model_name=model_name, wav_path=wav_path, - full_origin_str=full_origin_str, old_str=old_str, new_str=new_str, - use_pt_vocoder=use_pt_vocoder, - duration_preditor_path=duration_preditor_path, - sid=spemd) + use_pt_vocoder=use_pt_vocoder) return results_dict diff --git a/ernie-sat/model_paddle.py b/ernie-sat/mlm.py similarity index 50% rename from ernie-sat/model_paddle.py rename to ernie-sat/mlm.py index f33c49ed233db83e1f7785a68a36bfeaad6a37df..2cf7921f6bfb5b190e47da8a25513efbc25f7dee 100644 --- a/ernie-sat/model_paddle.py +++ b/ernie-sat/mlm.py @@ -1,16 +1,12 @@ import argparse -import logging -import math import os import sys -from pathlib import Path from typing import Dict from typing import List from typing import Optional from typing import Tuple from typing import Union -import numpy as np import paddle import yaml from paddle import nn @@ -20,17 +16,18 @@ for dir_name in os.listdir(pypath): if os.path.isdir(dir_path): sys.path.append(dir_path) -from paddlespeech.s2t.utils.error_rate import ErrorCalculator from paddlespeech.t2s.modules.activation import get_activation from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer from paddlespeech.t2s.modules.masked_fill import masked_fill from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.tacotron2.decoder import Postnet +from paddlespeech.t2s.modules.transformer.embedding import LegacyRelPositionalEncoding from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling +from paddlespeech.t2s.modules.transformer.attention import LegacyRelPositionMultiHeadedAttention from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward @@ -39,65 +36,10 @@ from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredCo from paddlespeech.t2s.modules.transformer.repeat import repeat from paddlespeech.t2s.modules.layer_norm import LayerNorm - -class LegacyRelPositionalEncoding(PositionalEncoding): - """Relative positional encoding module (old version). - - Details can be found in https://github.com/espnet/espnet/pull/2816. - - See : Appendix B in https://arxiv.org/abs/1901.02860 - - Args: - d_model (int): Embedding dimension. - dropout_rate (float): Dropout rate. - max_len (int): Maximum input length. - - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000): - """ - Args: - d_model (int): Embedding dimension. - dropout_rate (float): Dropout rate. - max_len (int, optional): [Maximum input length.]. Defaults to 5000. - """ - super().__init__(d_model, dropout_rate, max_len, reverse=True) - - def extend_pe(self, x): - """Reset the positional encodings.""" - if self.pe is not None: - if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]: - return - pe = paddle.zeros((paddle.shape(x)[1], self.d_model)) - if self.reverse: - position = paddle.arange( - paddle.shape(x)[1] - 1, -1, -1.0, - dtype=paddle.float32).unsqueeze(1) - else: - position = paddle.arange( - 0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1) - div_term = paddle.exp( - paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * - -(math.log(10000.0) / self.d_model)) - pe[:, 0::2] = paddle.sin(position * div_term) - pe[:, 1::2] = paddle.cos(position * div_term) - pe = pe.unsqueeze(0) - self.pe = pe - - def forward(self, x: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Compute positional encoding. - Args: - x (paddle.Tensor): Input tensor (batch, time, `*`). - Returns: - paddle.Tensor: Encoded tensor (batch, time, `*`). - paddle.Tensor: Positional embedding tensor (1, time, `*`). - """ - self.extend_pe(x) - x = x * self.xscale - pos_emb = self.pe[:, :paddle.shape(x)[1]] - return self.dropout(x), self.dropout(pos_emb) +from yacs.config import CfgNode +# MLM -> Mask Language Model class mySequential(nn.Sequential): def forward(self, *inputs): for module in self._sub_layers.values(): @@ -108,12 +50,8 @@ class mySequential(nn.Sequential): return inputs -class NewMaskInputLayer(nn.Layer): - __constants__ = ['out_features'] - out_features: int - - def __init__(self, out_features: int, device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} +class MaskInputLayer(nn.Layer): + def __init__(self, out_features: int) -> None: super().__init__() self.mask_feature = paddle.create_parameter( shape=(1, 1, out_features), @@ -121,109 +59,14 @@ class NewMaskInputLayer(nn.Layer): default_initializer=paddle.nn.initializer.Assign( paddle.normal(shape=(1, 1, out_features)))) - def forward(self, input: paddle.Tensor, masked_pos=None) -> paddle.Tensor: + def forward(self, input: paddle.Tensor, + masked_pos: paddle.Tensor=None) -> paddle.Tensor: masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input) masked_input = masked_fill(input, masked_pos, 0) + masked_fill( paddle.expand_as(self.mask_feature, input), ~masked_pos, 0) return masked_input -class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): - """Multi-Head Attention layer with relative position encoding (old version). - Details can be found in https://github.com/espnet/espnet/pull/2816. - Paper: https://arxiv.org/abs/1901.02860 - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - zero_triu (bool): Whether to zero the upper triangular part of attention matrix. - """ - - def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): - """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head, n_feat, dropout_rate) - self.zero_triu = zero_triu - # linear transformation for positional encoding - self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False) - # these two learnable bias are used in matrix c and matrix d - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - - self.pos_bias_u = paddle.create_parameter( - shape=(self.h, self.d_k), - dtype='float32', - default_initializer=paddle.nn.initializer.XavierUniform()) - self.pos_bias_v = paddle.create_parameter( - shape=(self.h, self.d_k), - dtype='float32', - default_initializer=paddle.nn.initializer.XavierUniform()) - - def rel_shift(self, x): - """Compute relative positional encoding. - Args: - x(Tensor): Input tensor (batch, head, time1, time2). - - Returns: - Tensor:Output tensor. - """ - b, h, t1, t2 = paddle.shape(x) - zero_pad = paddle.zeros((b, h, t1, 1)) - x_padded = paddle.concat([zero_pad, x], axis=-1) - x_padded = paddle.reshape(x_padded, [b, h, t2 + 1, t1]) - # only keep the positions from 0 to time2 - x = paddle.reshape(x_padded[:, :, 1:], [b, h, t1, t2]) - - if self.zero_triu: - ones = paddle.ones((t1, t2)) - x = x * paddle.tril(ones, t2 - 1)[None, None, :, :] - - return x - - def forward(self, query, key, value, pos_emb, mask): - """Compute 'Scaled Dot Product Attention' with rel. positional encoding. - - Args: - query(Tensor): Query tensor (#batch, time1, size). - key(Tensor): Key tensor (#batch, time2, size). - value(Tensor): Value tensor (#batch, time2, size). - pos_emb(Tensor): Positional embedding tensor (#batch, time1, size). - mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2). - - Returns: - Tensor: Output tensor (#batch, time1, d_model). - """ - q, k, v = self.forward_qkv(query, key, value) - # (batch, time1, head, d_k) - q = paddle.transpose(q, [0, 2, 1, 3]) - - n_batch_pos = paddle.shape(pos_emb)[0] - p = paddle.reshape( - self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k]) - # (batch, head, time1, d_k) - p = paddle.transpose(p, [0, 2, 1, 3]) - # (batch, head, time1, d_k) - q_with_bias_u = paddle.transpose((q + self.pos_bias_u), [0, 2, 1, 3]) - # (batch, head, time1, d_k) - q_with_bias_v = paddle.transpose((q + self.pos_bias_v), [0, 2, 1, 3]) - - # compute attention score - # first compute matrix a and matrix c - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - # (batch, head, time1, time2) - matrix_ac = paddle.matmul(q_with_bias_u, - paddle.transpose(k, [0, 1, 3, 2])) - - # compute matrix b and matrix d - # (batch, head, time1, time1) - matrix_bd = paddle.matmul(q_with_bias_v, - paddle.transpose(p, [0, 1, 3, 2])) - matrix_bd = self.rel_shift(matrix_bd) - # (batch, head, time1, time2) - scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) - - return self.forward_attention(v, scores, mask) - - class MLMEncoder(nn.Layer): """Conformer encoder module. @@ -253,47 +96,42 @@ class MLMEncoder(nn.Layer): cnn_module_kernel (int): Kernerl size of convolution module. padding_idx (int): Padding idx for input_layer=embed. stochastic_depth_rate (float): Maximum probability to skip the encoder layer. - intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer. - indices start from 1. - if not None, intermediate outputs are returned (which changes return type - signature.) """ def __init__(self, - idim, - vocab_size=0, + idim: int, + vocab_size: int=0, pre_speech_layer: int=0, - attention_dim=256, - attention_heads=4, - linear_units=2048, - num_blocks=6, - dropout_rate=0.1, - positional_dropout_rate=0.1, - attention_dropout_rate=0.0, - input_layer="conv2d", - normalize_before=True, - concat_after=False, - positionwise_layer_type="linear", - positionwise_conv_kernel_size=1, - macaron_style=False, - pos_enc_layer_type="abs_pos", + attention_dim: int=256, + attention_heads: int=4, + linear_units: int=2048, + num_blocks: int=6, + dropout_rate: float=0.1, + positional_dropout_rate: float=0.1, + attention_dropout_rate: float=0.0, + input_layer: str="conv2d", + normalize_before: bool=True, + concat_after: bool=False, + positionwise_layer_type: str="linear", + positionwise_conv_kernel_size: int=1, + macaron_style: bool=False, + pos_enc_layer_type: str="abs_pos", pos_enc_class=None, - selfattention_layer_type="selfattn", - activation_type="swish", - use_cnn_module=False, - zero_triu=False, - cnn_module_kernel=31, - padding_idx=-1, - stochastic_depth_rate=0.0, - intermediate_layers=None, - text_masking=False): + selfattention_layer_type: str="selfattn", + activation_type: str="swish", + use_cnn_module: bool=False, + zero_triu: bool=False, + cnn_module_kernel: int=31, + padding_idx: int=-1, + stochastic_depth_rate: float=0.0, + text_masking: bool=False): """Construct an Encoder object.""" super().__init__() self._output_size = attention_dim self.text_masking = text_masking if self.text_masking: - self.text_masking_layer = NewMaskInputLayer(attention_dim) + self.text_masking_layer = MaskInputLayer(attention_dim) activation = get_activation(activation_type) if pos_enc_layer_type == "abs_pos": pos_enc_class = PositionalEncoding @@ -330,7 +168,7 @@ class MLMEncoder(nn.Layer): elif input_layer == "mlm": self.segment_emb = None self.speech_embed = mySequential( - NewMaskInputLayer(idim), + MaskInputLayer(idim), nn.Linear(idim, attention_dim), nn.LayerNorm(attention_dim), nn.ReLU(), @@ -343,7 +181,7 @@ class MLMEncoder(nn.Layer): self.segment_emb = nn.Embedding( 500, attention_dim, padding_idx=padding_idx) self.speech_embed = mySequential( - NewMaskInputLayer(idim), + MaskInputLayer(idim), nn.Linear(idim, attention_dim), nn.LayerNorm(attention_dim), nn.ReLU(), @@ -365,7 +203,6 @@ class MLMEncoder(nn.Layer): # self-attention module definition if selfattention_layer_type == "selfattn": - logging.info("encoder self-attention layer type = self-attention") encoder_selfattn_layer = MultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, ) @@ -375,8 +212,6 @@ class MLMEncoder(nn.Layer): encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, ) elif selfattention_layer_type == "rel_selfattn": - logging.info( - "encoder self-attention layer type = relative self-attention") assert pos_enc_layer_type == "rel_pos" encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, attention_dim, @@ -436,49 +271,38 @@ class MLMEncoder(nn.Layer): if self.normalize_before: self.after_norm = LayerNorm(attention_dim) - self.intermediate_layers = intermediate_layers - def forward(self, - speech_pad, - text_pad, - masked_pos, - speech_mask=None, - text_mask=None, - speech_seg_pos=None, - text_seg_pos=None): + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor=None, + text_mask: paddle.Tensor=None, + speech_seg_pos: paddle.Tensor=None, + text_seg_pos: paddle.Tensor=None): """Encode input sequence. """ if masked_pos is not None: - speech_pad = self.speech_embed(speech_pad, masked_pos) + speech = self.speech_embed(speech, masked_pos) else: - speech_pad = self.speech_embed(speech_pad) - # pure speech input - if -2 in np.array(text_pad): - text_pad = text_pad + 3 - text_mask = paddle.unsqueeze(bool(text_pad), 1) - text_seg_pos = paddle.zeros_like(text_pad) - text_pad = self.text_embed(text_pad) - text_pad = (text_pad[0] + self.segment_emb(text_seg_pos), - text_pad[1]) - text_seg_pos = None - elif text_pad is not None: - text_pad = self.text_embed(text_pad) + speech = self.speech_embed(speech) + if text is not None: + text = self.text_embed(text) if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb: speech_seg_emb = self.segment_emb(speech_seg_pos) text_seg_emb = self.segment_emb(text_seg_pos) - text_pad = (text_pad[0] + text_seg_emb, text_pad[1]) - speech_pad = (speech_pad[0] + speech_seg_emb, speech_pad[1]) + text = (text[0] + text_seg_emb, text[1]) + speech = (speech[0] + speech_seg_emb, speech[1]) if self.pre_speech_encoders: - speech_pad, _ = self.pre_speech_encoders(speech_pad, speech_mask) + speech, _ = self.pre_speech_encoders(speech, speech_mask) - if text_pad is not None: - xs = paddle.concat([speech_pad[0], text_pad[0]], axis=1) - xs_pos_emb = paddle.concat([speech_pad[1], text_pad[1]], axis=1) + if text is not None: + xs = paddle.concat([speech[0], text[0]], axis=1) + xs_pos_emb = paddle.concat([speech[1], text[1]], axis=1) masks = paddle.concat([speech_mask, text_mask], axis=-1) else: - xs = speech_pad[0] - xs_pos_emb = speech_pad[1] + xs = speech[0] + xs_pos_emb = speech[1] masks = speech_mask xs, masks = self.encoders((xs, xs_pos_emb), masks) @@ -492,7 +316,7 @@ class MLMEncoder(nn.Layer): class MLMDecoder(MLMEncoder): - def forward(self, xs, masks, masked_pos=None, segment_emb=None): + def forward(self, xs: paddle.Tensor, masks: paddle.Tensor): """Encode input sequence. Args: @@ -504,51 +328,19 @@ class MLMDecoder(MLMEncoder): paddle.Tensor: Mask tensor (#batch, time). """ - if not self.training: - masked_pos = None xs = self.embed(xs) - if segment_emb: - xs = (xs[0] + segment_emb, xs[1]) - if self.intermediate_layers is None: - xs, masks = self.encoders(xs, masks) - else: - intermediate_outputs = [] - for layer_idx, encoder_layer in enumerate(self.encoders): - xs, masks = encoder_layer(xs, masks) - - if (self.intermediate_layers is not None and - layer_idx + 1 in self.intermediate_layers): - encoder_output = xs - # intermediate branches also require normalization. - if self.normalize_before: - encoder_output = self.after_norm(encoder_output) - intermediate_outputs.append(encoder_output) + xs, masks = self.encoders(xs, masks) + if isinstance(xs, tuple): xs = xs[0] if self.normalize_before: xs = self.after_norm(xs) - if self.intermediate_layers is not None: - return xs, masks, intermediate_outputs return xs, masks -def pad_to_longformer_att_window(text, max_len, max_tlen, attention_window): - round = max_len % attention_window - if round != 0: - max_tlen += (attention_window - round) - n_batch = paddle.shape(text)[0] - text_pad = paddle.zeros( - shape=(n_batch, max_tlen, *paddle.shape(text[0])[1:]), - dtype=text.dtype) - for i in range(n_batch): - text_pad[i, :paddle.shape(text[i])[0]] = text[i] - else: - text_pad = text[:, :max_tlen] - return text_pad, max_tlen - - -class MLMModel(nn.Layer): +# encoder and decoder is nn.Layer, not str +class MLM(nn.Layer): def __init__(self, token_list: Union[Tuple[str, ...], List[str]], odim: int, @@ -557,44 +349,15 @@ class MLMModel(nn.Layer): postnet_layers: int=0, postnet_chans: int=0, postnet_filts: int=0, - ignore_id: int=-1, - lsm_weight: float=0.0, - length_normalized_loss: bool=False, - report_cer: bool=True, - report_wer: bool=True, - sym_space: str="", - sym_blank: str="", - masking_schema: str="span", - mean_phn_span: int=3, - mlm_prob: float=0.25, - dynamic_mlm_prob=False, - decoder_seg_pos=False, - text_masking=False): + text_masking: bool=False): super().__init__() - # note that eos is the same as sos (equivalent ID) self.odim = odim - self.ignore_id = ignore_id self.token_list = token_list.copy() - self.encoder = encoder - self.decoder = decoder self.vocab_size = encoder.text_embed[0]._num_embeddings - if report_cer or report_wer: - self.error_calculator = ErrorCalculator( - token_list, sym_space, sym_blank, report_cer, report_wer) - else: - self.error_calculator = None - - self.mlm_weight = 1.0 - self.mlm_prob = mlm_prob - self.mlm_layer = 12 - self.finetune_wo_mlm = True - self.max_span = 50 - self.min_span = 4 - self.mean_phn_span = mean_phn_span - self.masking_schema = masking_schema + if self.decoder is None or not (hasattr(self.decoder, 'output_layer') and self.decoder.output_layer is not None): @@ -606,15 +369,9 @@ class MLMModel(nn.Layer): self.encoder.text_embed[0]._embedding_dim, self.vocab_size, weight_attr=self.encoder.text_embed[0]._weight_attr) - self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id) else: self.text_sfc = None - self.text_mlm_loss = None - self.decoder_seg_pos = decoder_seg_pos - if lsm_weight > 50: - self.l1_loss_func = nn.MSELoss() - else: - self.l1_loss_func = nn.L1Loss(reduction='none') + self.postnet = (None if postnet_layers == 0 else Postnet( idim=self.encoder._output_size, odim=odim, @@ -624,119 +381,79 @@ class MLMModel(nn.Layer): use_batch_norm=True, dropout_rate=0.5, )) - def collect_feats(self, - speech, - speech_lens, - text, - text_lens, - masked_pos, - speech_mask, - text_mask, - speech_seg_pos, - text_seg_pos, - y_masks=None) -> Dict[str, paddle.Tensor]: - return {"feats": speech, "feats_lens": speech_lens} - - def forward(self, batch, speech_seg_pos, y_masks=None): - # feats: (Batch, Length, Dim) - # -> encoder_out: (Batch, Length2, Dim2) - speech_pad_placeholder = batch['speech_pad'] - if self.decoder is not None: - ys_in = self._add_first_frame_and_remove_last_frame( - batch['speech_pad']) - encoder_out, h_masks = self.encoder(**batch) - if self.decoder is not None: - zs, _ = self.decoder(ys_in, y_masks, encoder_out, - bool(h_masks), - self.encoder.segment_emb(speech_seg_pos)) - speech_hidden_states = zs - else: - speech_hidden_states = encoder_out[:, :paddle.shape(batch[ - 'speech_pad'])[1], :] - if self.sfc is not None: - before_outs = paddle.reshape( - self.sfc(speech_hidden_states), - (paddle.shape(speech_hidden_states)[0], -1, self.odim)) - else: - before_outs = speech_hidden_states - if self.postnet is not None: - after_outs = before_outs + paddle.transpose( - self.postnet(paddle.transpose(before_outs, [0, 2, 1])), - (0, 2, 1)) - else: - after_outs = None - return before_outs, after_outs, speech_pad_placeholder, batch[ - 'masked_pos'] - def inference( self, - speech, - text, - masked_pos, - speech_mask, - text_mask, - speech_seg_pos, - text_seg_pos, - span_bdy, - y_masks=None, - speech_lens=None, - text_lens=None, - feats: Optional[paddle.Tensor]=None, - spembs: Optional[paddle.Tensor]=None, - sids: Optional[paddle.Tensor]=None, - lids: Optional[paddle.Tensor]=None, - threshold: float=0.5, - minlenratio: float=0.0, - maxlenratio: float=10.0, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor, + span_bdy: List[int], use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]: + ''' + Args: + speech (paddle.Tensor): input speech (1, Tmax, D). + text (paddle.Tensor): input text (1, Tmax2). + masked_pos (paddle.Tensor): masked position of input speech (1, Tmax) + speech_mask (paddle.Tensor): mask of speech (1, 1, Tmax). + text_mask (paddle.Tensor): mask of text (1, 1, Tmax2). + speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (1, Tmax). + text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (1, Tmax2). + span_bdy (List[int]): masked mel boundary of input speech (2,) + use_teacher_forcing (bool): whether to use teacher forcing + Returns: + List[Tensor]: + eg: + [Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])] + ''' - batch = dict( - speech_pad=speech, - text_pad=text, - masked_pos=masked_pos, - speech_mask=speech_mask, - text_mask=text_mask, - speech_seg_pos=speech_seg_pos, - text_seg_pos=text_seg_pos, ) - - # # inference with teacher forcing - # hs, h_masks = self.encoder(**batch) - - outs = [batch['speech_pad'][:, :span_bdy[0]]] z_cache = None if use_teacher_forcing: - before, zs, _, _ = self.forward( - batch, speech_seg_pos, y_masks=y_masks) + before_outs, zs, *_ = self.forward( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos) if zs is None: - zs = before + zs = before_outs + + speech = speech.squeeze(0) + outs = [speech[:span_bdy[0]]] outs += [zs[0][span_bdy[0]:span_bdy[1]]] - outs += [batch['speech_pad'][:, span_bdy[1]:]] - return dict(feat_gen=outs) + outs += [speech[span_bdy[1]:]] + return outs return None - def _add_first_frame_and_remove_last_frame( - self, ys: paddle.Tensor) -> paddle.Tensor: - ys_in = paddle.concat( - [ - paddle.zeros( - shape=(paddle.shape(ys)[0], 1, paddle.shape(ys)[2]), - dtype=ys.dtype), ys[:, :-1] - ], - axis=1) - return ys_in - -class MLMEncAsDecoderModel(MLMModel): - def forward(self, batch, speech_seg_pos, y_masks=None): +class MLMEncAsDecoder(MLM): + def forward(self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor): # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) - speech_pad_placeholder = batch['speech_pad'] - encoder_out, h_masks = self.encoder(**batch) # segment_emb + encoder_out, h_masks = self.encoder( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos) if self.decoder is not None: zs, _ = self.decoder(encoder_out, h_masks) else: zs = encoder_out - speech_hidden_states = zs[:, :paddle.shape(batch['speech_pad'])[1], :] + speech_hidden_states = zs[:, :paddle.shape(speech)[1], :] if self.sfc is not None: before_outs = paddle.reshape( self.sfc(speech_hidden_states), @@ -749,53 +466,35 @@ class MLMEncAsDecoderModel(MLMModel): [0, 2, 1]) else: after_outs = None - return before_outs, after_outs, speech_pad_placeholder, batch[ - 'masked_pos'] - - -class MLMDualMaksingModel(MLMModel): - def _calc_mlm_loss(self, - before_outs: paddle.Tensor, - after_outs: paddle.Tensor, - text_outs: paddle.Tensor, - batch): - xs_pad = batch['speech_pad'] - text_pad = batch['text_pad'] - masked_pos = batch['masked_pos'] - text_masked_pos = batch['text_masked_pos'] - mlm_loss_pos = masked_pos > 0 - loss = paddle.sum( - self.l1_loss_func( - paddle.reshape(before_outs, (-1, self.odim)), - paddle.reshape(xs_pad, (-1, self.odim))), - axis=-1) - if after_outs is not None: - loss += paddle.sum( - self.l1_loss_func( - paddle.reshape(after_outs, (-1, self.odim)), - paddle.reshape(xs_pad, (-1, self.odim))), - axis=-1) - loss_mlm = paddle.sum((loss * paddle.reshape( - mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10) - - loss_text = paddle.sum((self.text_mlm_loss( - paddle.reshape(text_outs, (-1, self.vocab_size)), - paddle.reshape(text_pad, (-1))) * paddle.reshape( - text_masked_pos, (-1)))) / paddle.sum((text_masked_pos) + 1e-10) - return loss_mlm, loss_text - - def forward(self, batch, speech_seg_pos, y_masks=None): + return before_outs, after_outs, None + + +class MLMDualMaksing(MLM): + def forward(self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor): # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) - encoder_out, h_masks = self.encoder(**batch) # segment_emb + encoder_out, h_masks = self.encoder( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos) if self.decoder is not None: zs, _ = self.decoder(encoder_out, h_masks) else: zs = encoder_out - speech_hidden_states = zs[:, :paddle.shape(batch['speech_pad'])[1], :] + speech_hidden_states = zs[:, :paddle.shape(speech)[1], :] if self.text_sfc: - text_hiddent_states = zs[:, paddle.shape(batch['speech_pad'])[ - 1]:, :] + text_hiddent_states = zs[:, paddle.shape(speech)[1]:, :] text_outs = paddle.reshape( self.text_sfc(text_hiddent_states), (paddle.shape(text_hiddent_states)[0], -1, self.vocab_size)) @@ -811,27 +510,25 @@ class MLMDualMaksingModel(MLMModel): [0, 2, 1]) else: after_outs = None - return before_outs, after_outs, text_outs, None #, speech_pad_placeholder, batch['masked_pos'],batch['text_masked_pos'] + return before_outs, after_outs, text_outs def build_model_from_file(config_file, model_file): state_dict = paddle.load(model_file) - model_class = MLMDualMaksingModel if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \ - else MLMEncAsDecoderModel + model_class = MLMDualMaksing if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \ + else MLMEncAsDecoder # 构建模型 - args = yaml.safe_load(Path(config_file).open("r", encoding="utf-8")) - args = argparse.Namespace(**args) - - model = build_model(args, model_class) - + with open(config_file) as f: + conf = CfgNode(yaml.safe_load(f)) + model = build_model(conf, model_class) model.set_state_dict(state_dict) - return model, args + return model, conf -def build_model(args: argparse.Namespace, - model_class=MLMEncAsDecoderModel) -> MLMModel: +# select encoder and decoder here +def build_model(args: argparse.Namespace, model_class=MLMEncAsDecoder) -> MLM: if isinstance(args.token_list, str): with open(args.token_list, encoding="utf-8") as f: token_list = [line.rstrip() for line in f] @@ -842,9 +539,8 @@ def build_model(args: argparse.Namespace, token_list = list(args.token_list) else: raise RuntimeError("token_list must be str or list") - vocab_size = len(token_list) - logging.info(f"Vocabulary size: {vocab_size }") + vocab_size = len(token_list) odim = 80 pos_enc_class = ScaledPositionalEncoding if args.use_scaled_pos_enc else PositionalEncoding @@ -857,17 +553,8 @@ def build_model(args: argparse.Namespace, if conformer_rel_pos_type == "legacy": if conformer_pos_enc_layer_type == "rel_pos": conformer_pos_enc_layer_type = "legacy_rel_pos" - logging.warning( - "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " - "due to the compatibility. If you want to use the new one, " - "please use conformer_pos_enc_layer_type = 'latest'.") if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" - logging.warning( - "Fallback to " - "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " - "due to the compatibility. If you want to use the new one, " - "please use conformer_pos_enc_layer_type = 'latest'.") elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" diff --git a/ernie-sat/mlm_loss.py b/ernie-sat/mlm_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..cd69d1ff9832dacd758c8212aa6f971233c9fc1e --- /dev/null +++ b/ernie-sat/mlm_loss.py @@ -0,0 +1,53 @@ +import paddle +from paddle import nn + + +class MLMLoss(nn.Layer): + def __init__(self, + lsm_weight: float=0.1, + ignore_id: int=-1, + text_masking: bool=False): + super().__init__() + if text_masking: + self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id) + if lsm_weight > 50: + self.l1_loss_func = nn.MSELoss() + else: + self.l1_loss_func = nn.L1Loss(reduction='none') + self.text_masking = text_masking + + def forward(self, + speech: paddle.Tensor, + before_outs: paddle.Tensor, + after_outs: paddle.Tensor, + masked_pos: paddle.Tensor, + text: paddle.Tensor=None, + text_outs: paddle.Tensor=None, + text_masked_pos: paddle.Tensor=None): + + xs_pad = speech + mlm_loss_pos = masked_pos > 0 + loss = paddle.sum( + self.l1_loss_func( + paddle.reshape(before_outs, (-1, self.odim)), + paddle.reshape(xs_pad, (-1, self.odim))), + axis=-1) + if after_outs is not None: + loss += paddle.sum( + self.l1_loss_func( + paddle.reshape(after_outs, (-1, self.odim)), + paddle.reshape(xs_pad, (-1, self.odim))), + axis=-1) + loss_mlm = paddle.sum((loss * paddle.reshape( + mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10) + + if self.text_masking: + loss_text = paddle.sum((self.text_mlm_loss( + paddle.reshape(text_outs, (-1, self.vocab_size)), + paddle.reshape(text, (-1))) * paddle.reshape( + text_masked_pos, + (-1)))) / paddle.sum((text_masked_pos) + 1e-10) + + return loss_mlm, loss_text + + return loss_mlm diff --git a/ernie-sat/paddlespeech/t2s/modules/transformer/attention.py b/ernie-sat/paddlespeech/t2s/modules/transformer/attention.py index cdb95b211ab9a60fad27c64fad6bb4dca86ffb3a..475b3fc32f1116a1abd5426a44b2cdf3f61be48d 100644 --- a/ernie-sat/paddlespeech/t2s/modules/transformer/attention.py +++ b/ernie-sat/paddlespeech/t2s/modules/transformer/attention.py @@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask) + +class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding (old version). + Details can be found in https://github.com/espnet/espnet/pull/2816. + Paper: https://arxiv.org/abs/1901.02860 + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + """ + + def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + + self.pos_bias_u = paddle.create_parameter( + shape=(self.h, self.d_k), + dtype='float32', + default_initializer=paddle.nn.initializer.XavierUniform()) + self.pos_bias_v = paddle.create_parameter( + shape=(self.h, self.d_k), + dtype='float32', + default_initializer=paddle.nn.initializer.XavierUniform()) + + def rel_shift(self, x): + """Compute relative positional encoding. + Args: + x(Tensor): Input tensor (batch, head, time1, time2). + + Returns: + Tensor:Output tensor. + """ + b, h, t1, t2 = paddle.shape(x) + zero_pad = paddle.zeros((b, h, t1, 1)) + x_padded = paddle.concat([zero_pad, x], axis=-1) + x_padded = paddle.reshape(x_padded, [b, h, t2 + 1, t1]) + # only keep the positions from 0 to time2 + x = paddle.reshape(x_padded[:, :, 1:], [b, h, t1, t2]) + + if self.zero_triu: + ones = paddle.ones((t1, t2)) + x = x * paddle.tril(ones, t2 - 1)[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + + Args: + query(Tensor): Query tensor (#batch, time1, size). + key(Tensor): Key tensor (#batch, time2, size). + value(Tensor): Value tensor (#batch, time2, size). + pos_emb(Tensor): Positional embedding tensor (#batch, time1, size). + mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + # (batch, time1, head, d_k) + q = paddle.transpose(q, [0, 2, 1, 3]) + + n_batch_pos = paddle.shape(pos_emb)[0] + p = paddle.reshape( + self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k]) + # (batch, head, time1, d_k) + p = paddle.transpose(p, [0, 2, 1, 3]) + # (batch, head, time1, d_k) + q_with_bias_u = paddle.transpose((q + self.pos_bias_u), [0, 2, 1, 3]) + # (batch, head, time1, d_k) + q_with_bias_v = paddle.transpose((q + self.pos_bias_v), [0, 2, 1, 3]) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = paddle.matmul(q_with_bias_u, + paddle.transpose(k, [0, 1, 3, 2])) + + # compute matrix b and matrix d + # (batch, head, time1, time1) + matrix_bd = paddle.matmul(q_with_bias_v, + paddle.transpose(p, [0, 1, 3, 2])) + matrix_bd = self.rel_shift(matrix_bd) + # (batch, head, time1, time2) + scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) + + return self.forward_attention(v, scores, mask) + diff --git a/ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py b/ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py index d9339d20bb79c1732e453b5b17f9ef127b8a687a..62dd2171d0254d05c7abfa413ca96c6191daa56f 100644 --- a/ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py +++ b/ernie-sat/paddlespeech/t2s/modules/transformer/embedding.py @@ -185,3 +185,63 @@ class RelPositionalEncoding(nn.Layer): pe_size = paddle.shape(self.pe) pos_emb = self.pe[:, pe_size[1] // 2 - T + 1:pe_size[1] // 2 + T, ] return self.dropout(x), self.dropout(pos_emb) + + +class LegacyRelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module (old version). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000): + """ + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int, optional): [Maximum input length.]. Defaults to 5000. + """ + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]: + return + pe = paddle.zeros((paddle.shape(x)[1], self.d_model)) + if self.reverse: + position = paddle.arange( + paddle.shape(x)[1] - 1, -1, -1.0, + dtype=paddle.float32).unsqueeze(1) + else: + position = paddle.arange( + 0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * + -(math.log(10000.0) / self.d_model)) + pe[:, 0::2] = paddle.sin(position * div_term) + pe[:, 1::2] = paddle.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe + + def forward(self, x: paddle.Tensor): + """Compute positional encoding. + Args: + x (paddle.Tensor): Input tensor (batch, time, `*`). + Returns: + paddle.Tensor: Encoded tensor (batch, time, `*`). + paddle.Tensor: Positional embedding tensor (1, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[:, :paddle.shape(x)[1]] + return self.dropout(x), self.dropout(pos_emb) + + diff --git a/ernie-sat/read_text.py b/ernie-sat/read_text.py index bcf1125ef041338eff9efef254caab25311afd31..b57e60c160817f9f964a34ab25c4caa1fc9543bb 100644 --- a/ernie-sat/read_text.py +++ b/ernie-sat/read_text.py @@ -5,7 +5,7 @@ from typing import List from typing import Union -def read_2column_text(path: Union[Path, str]) -> Dict[str, str]: +def read_2col_text(path: Union[Path, str]) -> Dict[str, str]: """Read a text file having 2 column as dict object. Examples: @@ -13,7 +13,7 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]: key1 /some/path/a.wav key2 /some/path/b.wav - >>> read_2column_text('wav.scp') + >>> read_2col_text('wav.scp') {'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'} """ diff --git a/ernie-sat/sedit_arg_parser.py b/ernie-sat/sedit_arg_parser.py index 01d0b47ef7831c973fe54f70c21c8f9600b53d52..7c06b649842f02c366ca96704df70878177b9b70 100644 --- a/ernie-sat/sedit_arg_parser.py +++ b/ernie-sat/sedit_arg_parser.py @@ -65,12 +65,6 @@ def parse_args(): help="mean and standard deviation used to normalize spectrogram when training voc." ) # other - parser.add_argument( - '--lang', - type=str, - default='en', - help='Choose model language. zh or en') - parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") # parser.add_argument("--test_metadata", type=str, help="test metadata.") diff --git a/ernie-sat/utils.py b/ernie-sat/utils.py index 1cd4f0083afda00d0c0af27dfad7e21cdb96a903..1b74dc014ab1b4f45015a4eabef909e7097a6b9d 100644 --- a/ernie-sat/utils.py +++ b/ernie-sat/utils.py @@ -1,5 +1,4 @@ import os -from typing import List from typing import Optional import numpy as np @@ -55,16 +54,14 @@ def build_vocoder_from_file( raise ValueError(f"{vocoder_file} is not supported format.") -def get_voc_out(mel, target_lang: str="chinese"): +def get_voc_out(mel): # vocoder args = parse_args() - assert target_lang == "chinese" or target_lang == "english", "In get_voc_out function, target_lang is illegal..." - # print("current vocoder: ", args.voc) with open(args.voc_config) as f: voc_config = CfgNode(yaml.safe_load(f)) - voc_inference = voc_inference = get_voc_inference( + voc_inference = get_voc_inference( voc=args.voc, voc_config=voc_config, voc_ckpt=args.voc_ckpt, @@ -167,19 +164,23 @@ def get_voc_inference( return voc_inference -def evaluate_durations(phns: List[str], - target_lang: str="chinese", - fs: int=24000, - hop_length: int=300): +def eval_durs(phns, target_lang="chinese", fs=24000, hop_length=300): args = parse_args() if target_lang == 'english': - args.lang = 'en' + args.am = "fastspeech2_ljspeech" + args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml" + args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz" + args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy" + args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt" elif target_lang == 'chinese': - args.lang = 'zh' + args.am = "fastspeech2_csmsc" + args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" + args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz" + args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy" + args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" - # args = parser.parse_args(args=[]) if args.ngpu == 0: paddle.set_device("cpu") elif args.ngpu > 0: @@ -187,8 +188,6 @@ def evaluate_durations(phns: List[str], else: print("ngpu should >= 0 !") - assert target_lang == "chinese" or target_lang == "english", "In evaluate_durations function, target_lang is illegal..." - # Init body. with open(args.am_config) as f: am_config = CfgNode(yaml.safe_load(f)) @@ -203,22 +202,19 @@ def evaluate_durations(phns: List[str], speaker_dict=args.speaker_dict, return_am=True) - torch_phns = phns vocab_phones = {} with open(args.phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] for tone, id in phn_id: vocab_phones[tone] = int(id) vocab_size = len(vocab_phones) - phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns] + phonemes = [phn if phn in vocab_phones else "sp" for phn in phns] phone_ids = [vocab_phones[item] for item in phonemes] - phone_ids_new = phone_ids - phone_ids_new.append(vocab_size - 1) - phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64)) - normalized_mel, d_outs, p_outs, e_outs = am.inference( - phone_ids_new, spk_id=None, spk_emb=None) + phone_ids.append(vocab_size - 1) + phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64)) + _, d_outs, _, _ = am.inference(phone_ids, spk_id=None, spk_emb=None) pre_d_outs = d_outs - phoneme_durations_new = pre_d_outs * hop_length / fs - phoneme_durations_new = phoneme_durations_new.tolist()[:-1] - return phoneme_durations_new + phu_durs_new = pre_d_outs * hop_length / fs + phu_durs_new = phu_durs_new.tolist()[:-1] + return phu_durs_new