提交 79658a5f 编写于 作者: 小湉湉's avatar 小湉湉

add ernie sat inference, test=tts

上级 803800f9
# ERNIE SAT with AISHELL3 dataset
# Mixed Chinese and English TTS with AISHELL3 and VCTK datasets
# ERNIE SAT with AISHELL3 and VCTK dataset
ERNIE-SAT 是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。
## 模型框架
ERNIE-SAT 中我们提出了两项创新:
- 在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射
- 采用语言和语音的联合掩码学习实现了语言和语音的对齐
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3lOXKJXE-1655380879339)(.meta/framework.png)]
## 使用说明
### 1.安装飞桨与环境依赖
- 本项目的代码基于 Paddle(version>=2.0)
- 本项目开放提供加载 torch 版本的 vocoder 的功能
- torch version>=1.8
- 安装 htk: 在[官方地址](https://htk.eng.cam.ac.uk/)注册完成后,即可进行下载较新版本的 htk (例如 3.4.1)。同时提供[历史版本 htk 下载地址](https://htk.eng.cam.ac.uk/ftp/software/)
- 1.注册账号,下载 htk
- 2.解压 htk 文件,**放入项目根目录的 tools 文件夹中, 以 htk 文件夹名称放入**
- 3.**注意**: 如果您下载的是 3.4.1 或者更高版本, 需要进入 HTKLib/HRec.c 文件中, **修改 1626 行和 1650 行**, 即把**以下两行的 dur<=0 都修改为 dur<0**,如下所示:
```bash
以htk3.4.1版本举例:
(1)第1626行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0");
修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0");
(2)1650行: if (dur<=0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<=0 ");
修改为: if (dur<0 && labid != splabid) HError(8522,"LatFromPaths: Align have dur<0 ");
```
- 4.**编译**: 详情参见解压后的 htk 中的 README 文件(如果未编译, 则无法正常运行)
- 安装 ParallelWaveGAN: 参见[官方地址](https://github.com/kan-bayashi/ParallelWaveGAN):按照该官方链接的安装流程,直接在**项目的根目录下** git clone ParallelWaveGAN 项目并且安装相关依赖即可。
- 安装其他依赖: **sox, libsndfile**
### 2.预训练模型
预训练模型 ERNIE-SAT 的模型如下所示:
- [ERNIE-SAT_ZH](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-zh.tar.gz)
- [ERNIE-SAT_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en.tar.gz)
- [ERNIE-SAT_ZH_and_EN](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/model-ernie-sat-base-en_zh.tar.gz)
创建 pretrained_model 文件夹,下载上述 ERNIE-SAT 预训练模型并将其解压:
```bash
mkdir pretrained_model
cd pretrained_model
tar -zxvf model-ernie-sat-base-en.tar.gz
tar -zxvf model-ernie-sat-base-zh.tar.gz
tar -zxvf model-ernie-sat-base-en_zh.tar.gz
```
### 3.下载
1. 本项目使用 parallel wavegan 作为声码器(vocoder):
- [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)
创建 download 文件夹,下载上述预训练的声码器(vocoder)模型并将其解压:
```bash
mkdir download
cd download
unzip pwg_aishell3_ckpt_0.5.zip
```
2. 本项目使用 [FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器:
- [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip) 中文场景下使用
- [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) 英文场景下使用
下载上述预训练的 fastspeech2 模型并将其解压:
```bash
cd download
unzip fastspeech2_conformer_baker_ckpt_0.5.zip
unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip
```
3. 本项目使用 HTK 获取输入音频和文本的对齐信息:
- [aligner.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ernie_sat/old/aligner.zip)
下载上述文件到 tools 文件夹并将其解压:
```bash
cd tools
unzip aligner.zip
```
### 4.推理
本项目当前开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。
注:当前英文场下的合成语音采用的声码器默认为 vctk_parallel_wavegan.v1.long, 可在[该链接](https://github.com/kan-bayashi/ParallelWaveGAN)中找到; 若 use_pt_vocoder 参数设置为 False,则英文场景下使用 paddle 版本的声码器。
我们提供特定音频文件, 以及其对应的文本、音素相关文件:
- prompt_wav: 提供的音频文件
- prompt/dev: 基于上述特定音频对应的文本、音素相关文件
```text
prompt_wav
├── p299_096.wav # 样例语音文件1
├── p243_313.wav # 样例语音文件2
└── ...
```
```text
prompt/dev
├── text # 样例语音对应文本
├── wav.scp # 样例语音路径
├── mfa_text # 样例语音对应音素
├── mfa_start # 样例语音中各个音素的开始时间
└── mfa_end # 样例语音中各个音素的结束时间
```
1. `--am` 声学模型格式符合 {model_name}_{dataset}
2. `--am_config`, `--am_checkpoint`, `--am_stat``--phones_dict` 是声学模型的参数,对应于 fastspeech2 预训练模型中的 4 个文件。
3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。
5. `--lang` 对应模型的语言可以是 `zh``en`
6. `--ngpu` 要使用的 GPU 数,如果 ngpu==0,则使用 cpu。
7. `--model_name` 模型名称
8. `--uid` 特定提示(prompt)语音的 id
9. `--new_str` 输入的文本(本次开源暂时先设置特定的文本)
10. `--prefix` 特定音频对应的文本、音素相关文件的地址
11. `--source_lang` , 源语言
12. `--target_lang` , 目标语言
13. `--output_name` , 合成语音名称
14. `--task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
运行以下脚本即可进行实验
```shell
./run_sedit_en.sh # 语音编辑任务(英文)
./run_gen_en.sh # 个性化语音合成任务(英文)
./run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
```
""" Usage:
align.py wavfile trsfile outwordfile outphonefile
"""
import os
import sys
PHONEME = 'tools/aligner/english_envir/english2phoneme/phoneme'
MODEL_DIR_EN = 'tools/aligner/english'
MODEL_DIR_ZH = 'tools/aligner/mandarin'
HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy'
def get_unk_phns(word_str: str):
tmpbase = '/tmp/tp.'
f = open(tmpbase + 'temp.words', 'w')
f.write(word_str)
f.close()
os.system(PHONEME + ' ' + tmpbase + 'temp.words' + ' ' + tmpbase +
'temp.phons')
f = open(tmpbase + 'temp.phons', 'r')
lines2 = f.readline().strip().split()
f.close()
phns = []
for phn in lines2:
phons = phn.replace('\n', '').replace(' ', '')
seq = []
j = 0
while (j < len(phons)):
if (phons[j] > 'Z'):
if (phons[j] == 'j'):
seq.append('JH')
elif (phons[j] == 'h'):
seq.append('HH')
else:
seq.append(phons[j].upper())
j += 1
else:
p = phons[j:j + 2]
if (p == 'WH'):
seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
seq.append(p)
elif (p == 'AX'):
seq.append('AH0')
else:
seq.append(p + '1')
j += 2
phns.extend(seq)
return phns
def words2phns(line: str):
'''
Args:
line (str): input text.
eg: for that reason cover is impossible to be given.
Returns:
List[str]: phones of input text.
eg:
['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0',
'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1',
'G', 'IH1', 'V', 'AH0', 'N']
Dict(str, str): key - idx_word
value - phones
eg:
{'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],
'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'],
'6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']}
'''
dictfile = MODEL_DIR_EN + '/dict'
line = line.strip()
words = []
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
wrd2phns[str(index) + "_" + wrd.upper()] = get_unk_phns(wrd)
phns.extend(get_unk_phns(wrd))
else:
wrd2phns[str(index) +
"_" + wrd.upper()] = word2phns_dict[wrd.upper()].split()
phns.extend(word2phns_dict[wrd.upper()].split())
return phns, wrd2phns
def words2phns_zh(line: str):
dictfile = MODEL_DIR_ZH + '/dict'
line = line.strip()
words = []
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
word = line.split()[0]
ds.add(word)
if word not in word2phns_dict.keys():
word2phns_dict[word] = " ".join(line.split()[1:])
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if wrd == '[MASK]':
wrd2phns[str(index) + "_" + wrd] = [wrd]
phns.append(wrd)
elif (wrd.upper() not in ds):
print("出现非法词错误,请输入正确的文本...")
else:
wrd2phns[str(index) + "_" + wrd] = word2phns_dict[wrd].split()
phns.extend(word2phns_dict[wrd].split())
return phns, wrd2phns
def prep_txt_zh(line: str, tmpbase: str, dictfile: str):
words = []
line = line.strip()
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
u'。', u':', u';', u'!', u'?', u'(', u')'
]:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
with open(dictfile, 'r') as fid:
for line in fid:
ds.add(line.split()[0])
unk_words = set([])
with open(tmpbase + '.txt', 'w') as fwid:
for wrd in words:
if (wrd not in ds):
unk_words.add(wrd)
fwid.write(wrd + ' ')
fwid.write('\n')
return unk_words
def prep_txt_en(line: str, tmpbase, dictfile):
words = []
line = line.strip()
for pun in [',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---']:
line = line.replace(pun, ' ')
for wrd in line.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
ds = set([])
with open(dictfile, 'r') as fid:
for line in fid:
ds.add(line.split()[0])
unk_words = set([])
with open(tmpbase + '.txt', 'w') as fwid:
for wrd in words:
if (wrd.upper() not in ds):
unk_words.add(wrd.upper())
fwid.write(wrd + ' ')
fwid.write('\n')
#generate pronounciations for unknows words using 'letter to sound'
with open(tmpbase + '_unk.words', 'w') as fwid:
for unk in unk_words:
fwid.write(unk + '\n')
try:
os.system(PHONEME + ' ' + tmpbase + '_unk.words' + ' ' + tmpbase +
'_unk.phons')
except Exception:
print('english2phoneme error!')
sys.exit(1)
#add unknown words to the standard dictionary, generate a tmp dictionary for alignment
fw = open(tmpbase + '.dict', 'w')
with open(dictfile, 'r') as fid:
for line in fid:
fw.write(line)
f = open(tmpbase + '_unk.words', 'r')
lines1 = f.readlines()
f.close()
f = open(tmpbase + '_unk.phons', 'r')
lines2 = f.readlines()
f.close()
for i in range(len(lines1)):
wrd = lines1[i].replace('\n', '')
phons = lines2[i].replace('\n', '').replace(' ', '')
seq = []
j = 0
while (j < len(phons)):
if (phons[j] > 'Z'):
if (phons[j] == 'j'):
seq.append('JH')
elif (phons[j] == 'h'):
seq.append('HH')
else:
seq.append(phons[j].upper())
j += 1
else:
p = phons[j:j + 2]
if (p == 'WH'):
seq.append('W')
elif (p in ['TH', 'SH', 'HH', 'DH', 'CH', 'ZH', 'NG']):
seq.append(p)
elif (p == 'AX'):
seq.append('AH0')
else:
seq.append(p + '1')
j += 2
fw.write(wrd + ' ')
for s in seq:
fw.write(' ' + s)
fw.write('\n')
fw.close()
def prep_mlf(txt: str, tmpbase: str):
with open(tmpbase + '.mlf', 'w') as fwid:
fwid.write('#!MLF!#\n')
fwid.write('"' + tmpbase + '.lab"\n')
fwid.write('sp\n')
wrds = txt.split()
for wrd in wrds:
fwid.write(wrd.upper() + '\n')
fwid.write('sp\n')
fwid.write('.\n')
def _get_user():
return os.path.expanduser('~').split("/")[-1]
def alignment(wav_path: str, text: str):
'''
intervals: List[phn, start, end]
'''
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 ' + tmpbase + '.wav remix -')
except Exception:
print('sox error!')
return None
#prepare clean_transcript file
try:
prep_txt_en(line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_EN + '/dict')
except Exception:
print('prep_txt error!')
return None
#prepare mlf file
try:
with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline()
prep_mlf(txt, tmpbase)
except Exception:
print('prep_mlf error!')
return None
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR_EN + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except Exception:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR_EN + '/16000/macros -H ' + MODEL_DIR_EN
+ '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + tmpbase +
'.dict ' + MODEL_DIR_EN + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except Exception:
print('HVite error!')
return None
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
words = txt.strip().split()
words.reverse()
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
intervals = []
word2phns = {}
current_word = ''
index = 0
while (i < len(lines)):
splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
intervals.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn
index += 1
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return intervals, word2phns
def alignment_zh(wav_path: str, text: str):
tmpbase = '/tmp/' + _get_user() + '_' + str(os.getpid())
#prepare wav and trs files
try:
os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
'.wav remix -')
except Exception:
print('sox error!')
return None
#prepare clean_transcript file
try:
unk_words = prep_txt_zh(
line=text, tmpbase=tmpbase, dictfile=MODEL_DIR_ZH + '/dict')
if unk_words:
print('Error! Please add the following words to dictionary:')
for unk in unk_words:
print("非法words: ", unk)
except Exception:
print('prep_txt error!')
return None
#prepare mlf file
try:
with open(tmpbase + '.txt', 'r') as fid:
txt = fid.readline()
prep_mlf(txt, tmpbase)
except Exception:
print('prep_mlf error!')
return None
#prepare scp
try:
os.system(HCOPY + ' -C ' + MODEL_DIR_ZH + '/16000/config ' + tmpbase +
'.wav' + ' ' + tmpbase + '.plp')
except Exception:
print('HCopy error!')
return None
#run alignment
try:
os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
'.mlf -H ' + MODEL_DIR_ZH + '/16000/macros -H ' + MODEL_DIR_ZH
+ '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR_ZH
+ '/dict ' + MODEL_DIR_ZH + '/monophones ' + tmpbase +
'.plp 2>&1 > /dev/null')
except Exception:
print('HVite error!')
return None
with open(tmpbase + '.txt', 'r') as fid:
words = fid.readline().strip().split()
words = txt.strip().split()
words.reverse()
with open(tmpbase + '.aligned', 'r') as fid:
lines = fid.readlines()
i = 2
intervals = []
word2phns = {}
current_word = ''
index = 0
while (i < len(lines)):
splited_line = lines[i].strip().split()
if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
phn = splited_line[2]
pst = (int(splited_line[0]) / 1000 + 125) / 10000
pen = (int(splited_line[1]) / 1000 + 125) / 10000
intervals.append([phn, pst, pen])
# splited_line[-1]!='sp'
if len(splited_line) == 5:
current_word = str(index) + '_' + splited_line[-1]
word2phns[current_word] = phn
index += 1
elif len(splited_line) == 4:
word2phns[current_word] += ' ' + phn
i += 1
return intervals, word2phns
此差异已折叠。
import argparse
def parse_args():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(
description="Synthesize with acoustic model & vocoder")
# acoustic model
parser.add_argument(
'--am',
type=str,
default='fastspeech2_csmsc',
choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc',
'tacotron2_ljspeech', 'tacotron2_aishell3'
],
help='Choose acoustic model type of tts task.')
parser.add_argument(
'--am_config',
type=str,
default=None,
help='Config of acoustic model. Use deault config when it is None.')
parser.add_argument(
'--am_ckpt',
type=str,
default=None,
help='Checkpoint file of acoustic model.')
parser.add_argument(
"--am_stat",
type=str,
default=None,
help="mean and standard deviation used to normalize spectrogram when training acoustic model."
)
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--tones_dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument(
"--speaker_dict", type=str, default=None, help="speaker id map file.")
# vocoder
parser.add_argument(
'--voc',
type=str,
default='pwgan_aishell3',
choices=[
'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc',
'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk',
'style_melgan_csmsc'
],
help='Choose vocoder type of tts task.')
parser.add_argument(
'--voc_config',
type=str,
default=None,
help='Config of voc. Use deault config when it is None.')
parser.add_argument(
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument(
"--voc_stat",
type=str,
default=None,
help="mean and standard deviation used to normalize spectrogram when training voc."
)
# other
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument("--model_name", type=str, help="model name")
parser.add_argument("--uid", type=str, help="uid")
parser.add_argument("--new_str", type=str, help="new string")
parser.add_argument("--prefix", type=str, help="prefix")
parser.add_argument(
"--source_lang", type=str, default="english", help="source language")
parser.add_argument(
"--target_lang", type=str, default="english", help="target language")
parser.add_argument("--output_name", type=str, help="output name")
parser.add_argument("--task_name", type=str, help="task name")
# pre
args = parser.parse_args()
return args
from pathlib import Path
from typing import Dict
from typing import List
from typing import Union
import numpy as np
import paddle
import yaml
from sedit_arg_parser import parse_args
from yacs.config import CfgNode
from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
def read_2col_text(path: Union[Path, str]) -> Dict[str, str]:
"""Read a text file having 2 column as dict object.
Examples:
wav.scp:
key1 /some/path/a.wav
key2 /some/path/b.wav
>>> read_2col_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
data = {}
with Path(path).open("r", encoding="utf-8") as f:
for linenum, line in enumerate(f, 1):
sps = line.rstrip().split(maxsplit=1)
if len(sps) == 1:
k, v = sps[0], ""
else:
k, v = sps
if k in data:
raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
data[k] = v
return data
def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int"
) -> Dict[str, List[Union[float, int]]]:
"""Read a text file indicating sequences of number
Examples:
key1 1 2 3
key2 34 5 6
>>> d = load_num_sequence_text('text')
>>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3]))
"""
if loader_type == "text_int":
delimiter = " "
dtype = int
elif loader_type == "text_float":
delimiter = " "
dtype = float
elif loader_type == "csv_int":
delimiter = ","
dtype = int
elif loader_type == "csv_float":
delimiter = ","
dtype = float
else:
raise ValueError(f"Not supported loader_type={loader_type}")
# path looks like:
# utta 1,0
# uttb 3,4,5
# -> return {'utta': np.ndarray([1, 0]),
# 'uttb': np.ndarray([3, 4, 5])}
d = read_2column_text(path)
# Using for-loop instead of dict-comprehension for debuggability
retval = {}
for k, v in d.items():
try:
retval[k] = [dtype(i) for i in v.split(delimiter)]
except TypeError:
print(f'Error happened with path="{path}", id="{k}", value="{v}"')
raise
return retval
def is_chinese(ch):
if u'\u4e00' <= ch <= u'\u9fff':
return True
else:
return False
def get_voc_out(mel):
# vocoder
args = parse_args()
with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f))
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
with paddle.no_grad():
wav = voc_inference(mel)
return np.squeeze(wav)
def eval_durs(phns, target_lang="chinese", fs=24000, hop_length=300):
args = parse_args()
if target_lang == 'english':
args.am = "fastspeech2_ljspeech"
args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_lang == 'chinese':
args.am = "fastspeech2_csmsc"
args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
if args.ngpu == 0:
paddle.set_device("cpu")
elif args.ngpu > 0:
paddle.set_device("gpu")
else:
print("ngpu should >= 0 !")
# Init body.
with open(args.am_config) as f:
am_config = CfgNode(yaml.safe_load(f))
am_inference, am = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict,
return_am=True)
vocab_phones = {}
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
for tone, id in phn_id:
vocab_phones[tone] = int(id)
vocab_size = len(vocab_phones)
phonemes = [phn if phn in vocab_phones else "sp" for phn in phns]
phone_ids = [vocab_phones[item] for item in phonemes]
phone_ids.append(vocab_size - 1)
phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64))
_, d_outs, _, _ = am.inference(phone_ids, spk_id=None, spk_emb=None)
pre_d_outs = d_outs
phu_durs_new = pre_d_outs * hop_length / fs
phu_durs_new = phu_durs_new.tolist()[:-1]
return phu_durs_new
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=ernie_sat
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
\ No newline at end of file
p243_new For that reason cover should not be given.
Prompt_003_new This was not the show for me.
p299_096 We are trying to establish a date.
p243_new ../../prompt_wav/p243_313.wav
Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav
p299_096 ../../prompt_wav/p299_096.wav
#!/bin/bash
set -e
source path.sh
# en --> zh 的 语音合成
# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好'
# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
python local/inference.py \
--task_name=cross-lingual_clone \
--model_name=paddle_checkpoint_dual_mask_enzh \
--uid=Prompt_003_new \
--new_str='今天天气很好.' \
--prefix='./prompt/dev/' \
--source_lang=english \
--target_lang=chinese \
--output_name=pred_clone.wav \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_csmsc \
--am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \
--am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \
--am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
#!/bin/bash
set -e
source path.sh
# 纯英文的语音合成
# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
python local/inference.py \
--task_name=synthesize \
--model_name=paddle_checkpoint_en \
--uid=p299_096 \
--new_str='I enjoy my life, do you?' \
--prefix='./prompt/dev/' \
--source_lang=english \
--target_lang=english \
--output_name=pred_gen.wav \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_ljspeech \
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
#!/bin/bash
set -e
source path.sh
# 纯英文的语音编辑
# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音
# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作
python local/inference.py \
--task_name=edit \
--model_name=paddle_checkpoint_en \
--uid=p243_new \
--new_str='for that reason cover is impossible to be given.' \
--prefix='./prompt/dev/' \
--source_lang=english \
--target_lang=english \
--output_name=pred_edit.wav \
--voc=pwgan_aishell3 \
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
--am=fastspeech2_ljspeech \
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
#!/bin/bash
rm -rf *.wav
./run_sedit_en.sh # 语音编辑任务(英文)
./run_gen_en.sh # 个性化语音合成任务(英文)
./run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
\ No newline at end of file
# ERNIE SAT with VCTK dataset
...@@ -11,10 +11,21 @@ ...@@ -11,10 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
import numpy as np import numpy as np
import paddle import paddle
from paddlespeech.t2s.datasets.batch import batch_sequences from paddlespeech.t2s.datasets.batch import batch_sequences
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.modules.nets_utils import get_seg_pos
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import pad_list
from paddlespeech.t2s.modules.nets_utils import phones_masking
from paddlespeech.t2s.modules.nets_utils import phones_text_masking
def tacotron2_single_spk_batch_fn(examples): def tacotron2_single_spk_batch_fn(examples):
...@@ -335,3 +346,182 @@ def vits_single_spk_batch_fn(examples): ...@@ -335,3 +346,182 @@ def vits_single_spk_batch_fn(examples):
"speech": speech "speech": speech
} }
return batch return batch
# for ERNIE SAT
class MLMCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(
self,
feats_extract,
mlm_prob: float=0.8,
mean_phn_span: int=8,
seg_emb: bool=False,
text_masking: bool=False,
attention_window: int=0,
not_sequence: Collection[str]=(), ):
self.mlm_prob = mlm_prob
self.mean_phn_span = mean_phn_span
self.feats_extract = feats_extract
self.not_sequence = set(not_sequence)
self.attention_window = attention_window
self.seg_emb = seg_emb
self.text_masking = text_masking
def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
return mlm_collate_fn(
data,
feats_extract=self.feats_extract,
mlm_prob=self.mlm_prob,
mean_phn_span=self.mean_phn_span,
seg_emb=self.seg_emb,
text_masking=self.text_masking,
attention_window=self.attention_window,
not_sequence=self.not_sequence)
def mlm_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
feats_extract=None,
mlm_prob: float=0.8,
mean_phn_span: int=8,
seg_emb: bool=False,
text_masking: bool=False,
attention_window: int=0,
pad_value: int=0,
not_sequence: Collection[str]=(),
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(not k.endswith("_lens")
for k in data[0]), f"*_lens is reserved: {list(data[0])}"
output = {}
for key in data[0]:
array_list = [d[key] for d in data]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list = [paddle.to_tensor(a) for a in array_list]
# tensor: (Batch, Length, ...)
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
# lens: (Batch,)
if key not in not_sequence:
lens = paddle.to_tensor(
[d[key].shape[0] for d in data], dtype=paddle.int64)
output[key + "_lens"] = lens
feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
feats = paddle.to_tensor(feats)
feats_lens = paddle.shape(feats)[0]
feats = paddle.unsqueeze(feats, 0)
text = output["text"]
text_lens = output["text_lens"]
align_start = output["align_start"]
align_start_lens = output["align_start_lens"]
align_end = output["align_end"]
max_tlen = max(text_lens)
max_slen = max(feats_lens)
speech_pad = feats[:, :max_slen]
text_pad = text
text_mask = make_non_pad_mask(
text_lens, text_pad, length_dim=1).unsqueeze(-2)
speech_mask = make_non_pad_mask(
feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
span_bdy = None
if 'span_bdy' in output.keys():
span_bdy = output['span_bdy']
# dual_mask 的是混合中英时候同时 mask 语音和文本
# ernie sat 在实现跨语言的时候都 mask 了
if text_masking:
masked_pos, text_masked_pos = phones_text_masking(
xs_pad=speech_pad,
src_mask=speech_mask,
text_pad=text_pad,
text_mask=text_mask,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
span_bdy=span_bdy)
# 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了
# a3t 和 ernie sat 的区别主要在于做 mask 的时候
else:
masked_pos = phones_masking(
xs_pad=speech_pad,
src_mask=speech_mask,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
span_bdy=span_bdy)
text_masked_pos = paddle.zeros(paddle.shape(text_pad))
output_dict = {}
speech_seg_pos, text_seg_pos = get_seg_pos(
speech_pad=speech_pad,
text_pad=text_pad,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
seg_emb=seg_emb)
output_dict['speech'] = speech_pad
output_dict['text'] = text_pad
output_dict['masked_pos'] = masked_pos
output_dict['text_masked_pos'] = text_masked_pos
output_dict['speech_mask'] = speech_mask
output_dict['text_mask'] = text_mask
output_dict['speech_seg_pos'] = speech_seg_pos
output_dict['text_seg_pos'] = text_seg_pos
output = (uttids, output_dict)
return output
def build_mlm_collate_fn(
sr: int=24000,
n_fft: int=2048,
hop_length: int=300,
win_length: int=None,
n_mels: int=80,
fmin: int=80,
fmax: int=7600,
mlm_prob: float=0.8,
mean_phn_span: int=8,
seg_emb: bool=False,
epoch: int=-1, ):
feats_extract_class = LogMelFBank
feats_extract = feats_extract_class(
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mels=n_mels,
fmin=fmin,
fmax=fmax)
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_prob_factor = 0.8
return MLMCollateFn(
feats_extract=feats_extract,
mlm_prob=mlm_prob * mlm_prob_factor,
mean_phn_span=mean_phn_span,
seg_emb=seg_emb)
...@@ -147,14 +147,14 @@ def get_frontend(lang: str='zh', ...@@ -147,14 +147,14 @@ def get_frontend(lang: str='zh',
# dygraph # dygraph
def get_am_inference( def get_am_inference(am: str='fastspeech2_csmsc',
am: str='fastspeech2_csmsc', am_config: CfgNode=None,
am_config: CfgNode=None, am_ckpt: Optional[os.PathLike]=None,
am_ckpt: Optional[os.PathLike]=None, am_stat: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None, phones_dict: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None, tones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None, speaker_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None, ): return_am: bool=False):
with open(phones_dict, "r") as f: with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
...@@ -203,7 +203,10 @@ def get_am_inference( ...@@ -203,7 +203,10 @@ def get_am_inference(
am_inference = am_inference_class(am_normalizer, am) am_inference = am_inference_class(am_normalizer, am)
am_inference.eval() am_inference.eval()
print("acoustic model done!") print("acoustic model done!")
return am_inference if return_am:
return am_inference, am
else:
return am_inference
def get_voc_inference( def get_voc_inference(
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .ernie_sat import *
from .fastspeech2 import * from .fastspeech2 import *
from .hifigan import * from .hifigan import *
from .melgan import * from .melgan import *
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mlm import *
此差异已折叠。
...@@ -1007,3 +1007,55 @@ class KLDivergenceLoss(nn.Layer): ...@@ -1007,3 +1007,55 @@ class KLDivergenceLoss(nn.Layer):
loss = kl / paddle.sum(z_mask) loss = kl / paddle.sum(z_mask)
return loss return loss
# loss for ERNIE SAT
class MLMLoss(nn.Layer):
def __init__(self,
lsm_weight: float=0.1,
ignore_id: int=-1,
text_masking: bool=False):
super().__init__()
if text_masking:
self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id)
if lsm_weight > 50:
self.l1_loss_func = nn.MSELoss()
else:
self.l1_loss_func = nn.L1Loss(reduction='none')
self.text_masking = text_masking
def forward(self,
speech: paddle.Tensor,
before_outs: paddle.Tensor,
after_outs: paddle.Tensor,
masked_pos: paddle.Tensor,
text: paddle.Tensor=None,
text_outs: paddle.Tensor=None,
text_masked_pos: paddle.Tensor=None):
xs_pad = speech
mlm_loss_pos = masked_pos > 0
loss = paddle.sum(
self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
if after_outs is not None:
loss += paddle.sum(
self.l1_loss_func(
paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
loss_mlm = paddle.sum((loss * paddle.reshape(
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
if self.text_masking:
loss_text = paddle.sum((self.text_mlm_loss(
paddle.reshape(text_outs, (-1, self.vocab_size)),
paddle.reshape(text, (-1))) * paddle.reshape(
text_masked_pos,
(-1)))) / paddle.sum((text_masked_pos) + 1e-10)
return loss_mlm, loss_text
return loss_mlm
...@@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -220,3 +220,99 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask) return self.forward_attention(v, scores, mask)
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = paddle.create_parameter(
shape=(self.h, self.d_k),
dtype='float32',
default_initializer=paddle.nn.initializer.XavierUniform())
self.pos_bias_v = paddle.create_parameter(
shape=(self.h, self.d_k),
dtype='float32',
default_initializer=paddle.nn.initializer.XavierUniform())
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x(Tensor): Input tensor (batch, head, time1, time2).
Returns:
Tensor:Output tensor.
"""
b, h, t1, t2 = paddle.shape(x)
zero_pad = paddle.zeros((b, h, t1, 1))
x_padded = paddle.concat([zero_pad, x], axis=-1)
x_padded = paddle.reshape(x_padded, [b, h, t2 + 1, t1])
# only keep the positions from 0 to time2
x = paddle.reshape(x_padded[:, :, 1:], [b, h, t1, t2])
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query(Tensor): Query tensor (#batch, time1, size).
key(Tensor): Key tensor (#batch, time2, size).
value(Tensor): Value tensor (#batch, time2, size).
pos_emb(Tensor): Positional embedding tensor (#batch, time1, size).
mask(Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
Returns:
Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
# (batch, time1, head, d_k)
q = paddle.transpose(q, [0, 2, 1, 3])
n_batch_pos = paddle.shape(pos_emb)[0]
p = paddle.reshape(
self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k])
# (batch, head, time1, d_k)
p = paddle.transpose(p, [0, 2, 1, 3])
# (batch, head, time1, d_k)
q_with_bias_u = paddle.transpose((q + self.pos_bias_u), [0, 2, 1, 3])
# (batch, head, time1, d_k)
q_with_bias_v = paddle.transpose((q + self.pos_bias_v), [0, 2, 1, 3])
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = paddle.matmul(q_with_bias_u,
paddle.transpose(k, [0, 1, 3, 2]))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd = paddle.matmul(q_with_bias_v,
paddle.transpose(p, [0, 1, 3, 2]))
matrix_bd = self.rel_shift(matrix_bd)
# (batch, head, time1, time2)
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
...@@ -185,3 +185,61 @@ class RelPositionalEncoding(nn.Layer): ...@@ -185,3 +185,61 @@ class RelPositionalEncoding(nn.Layer):
pe_size = paddle.shape(self.pe) pe_size = paddle.shape(self.pe)
pos_emb = self.pe[:, pe_size[1] // 2 - T + 1:pe_size[1] // 2 + T, ] pos_emb = self.pe[:, pe_size[1] // 2 - T + 1:pe_size[1] // 2 + T, ]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
class LegacyRelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]:
return
pe = paddle.zeros((paddle.shape(x)[1], self.d_model))
if self.reverse:
position = paddle.arange(
paddle.shape(x)[1] - 1, -1, -1.0,
dtype=paddle.float32).unsqueeze(1)
else:
position = paddle.arange(
0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe
def forward(self, x: paddle.Tensor):
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[:, :paddle.shape(x)[1]]
return self.dropout(x), self.dropout(pos_emb)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册