未验证 提交 30986264 编写于 作者: K Kennycao123 提交者: GitHub

Merge pull request #827 from yt605155624/format

[ernie sat]add docstring
...@@ -39,9 +39,9 @@ ERNIE-SAT 中我们提出了两项创新: ...@@ -39,9 +39,9 @@ ERNIE-SAT 中我们提出了两项创新:
### 2.预训练模型 ### 2.预训练模型
预训练模型 ERNIE-SAT 的模型如下所示: 预训练模型 ERNIE-SAT 的模型如下所示:
- [ERNIE-SAT_ZH](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-zh.tar.gz) - [ERNIE-SAT_ZH](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-zh.tar.gz)
- [ERNIE-SAT_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en.tar.gz) - [ERNIE-SAT_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en.tar.gz)
- [ERNIE-SAT_ZH_and_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en_zh.tar.gz) - [ERNIE-SAT_ZH_and_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en_zh.tar.gz)
创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压: 创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
...@@ -108,7 +108,7 @@ prompt/dev ...@@ -108,7 +108,7 @@ prompt/dev
3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset} 3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。 4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。
5. `--lang` 对应模型的语言可以是 `zh``en` 5. `--lang` 对应模型的语言可以是 `zh``en`
6. `--ngpu` 要使用的GPU数,如果 ngpu==0,则使用 cpu。 6. `--ngpu` 要使用的 GPU 数,如果 ngpu==0,则使用 cpu。
7. ` --model_name` 模型名称 7. ` --model_name` 模型名称
8. ` --uid` 特定提示(prompt)语音的 id 8. ` --uid` 特定提示(prompt)语音的 id
9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本) 9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本)
...@@ -125,4 +125,3 @@ sh run_sedit_en.sh # 语音编辑任务(英文) ...@@ -125,4 +125,3 @@ sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh # 个性化语音合成任务(英文) sh run_gen_en.sh # 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆) sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
``` ```
#!/usr/bin/env python
""" Usage: """ Usage:
align.py wavfile trsfile outwordfile outphonefile align.py wavfile trsfile outwordfile outphonefile
""" """
import multiprocessing as mp
import os import os
import sys import sys
from tqdm import tqdm
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme' PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN = 'tools/aligner/english' MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin' MODEL_DIR_ZH = 'tools/aligner/mandarin'
...@@ -15,6 +11,142 @@ HVITE = 'tools/htk/HTKTools/HVite' ...@@ -15,6 +11,142 @@ HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy' HCOPY = 'tools/htk/HTKTools/HCopy'
def get_unk_phns(word_str: str):
tmpbase = '/tmp/tp.'
f = open(tmpbase + 'temp.words', 'w')
f.write(word_str)
f.close()
os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase +
'temp.phons')
f = open(tmpbase + 'temp.phons', 'r')
lines2 = f.readline().strip().split()
f.close()
phns = []
for phn in lines2:
phons = phn.replace('\n', '').replace(' ', '')
seq = []
j = 0
while (j < len(phons)):
if (phons[j] > 'Z'):
if (phons[j] == 'j'):
seq.append('JH')
elif (phons[j] == 'h'):
seq.append('HH')
else:
seq.append(phons[j].upper())
j += 1
else:
p = phons[j:j + 2]
if (p == 'WH'):
seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
seq.append(p)
elif (p == 'AX'):
seq.append('AH0')
else:
seq.append(p + '1')
j += 2
phns.extend(seq)
return phns
def words2phns(line: str):
'''
Args:
line (str): input text.
eg: for that reason cover is impossible to be given.
Returns:
List[str]: phones of input text.
eg:
['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0',
'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1',
'G', 'IH1', 'V', 'AH0', 'N']
Dict(str, str): key - idx_word
value - phones
eg:
{'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],
'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'],
'6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']}
'''
dictfile = MODEL_DIR_EN + '/dict'
line = line.strip()
words = []
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd)
phns.extend(get_unk_phns(wrd))
else:
wrd2phns[str(index) +
"_" + wrd.upper()] = word2phns_dict[wrd.upper()].split()
phns.extend(word2phns_dict[wrd.upper()].split())
return phns, wrd2phns
def words2phns_zh(line: str):
dictfile = MODEL_DIR_ZH + '/dict'
line = line.strip()
words = []
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
print("出现非法词错误,请输入正确的文本...")
else:
wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split()
phns.extend(word2phns_dict[wrd].split())
return phns, wrd2phns
def prep_txt_zh(line: str, tmpbase: str, dictfile: str): def prep_txt_zh(line: str, tmpbase: str, dictfile: str):
words = [] words = []
...@@ -82,7 +214,7 @@ def prep_txt_en(line: str, tmpbase, dictfile): ...@@ -82,7 +214,7 @@ def prep_txt_en(line: str, tmpbase, dictfile):
try: try:
os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase + os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase +
'_unk.phons') '_unk.phons')
except: except Exception:
print('english2phoneme error!') print('english2phoneme error!')
sys.exit(1) sys.exit(1)
...@@ -148,19 +280,22 @@ def _get_user(): ...@@ -148,19 +280,22 @@ def _get_user():
def alignment(wav_path: str, text: str): def alignment(wav_path: str, text: str):
'''
intervals: List[phn, start, end]
'''
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files #prepare wav and trs files
try: try:
os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -') os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -')
except: except Exception:
print('sox error!') print('sox error!')
return None return None
#prepare clean_transcript file #prepare clean_transcript file
try: try:
prep_txt_en(text, tmpbase, MODEL_DIR_EN + '/dict') prep_txt_en(line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_EN + '/dict')
except: except Exception:
print('prep_txt error!') print('prep_txt error!')
return None return None
...@@ -169,7 +304,7 @@ def alignment(wav_path: str, text: str): ...@@ -169,7 +304,7 @@ def alignment(wav_path: str, text: str):
with open(tmpbase + '.txt', 'r') as fid: with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline() txt = fid.readline()
prep_mlf(txt, tmpbase) prep_mlf(txt, tmpbase)
except: except Exception:
print('prep_mlf error!') print('prep_mlf error!')
return None return None
...@@ -177,7 +312,7 @@ def alignment(wav_path: str, text: str): ...@@ -177,7 +312,7 @@ def alignment(wav_path: str, text: str):
try: try:
os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase + os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp') '.wav' + ' ' + tmpbase + '.plp')
except: except Exception:
print('HCopy error!') print('HCopy error!')
return None return None
...@@ -188,7 +323,7 @@ def alignment(wav_path: str, text: str): ...@@ -188,7 +323,7 @@ def alignment(wav_path: str, text: str):
+ '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase + + '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase +
'.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase + '.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null') '.plp 2>&1 > /dev/null')
except: except Exception:
print('HVite error!') print('HVite error!')
return None return None
...@@ -200,7 +335,7 @@ def alignment(wav_path: str, text: str): ...@@ -200,7 +335,7 @@ def alignment(wav_path: str, text: str):
with open(tmpbase + '.aligned', 'r') as fid: with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines() lines = fid.readlines()
i = 2 i = 2
times2 = [] intervals = []
word2phns = {} word2phns = {}
current_word = '' current_word = ''
index = 0 index = 0
...@@ -210,7 +345,7 @@ def alignment(wav_path: str, text: str): ...@@ -210,7 +345,7 @@ def alignment(wav_path: str, text: str):
phn = splited_line[2] phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000 pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000 pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen]) intervals.append([phn, pst, pen])
# splited_line[-1]!='sp' # splited_line[-1]!='sp'
if len(splited_line) == 5: if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1] current_word = str(index) + '_' + splited_line[-1]
...@@ -219,10 +354,10 @@ def alignment(wav_path: str, text: str): ...@@ -219,10 +354,10 @@ def alignment(wav_path: str, text: str):
elif len(splited_line) == 4: elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn word2phns[current_word] += ' ' + phn
i += 1 i += 1
return times2, word2phns return intervals, word2phns
def alignment_zh(wav_path, text_string): def alignment_zh(wav_path: str, text: str):
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid()) tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files #prepare wav and trs files
...@@ -230,18 +365,19 @@ def alignment_zh(wav_path, text_string): ...@@ -230,18 +365,19 @@ def alignment_zh(wav_path, text_string):
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase + os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -') '.wav remix -')
except: except Exception:
print('sox error!') print('sox error!')
return None return None
#prepare clean_transcript file #prepare clean_transcript file
try: try:
unk_words = prep_txt_zh(text_string, tmpbase, MODEL_DIR_ZH + '/dict') unk_words = prep_txt_zh(
line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_ZH + '/dict')
if unk_words: if unk_words:
print('Error! Please add the following words to dictionary:') print('Error! Please add the following words to dictionary:')
for unk in unk_words: for unk in unk_words:
print("非法words: ", unk) print("非法words: ", unk)
except: except Exception:
print('prep_txt error!') print('prep_txt error!')
return None return None
...@@ -250,7 +386,7 @@ def alignment_zh(wav_path, text_string): ...@@ -250,7 +386,7 @@ def alignment_zh(wav_path, text_string):
with open(tmpbase + '.txt', 'r') as fid: with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline() txt = fid.readline()
prep_mlf(txt, tmpbase) prep_mlf(txt, tmpbase)
except: except Exception:
print('prep_mlf error!') print('prep_mlf error!')
return None return None
...@@ -258,7 +394,7 @@ def alignment_zh(wav_path, text_string): ...@@ -258,7 +394,7 @@ def alignment_zh(wav_path, text_string):
try: try:
os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase + os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp') '.wav' + ' ' + tmpbase + '.plp')
except: except Exception:
print('HCopy error!') print('HCopy error!')
return None return None
...@@ -270,7 +406,7 @@ def alignment_zh(wav_path, text_string): ...@@ -270,7 +406,7 @@ def alignment_zh(wav_path, text_string):
+ '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase + + '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null') '.plp 2>&1 > /dev/null')
except: except Exception:
print('HVite error!') print('HVite error!')
return None return None
...@@ -283,7 +419,7 @@ def alignment_zh(wav_path, text_string): ...@@ -283,7 +419,7 @@ def alignment_zh(wav_path, text_string):
lines = fid.readlines() lines = fid.readlines()
i = 2 i = 2
times2 = [] intervals = []
word2phns = {} word2phns = {}
current_word = '' current_word = ''
index = 0 index = 0
...@@ -293,7 +429,7 @@ def alignment_zh(wav_path, text_string): ...@@ -293,7 +429,7 @@ def alignment_zh(wav_path, text_string):
phn = splited_line[2] phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000 pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000 pen = (int(splited_line[1]) / 1000 + 125) / 10000
times2.append([phn, pst, pen]) intervals.append([phn, pst, pen])
# splited_line[-1]!='sp' # splited_line[-1]!='sp'
if len(splited_line) == 5: if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1] current_word = str(index) + '_' + splited_line[-1]
...@@ -302,4 +438,4 @@ def alignment_zh(wav_path, text_string): ...@@ -302,4 +438,4 @@ def alignment_zh(wav_path, text_string):
elif len(splited_line) == 4: elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn word2phns[current_word] += ' ' + phn
i += 1 i += 1
return times2, word2phns return intervals, word2phns
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
import numpy as np
import paddle
from dataset import get_seg_pos
from dataset import phones_masking
from dataset import phones_text_masking
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import pad_list
class MLMCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(self,
feats_extract,
float_pad_value: Union[float, int]=0.0,
int_pad_value: int=-32768,
not_sequence: Collection[str]=(),
mlm_prob: float=0.8,
mean_phn_span: int=8,
attention_window: int=0,
pad_speech: bool=False,
seg_emb: bool=False,
text_masking: bool=False):
self.mlm_prob = mlm_prob
self.mean_phn_span = mean_phn_span
self.feats_extract = feats_extract
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.attention_window = attention_window
self.pad_speech = pad_speech
self.seg_emb = seg_emb
self.text_masking = text_masking
def __repr__(self):
return (f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})")
def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
return mlm_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
mlm_prob=self.mlm_prob,
mean_phn_span=self.mean_phn_span,
feats_extract=self.feats_extract,
attention_window=self.attention_window,
pad_speech=self.pad_speech,
seg_emb=self.seg_emb,
text_masking=self.text_masking)
def mlm_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int]=0.0,
int_pad_value: int=-32768,
not_sequence: Collection[str]=(),
mlm_prob: float=0.8,
mean_phn_span: int=8,
feats_extract=None,
attention_window: int=0,
pad_speech: bool=False,
seg_emb: bool=False,
text_masking: bool=False) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(not k.endswith("_lens")
for k in data[0]), f"*_lens is reserved: {list(data[0])}"
output = {}
for key in data[0]:
# Each models, which accepts these values finally, are responsible
# to repaint the pad_value to the desired value for each tasks.
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list = [paddle.to_tensor(a) for a in array_list]
# tensor: (Batch, Length, ...)
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
# lens: (Batch,)
if key not in not_sequence:
lens = paddle.to_tensor(
[d[key].shape[0] for d in data], dtype=paddle.int64)
output[key + "_lens"] = lens
feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
feats = paddle.to_tensor(feats)
feats_lens = paddle.shape(feats)[0]
feats = paddle.unsqueeze(feats, 0)
text = output["text"]
text_lens = output["text_lens"]
align_start = output["align_start"]
align_start_lens = output["align_start_lens"]
align_end = output["align_end"]
max_tlen = max(text_lens)
max_slen = max(feats_lens)
speech_pad = feats[:, :max_slen]
text_pad = text
text_mask = make_non_pad_mask(
text_lens, text_pad, length_dim=1).unsqueeze(-2)
speech_mask = make_non_pad_mask(
feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
span_bdy = None
if 'span_bdy' in output.keys():
span_bdy = output['span_bdy']
# dual_mask 的是混合中英时候同时 mask 语音和文本
# ernie sat 在实现跨语言的时候都 mask 了
if text_masking:
masked_pos, text_masked_pos = phones_text_masking(
xs_pad=speech_pad,
src_mask=speech_mask,
text_pad=text_pad,
text_mask=text_mask,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
span_bdy=span_bdy)
# 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了
# a3t 和 ernie sat 的区别主要在于做 mask 的时候
else:
masked_pos = phones_masking(
xs_pad=speech_pad,
src_mask=speech_mask,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
span_bdy=span_bdy)
text_masked_pos = paddle.zeros(paddle.shape(text_pad))
output_dict = {}
speech_seg_pos, text_seg_pos = get_seg_pos(
speech_pad=speech_pad,
text_pad=text_pad,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
seg_emb=seg_emb)
output_dict['speech'] = speech_pad
output_dict['text'] = text_pad
output_dict['masked_pos'] = masked_pos
output_dict['text_masked_pos'] = text_masked_pos
output_dict['speech_mask'] = speech_mask
output_dict['text_mask'] = text_mask
output_dict['speech_seg_pos'] = speech_seg_pos
output_dict['text_seg_pos'] = text_seg_pos
output = (uttids, output_dict)
return output
def build_collate_fn(
sr: int=24000,
n_fft: int=2048,
hop_length: int=300,
win_length: int=None,
n_mels: int=80,
fmin: int=80,
fmax: int=7600,
mlm_prob: float=0.8,
mean_phn_span: int=8,
train: bool=False,
seg_emb: bool=False,
epoch: int=-1, ):
feats_extract_class = LogMelFBank
feats_extract = feats_extract_class(
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mels=n_mels,
fmin=fmin,
fmax=fmax)
pad_speech = False
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_prob_factor = 0.8
return MLMCollateFn(
feats_extract=feats_extract,
float_pad_value=0.0,
int_pad_value=0,
mlm_prob=mlm_prob * mlm_prob_factor,
mean_phn_span=mean_phn_span,
pad_speech=pad_speech,
seg_emb=seg_emb)
...@@ -4,6 +4,68 @@ import numpy as np ...@@ -4,6 +4,68 @@ import numpy as np
import paddle import paddle
# mask phones
def phones_masking(xs_pad: paddle.Tensor,
src_mask: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
mlm_prob: float=0.8,
mean_phn_span: int=8,
span_bdy: paddle.Tensor=None):
'''
Args:
xs_pad (paddle.Tensor): input speech (B, Tmax, D).
src_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
mlm_prob (float):
mean_phn_span (int):
span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2).
Returns:
paddle.Tensor[bool]: masked position of input speech (B, Tmax).
'''
bz, sent_len, _ = paddle.shape(xs_pad)
masked_pos = paddle.zeros((bz, sent_len))
if mlm_prob == 1.0:
masked_pos += 1
elif mean_phn_span == 0:
# only speech
length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_idxs = random_spans_noise_mask(
length=length, mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span).nonzero()
masked_pos[:, masked_phn_idxs] = 1
else:
for idx in range(bz):
# for inference
if span_bdy is not None:
for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
masked_pos[idx, s:e] = 1
# for training
else:
length = align_start_lens[idx]
if length < 2:
continue
masked_phn_idxs = random_spans_noise_mask(
length=length,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span).nonzero()
masked_start = align_start[idx][masked_phn_idxs].tolist()
masked_end = align_end[idx][masked_phn_idxs].tolist()
for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask
masked_pos = paddle.cast(masked_pos, 'bool')
return masked_pos
# mask speech and phones
def phones_text_masking(xs_pad: paddle.Tensor, def phones_text_masking(xs_pad: paddle.Tensor,
src_mask: paddle.Tensor, src_mask: paddle.Tensor,
text_pad: paddle.Tensor, text_pad: paddle.Tensor,
...@@ -11,37 +73,56 @@ def phones_text_masking(xs_pad: paddle.Tensor, ...@@ -11,37 +73,56 @@ def phones_text_masking(xs_pad: paddle.Tensor,
align_start: paddle.Tensor, align_start: paddle.Tensor,
align_end: paddle.Tensor, align_end: paddle.Tensor,
align_start_lens: paddle.Tensor, align_start_lens: paddle.Tensor,
mlm_prob: float, mlm_prob: float=0.8,
mean_phn_span: float, mean_phn_span: int=8,
span_bdy: paddle.Tensor=None): span_bdy: paddle.Tensor=None):
'''
Args:
xs_pad (paddle.Tensor): input speech (B, Tmax, D).
src_mask (paddle.Tensor): mask of speech (B, 1, Tmax).
text_pad (paddle.Tensor): input text (B, Tmax2).
text_mask (paddle.Tensor): mask of text (B, 1, Tmax2).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
mlm_prob (float):
mean_phn_span (int):
span_bdy (paddle.Tensor): masked mel boundary of input speech (B, 2).
Returns:
paddle.Tensor[bool]: masked position of input speech (B, Tmax).
paddle.Tensor[bool]: masked position of input text (B, Tmax2).
'''
bz, sent_len, _ = paddle.shape(xs_pad) bz, sent_len, _ = paddle.shape(xs_pad)
masked_pos = paddle.zeros((bz, sent_len)) masked_pos = paddle.zeros((bz, sent_len))
_, text_len = paddle.shape(text_pad) _, text_len = paddle.shape(text_pad)
text_mask_num_lower = math.ceil(text_len * (1 - mlm_prob) * 0.5) text_mask_num_lower = math.ceil(text_len * (1 - mlm_prob) * 0.5)
text_masked_pos = paddle.zeros((bz, text_len)) text_masked_pos = paddle.zeros((bz, text_len))
y_masks = None
if mlm_prob == 1.0: if mlm_prob == 1.0:
masked_pos += 1 masked_pos += 1
# y_masks = tril_masks
elif mean_phn_span == 0: elif mean_phn_span == 0:
# only speech # only speech
length = sent_len length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50) mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_idxs = random_spans_noise_mask(length, mlm_prob, masked_phn_idxs = random_spans_noise_mask(
mean_phn_span).nonzero() length=length, mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span).nonzero()
masked_pos[:, masked_phn_idxs] = 1 masked_pos[:, masked_phn_idxs] = 1
else: else:
for idx in range(bz): for idx in range(bz):
# for inference
if span_bdy is not None: if span_bdy is not None:
for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]): for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
masked_pos[idx, s:e] = 1 masked_pos[idx, s:e] = 1
# for training
else: else:
length = align_start_lens[idx] length = align_start_lens[idx]
if length < 2: if length < 2:
continue continue
masked_phn_idxs = random_spans_noise_mask( masked_phn_idxs = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero() length=length,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span).nonzero()
unmasked_phn_idxs = list( unmasked_phn_idxs = list(
set(range(length)) - set(masked_phn_idxs[0].tolist())) set(range(length)) - set(masked_phn_idxs[0].tolist()))
np.random.shuffle(unmasked_phn_idxs) np.random.shuffle(unmasked_phn_idxs)
...@@ -58,60 +139,76 @@ def phones_text_masking(xs_pad: paddle.Tensor, ...@@ -58,60 +139,76 @@ def phones_text_masking(xs_pad: paddle.Tensor,
masked_pos = paddle.cast(masked_pos, 'bool') masked_pos = paddle.cast(masked_pos, 'bool')
text_masked_pos = paddle.cast(text_masked_pos, 'bool') text_masked_pos = paddle.cast(text_masked_pos, 'bool')
return masked_pos, text_masked_pos, y_masks return masked_pos, text_masked_pos
def get_seg_pos_reduce_duration( def get_seg_pos(speech_pad: paddle.Tensor,
speech_pad: paddle.Tensor,
text_pad: paddle.Tensor, text_pad: paddle.Tensor,
align_start: paddle.Tensor, align_start: paddle.Tensor,
align_end: paddle.Tensor, align_end: paddle.Tensor,
align_start_lens: paddle.Tensor, align_start_lens: paddle.Tensor,
sega_emb: bool, seg_emb: bool=False):
masked_pos: paddle.Tensor, '''
feats_lens: paddle.Tensor, ): Args:
speech_pad (paddle.Tensor): input speech (B, Tmax, D).
text_pad (paddle.Tensor): input text (B, Tmax2).
align_start (paddle.Tensor): frame level phone alignment start (B, Tmax2).
align_end (paddle.Tensor): frame level phone alignment end (B, Tmax2).
align_start_lens (paddle.Tensor): length of align_start (B, ).
seg_emb (bool): whether to use segment embedding.
Returns:
paddle.Tensor[int]: n-th phone of each mel, 0<=n<=Tmax2 (B, Tmax).
eg:
Tensor(shape=[1, 328], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 3 , 4 , 4 , 4 ,
5 , 5 , 5 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 7 , 7 , 7 , 7 , 7 , 7 , 7 ,
7 , 8 , 8 , 8 , 8 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 10, 10, 10, 10, 10,
10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13,
13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15,
15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17,
17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20,
20, 20, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 23, 23,
23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25,
25, 26, 26, 26, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 29,
29, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 32,
32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 35, 35,
35, 35, 35, 35, 35, 35, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38,
38, 38, 0 , 0 ]])
paddle.Tensor[int]: n-th phone of each phone, 0<=n<=Tmax2 (B, Tmax2).
eg:
Tensor(shape=[1, 38], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38]])
'''
bz, speech_len, _ = paddle.shape(speech_pad) bz, speech_len, _ = paddle.shape(speech_pad)
text_seg_pos = paddle.zeros(paddle.shape(text_pad)) _, text_len = paddle.shape(text_pad)
speech_seg_pos = paddle.zeros((bz, speech_len), dtype=text_pad.dtype)
reordered_idx = paddle.zeros((bz, speech_len), dtype=align_start_lens.dtype) text_seg_pos = paddle.zeros((bz, text_len), dtype='int64')
speech_seg_pos = paddle.zeros((bz, speech_len), dtype='int64')
durations = paddle.ones((bz, speech_len), dtype=align_start_lens.dtype) if not seg_emb:
max_reduced_length = 0 return speech_seg_pos, text_seg_pos
if not sega_emb:
return speech_pad, masked_pos, speech_seg_pos, text_seg_pos, durations
for idx in range(bz): for idx in range(bz):
first_idx = []
last_idx = []
align_length = align_start_lens[idx] align_length = align_start_lens[idx]
for j in range(align_length): for j in range(align_length):
s, e = align_start[idx][j], align_end[idx][j] s, e = align_start[idx][j], align_end[idx][j]
if j == 0: speech_seg_pos[idx, s:e] = j + 1
if paddle.sum(masked_pos[idx][0:s]) == 0: text_seg_pos[idx, j] = j + 1
first_idx.extend(range(0, s))
else:
first_idx.extend([0])
last_idx.extend(range(1, s))
if paddle.sum(masked_pos[idx][s:e]) == 0:
first_idx.extend(range(s, e))
else:
first_idx.extend([s])
last_idx.extend(range(s + 1, e))
durations[idx][s] = e - s
speech_seg_pos[idx][s:e] = j + 1
text_seg_pos[idx][j] = j + 1
max_reduced_length = max(
len(first_idx) + feats_lens[idx] - e, max_reduced_length)
first_idx.extend(range(e, speech_len))
reordered_idx[idx] = paddle.to_tensor(
(first_idx + last_idx), dtype=align_start_lens.dtype)
feats_lens[idx] = len(first_idx)
reordered_idx = reordered_idx[:, :max_reduced_length]
return reordered_idx, speech_seg_pos, text_seg_pos, durations, feats_lens return speech_seg_pos, text_seg_pos
def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float): # randomly select the range of speech and text to mask during training
def random_spans_noise_mask(length: int,
mlm_prob: float=0.8,
mean_phn_span: float=8):
"""This function is copy of `random_spans_helper """This function is copy of `random_spans_helper
<https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ . <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens. Noise mask consisting of random spans of noise tokens.
...@@ -126,7 +223,7 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float): ...@@ -126,7 +223,7 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
noise_density: a float - approximate density of output mask noise_density: a float - approximate density of output mask
mean_noise_span_length: a number mean_noise_span_length: a number
Returns: Returns:
a boolean tensor with shape [length] np.ndarray: a boolean tensor with shape [length]
""" """
orig_length = length orig_length = length
...@@ -171,87 +268,3 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float): ...@@ -171,87 +268,3 @@ def random_spans_noise_mask(length: int, mlm_prob: float, mean_phn_span: float):
is_noise = np.equal(span_num % 2, 1) is_noise = np.equal(span_num % 2, 1)
return is_noise[:orig_length] return is_noise[:orig_length]
def pad_to_longformer_att_window(text: paddle.Tensor,
max_len: int,
max_tlen: int,
attention_window: int=0):
round = max_len % attention_window
if round != 0:
max_tlen += (attention_window - round)
n_batch = paddle.shape(text)[0]
text_pad = paddle.zeros(
(n_batch, max_tlen, *paddle.shape(text[0])[1:]), dtype=text.dtype)
for i in range(n_batch):
text_pad[i, :paddle.shape(text[i])[0]] = text[i]
else:
text_pad = text[:, :max_tlen]
return text_pad, max_tlen
def phones_masking(xs_pad: paddle.Tensor,
src_mask: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
mlm_prob: float,
mean_phn_span: int,
span_bdy: paddle.Tensor=None):
bz, sent_len, _ = paddle.shape(xs_pad)
masked_pos = paddle.zeros((bz, sent_len))
y_masks = None
if mlm_prob == 1.0:
masked_pos += 1
elif mean_phn_span == 0:
# only speech
length = sent_len
mean_phn_span = min(length * mlm_prob // 3, 50)
masked_phn_idxs = random_spans_noise_mask(length, mlm_prob,
mean_phn_span).nonzero()
masked_pos[:, masked_phn_idxs] = 1
else:
for idx in range(bz):
if span_bdy is not None:
for s, e in zip(span_bdy[idx][::2], span_bdy[idx][1::2]):
masked_pos[idx, s:e] = 1
else:
length = align_start_lens[idx]
if length < 2:
continue
masked_phn_idxs = random_spans_noise_mask(
length, mlm_prob, mean_phn_span).nonzero()
masked_start = align_start[idx][masked_phn_idxs].tolist()
masked_end = align_end[idx][masked_phn_idxs].tolist()
for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask
masked_pos = paddle.cast(masked_pos, 'bool')
return masked_pos, y_masks
def get_seg_pos(speech_pad: paddle.Tensor,
text_pad: paddle.Tensor,
align_start: paddle.Tensor,
align_end: paddle.Tensor,
align_start_lens: paddle.Tensor,
sega_emb: bool):
bz, speech_len, _ = paddle.shape(speech_pad)
_, text_len = paddle.shape(text_pad)
text_seg_pos = paddle.zeros((bz, text_len), dtype='int64')
speech_seg_pos = paddle.zeros((bz, speech_len), dtype='int64')
if not sega_emb:
return speech_seg_pos, text_seg_pos
for idx in range(bz):
align_length = align_start_lens[idx]
for j in range(align_length):
s, e = align_start[idx][j], align_end[idx][j]
speech_seg_pos[idx, s:e] = j + 1
text_seg_pos[idx, j] = j + 1
return speech_seg_pos, text_seg_pos
此差异已折叠。
import paddle
from paddle import nn
class MLMLoss(nn.Layer):
def __init__(self,
lsm_weight: float=0.1,
ignore_id: int=-1,
text_masking: bool=False):
super().__init__()
if text_masking:
self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id)
if lsm_weight > 50:
self.l1_loss_func = nn.MSELoss()
else:
self.l1_loss_func = nn.L1Loss(reduction='none')
self.text_masking = text_masking
def forward(self,
speech: paddle.Tensor,
before_outs: paddle.Tensor,
after_outs: paddle.Tensor,
masked_pos: paddle.Tensor,
text: paddle.Tensor=None,
text_outs: paddle.Tensor=None,
text_masked_pos: paddle.Tensor=None):
xs_pad = speech
mlm_loss_pos = masked_pos > 0
loss = paddle.sum(
self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
if after_outs is not None:
loss += paddle.sum(
self.l1_loss_func(
paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
loss_mlm = paddle.sum((loss * paddle.reshape(
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
if self.text_masking:
loss_text = paddle.sum((self.text_mlm_loss(
paddle.reshape(text_outs, (-1, self.vocab_size)),
paddle.reshape(text, (-1))) * paddle.reshape(
text_masked_pos,
(-1)))) / paddle.sum((text_masked_pos) + 1e-10)
return loss_mlm, loss_text
return loss_mlm
...@@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask) return self.forward_attention(v, scores, mask)
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = paddle.create_parameter(
shape=(self.h, self.d_k),
dtype='float32',
default_initializer=paddle.nn.initializer.XavierUniform())
self.pos_bias_v = paddle.create_parameter(
shape=(self.h, self.d_k),
dtype='float32',
default_initializer=paddle.nn.initializer.XavierUniform())
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x(Tensor): Input tensor (batch, head, time1, time2).
Returns:
Tensor:Output tensor.
"""
b, h, t1, t2 = paddle.shape(x)
zero_pad = paddle.zeros((b, h, t1, 1))
x_padded = paddle.concat([zero_pad, x], axis=-1)
x_padded = paddle.reshape(x_padded, [b, h, t2 + 1, t1])
# only keep the positions from 0 to time2
x = paddle.reshape(x_padded[:, :, 1:], [b, h, t1, t2])
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query(Tensor): Query tensor (#batch, time1, size).
key(Tensor): Key tensor (#batch, time2, size).
value(Tensor): Value tensor (#batch, time2, size).
pos_emb(Tensor): Positional embedding tensor (#batch, time1, size).
mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
Returns:
Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
# (batch, time1, head, d_k)
q = paddle.transpose(q, [0, 2, 1, 3])
n_batch_pos = paddle.shape(pos_emb)[0]
p = paddle.reshape(
self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k])
# (batch, head, time1, d_k)
p = paddle.transpose(p, [0, 2, 1, 3])
# (batch, head, time1, d_k)
q_with_bias_u = paddle.transpose((q + self.pos_bias_u), [0, 2, 1, 3])
# (batch, head, time1, d_k)
q_with_bias_v = paddle.transpose((q + self.pos_bias_v), [0, 2, 1, 3])
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = paddle.matmul(q_with_bias_u,
paddle.transpose(k, [0, 1, 3, 2]))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd = paddle.matmul(q_with_bias_v,
paddle.transpose(p, [0, 1, 3, 2]))
matrix_bd = self.rel_shift(matrix_bd)
# (batch, head, time1, time2)
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
...@@ -185,3 +185,63 @@ class RelPositionalEncoding(nn.Layer): ...@@ -185,3 +185,63 @@ class RelPositionalEncoding(nn.Layer):
pe_size = paddle.shape(self.pe) pe_size = paddle.shape(self.pe)
pos_emb = self.pe[:, pe_size[1] // 2 - T + 1:pe_size[1] // 2 + T, ] pos_emb = self.pe[:, pe_size[1] // 2 - T + 1:pe_size[1] // 2 + T, ]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
class LegacyRelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]:
return
pe = paddle.zeros((paddle.shape(x)[1], self.d_model))
if self.reverse:
position = paddle.arange(
paddle.shape(x)[1] - 1, -1, -1.0,
dtype=paddle.float32).unsqueeze(1)
else:
position = paddle.arange(
0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe
def forward(self, x: paddle.Tensor):
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[:, :paddle.shape(x)[1]]
return self.dropout(x), self.dropout(pos_emb)
...@@ -5,7 +5,7 @@ from typing import List ...@@ -5,7 +5,7 @@ from typing import List
from typing import Union from typing import Union
def read_2column_text(path: Union[Path, str]) -> Dict[str, str]: def read_2col_text(path: Union[Path, str]) -> Dict[str, str]:
"""Read a text file having 2 column as dict object. """Read a text file having 2 column as dict object.
Examples: Examples:
...@@ -13,7 +13,7 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]: ...@@ -13,7 +13,7 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
key1 /some/path/a.wav key1 /some/path/a.wav
key2 /some/path/b.wav key2 /some/path/b.wav
>>> read_2column_text('wav.scp') >>> read_2col_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'} {'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
""" """
......
...@@ -65,12 +65,6 @@ def parse_args(): ...@@ -65,12 +65,6 @@ def parse_args():
help="mean and standard deviation used to normalize spectrogram when training voc." help="mean and standard deviation used to normalize spectrogram when training voc."
) )
# other # other
parser.add_argument(
'--lang',
type=str,
default='en',
help='Choose model language. zh or en')
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
# parser.add_argument("--test_metadata", type=str, help="test metadata.") # parser.add_argument("--test_metadata", type=str, help="test metadata.")
......
import os import os
from typing import List
from typing import Optional from typing import Optional
import numpy as np import numpy as np
...@@ -55,16 +54,14 @@ def build_vocoder_from_file( ...@@ -55,16 +54,14 @@ def build_vocoder_from_file(
raise ValueError(f"{vocoder_file} is not supported format.") raise ValueError(f"{vocoder_file} is not supported format.")
def get_voc_out(mel, target_lang: str="chinese"): def get_voc_out(mel):
# vocoder # vocoder
args = parse_args() args = parse_args()
assert target_lang == "chinese" or target_lang == "english", "In get_voc_out function, target_lang is illegal..."
# print("current vocoder: ", args.voc) # print("current vocoder: ", args.voc)
with open(args.voc_config) as f: with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f)) voc_config = CfgNode(yaml.safe_load(f))
voc_inference = voc_inference = get_voc_inference( voc_inference = get_voc_inference(
voc=args.voc, voc=args.voc,
voc_config=voc_config, voc_config=voc_config,
voc_ckpt=args.voc_ckpt, voc_ckpt=args.voc_ckpt,
...@@ -167,19 +164,23 @@ def get_voc_inference( ...@@ -167,19 +164,23 @@ def get_voc_inference(
return voc_inference return voc_inference
def evaluate_durations(phns: List[str], def eval_durs(phns, target_lang="chinese", fs=24000, hop_length=300):
target_lang: str="chinese",
fs: int=24000,
hop_length: int=300):
args = parse_args() args = parse_args()
if target_lang == 'english': if target_lang == 'english':
args.lang = 'en' args.am = "fastspeech2_ljspeech"
args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_lang == 'chinese': elif target_lang == 'chinese':
args.lang = 'zh' args.am = "fastspeech2_csmsc"
args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# args = parser.parse_args(args=[])
if args.ngpu == 0: if args.ngpu == 0:
paddle.set_device("cpu") paddle.set_device("cpu")
elif args.ngpu > 0: elif args.ngpu > 0:
...@@ -187,8 +188,6 @@ def evaluate_durations(phns: List[str], ...@@ -187,8 +188,6 @@ def evaluate_durations(phns: List[str],
else: else:
print("ngpu should >= 0 !") print("ngpu should >= 0 !")
assert target_lang == "chinese" or target_lang == "english", "In evaluate_durations function, target_lang is illegal..."
# Init body. # Init body.
with open(args.am_config) as f: with open(args.am_config) as f:
am_config = CfgNode(yaml.safe_load(f)) am_config = CfgNode(yaml.safe_load(f))
...@@ -203,22 +202,19 @@ def evaluate_durations(phns: List[str], ...@@ -203,22 +202,19 @@ def evaluate_durations(phns: List[str],
speaker_dict=args.speaker_dict, speaker_dict=args.speaker_dict,
return_am=True) return_am=True)
torch_phns = phns
vocab_phones = {} vocab_phones = {}
with open(args.phones_dict, "r") as f: with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
for tone, id in phn_id: for tone, id in phn_id:
vocab_phones[tone] = int(id) vocab_phones[tone] = int(id)
vocab_size = len(vocab_phones) vocab_size = len(vocab_phones)
phonemes = [phn if phn in vocab_phones else "sp" for phn in torch_phns] phonemes = [phn if phn in vocab_phones else "sp" for phn in phns]
phone_ids = [vocab_phones[item] for item in phonemes] phone_ids = [vocab_phones[item] for item in phonemes]
phone_ids_new = phone_ids phone_ids.append(vocab_size - 1)
phone_ids_new.append(vocab_size - 1) phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64))
phone_ids_new = paddle.to_tensor(np.array(phone_ids_new, np.int64)) _, d_outs, _, _ = am.inference(phone_ids, spk_id=None, spk_emb=None)
normalized_mel, d_outs, p_outs, e_outs = am.inference(
phone_ids_new, spk_id=None, spk_emb=None)
pre_d_outs = d_outs pre_d_outs = d_outs
phoneme_durations_new = pre_d_outs * hop_length / fs phu_durs_new = pre_d_outs * hop_length / fs
phoneme_durations_new = phoneme_durations_new.tolist()[:-1] phu_durs_new = phu_durs_new.tolist()[:-1]
return phoneme_durations_new return phu_durs_new
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册