未验证 提交 1af9bd47 编写于 作者: H HuangLiangJie 提交者: GitHub

[TTS]Cantonese FastSpeech2 e2e infer, test=tts (#2927)

上级 004a4d60
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
stage=0
stop_stage=0
# pwgan
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_e2e.py \
--am=fastspeech2_canton \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=pwgan_aishell3 \
--voc_config=pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
--lang=canton \
--text=${BIN_DIR}/../sentences_canton.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \
--spk_id=0 \
--inference_dir=${train_output_path}/inference
fi
# hifigan
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "in hifigan syn_e2e"
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_e2e.py \
--am=fastspeech2_canton \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=hifigan_aishell3 \
--voc_config=hifigan_aishell3_ckpt_0.2.0/default.yaml \
--voc_ckpt=hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz \
--voc_stat=hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \
--lang=canton \
--text=${BIN_DIR}/../sentences_canton.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \
--spk_id=0 \
--inference_dir=${train_output_path}/inference
fi
......@@ -9,7 +9,8 @@ stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_112793.pdz
ckpt_name=snapshot_iter_280000.pdz
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
......
001 白云山爬过一次嘅,好远啊,爬上去都成两个钟
002 睇书咯,番屋企,而家好多人好少睇书噶喎
003 因为如果唔考试嘅话,工资好低噶
004 冇固定噶,你中意休边日就边日噶
005 即系太迟嘅话咧,落班太迟嘅话就喺出边食啲咯
006 是非有公理,慎言莫冒犯别人
007 遇上冷风雨,休太认真
......@@ -33,6 +33,7 @@ from paddlespeech.t2s.datasets.am_batch_fn import *
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip_static
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.canton_frontend import CantonFrontend
from paddlespeech.t2s.frontend.mix_frontend import MixFrontend
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore
......@@ -111,7 +112,7 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
if line.strip() != "":
items = re.split(r"\s+", line.strip(), 1)
utt_id = items[0]
if lang == 'zh':
if lang in {'zh', 'canton'}:
sentence = "".join(items[1:])
elif lang == 'en':
sentence = " ".join(items[1:])
......@@ -132,8 +133,8 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]],
converters = {}
if am_name == 'fastspeech2':
fields = ["utt_id", "text"]
if am_dataset in {"aishell3", "vctk",
"mix"} and speaker_dict is not None:
if am_dataset in {"aishell3", "vctk", "mix",
"canton"} and speaker_dict is not None:
print("multiple speaker fastspeech2!")
fields += ["spk_id"]
elif voice_cloning:
......@@ -177,8 +178,8 @@ def get_dev_dataloader(dev_metadata: List[Dict[str, Any]],
converters = {}
if am_name == 'fastspeech2':
fields = ["utt_id", "text"]
if am_dataset in {"aishell3", "vctk",
"mix"} and speaker_dict is not None:
if am_dataset in {"aishell3", "vctk", "mix",
"canton"} and speaker_dict is not None:
print("multiple speaker fastspeech2!")
collate_fn = fastspeech2_multi_spk_batch_fn_static
fields += ["spk_id"]
......@@ -266,6 +267,8 @@ def get_frontend(lang: str='zh',
phone_vocab_path=phones_dict,
tone_vocab_path=tones_dict,
use_rhy=use_rhy)
elif lang == 'canton':
frontend = CantonFrontend(phone_vocab_path=phones_dict)
elif lang == 'en':
frontend = English(phone_vocab_path=phones_dict)
elif lang == 'mix':
......@@ -302,6 +305,10 @@ def run_frontend(frontend: object,
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
outs.update({'tone_ids': tone_ids})
elif lang == 'canton':
input_ids = frontend.get_input_ids(
text, merge_sentences=merge_sentences, to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
elif lang == 'en':
input_ids = frontend.get_input_ids(
text, merge_sentences=merge_sentences, to_tensor=to_tensor)
......@@ -311,7 +318,7 @@ def run_frontend(frontend: object,
text, merge_sentences=merge_sentences, to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en', 'mix'}!")
print("lang should in {'zh', 'en', 'mix', 'canton'}!")
outs.update({'phone_ids': phone_ids})
return outs
......@@ -411,8 +418,8 @@ def am_to_static(am_inference,
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2':
if am_dataset in {"aishell3", "vctk",
"mix"} and speaker_dict is not None:
if am_dataset in {"aishell3", "vctk", "mix",
"canton"} and speaker_dict is not None:
am_inference = jit.to_static(
am_inference,
input_spec=[
......@@ -424,8 +431,8 @@ def am_to_static(am_inference,
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
elif am_name == 'speedyspeech':
if am_dataset in {"aishell3", "vctk",
"mix"} and speaker_dict is not None:
if am_dataset in {"aishell3", "vctk", "mix",
"canton"} and speaker_dict is not None:
am_inference = jit.to_static(
am_inference,
input_spec=[
......@@ -575,7 +582,7 @@ def get_am_output(
get_tone_ids = False
if am_name == 'speedyspeech':
get_tone_ids = True
if am_dataset in {"aishell3", "vctk", "mix"} and speaker_dict:
if am_dataset in {"aishell3", "vctk", "mix", "canton"} and speaker_dict:
get_spk_id = True
spk_id = np.array([spk_id])
......
......@@ -136,7 +136,8 @@ def parse_args():
choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc',
'tacotron2_ljspeech', 'tacotron2_aishell3', 'fastspeech2_mix'
'tacotron2_ljspeech', 'tacotron2_aishell3', 'fastspeech2_mix',
'fastspeech2_canton'
],
help='Choose acoustic model type of tts task.')
parser.add_argument(
......
......@@ -119,7 +119,7 @@ def evaluate(args):
# acoustic model
if am_name == 'fastspeech2':
# multi speaker
if am_dataset in {"aishell3", "vctk", "mix"}:
if am_dataset in {"aishell3", "vctk", "mix", "canton"}:
spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, spk_id)
else:
......@@ -167,7 +167,8 @@ def parse_args():
choices=[
'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc',
'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk',
'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix'
'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix',
'fastspeech2_canton'
],
help='Choose acoustic model type of tts task.')
parser.add_argument(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import List
import numpy as np
import paddle
import ToJyutping
from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer
INITIALS = [
'p', 'b', 't', 'd', 'ts', 'dz', 'k', 'g', 'kw', 'gw', 'f', 'h', 'l', 'm',
'ng', 'n', 's', 'y', 'w', 'c', 'z', 'j'
]
INITIALS += ['sp', 'spl', 'spn', 'sil']
def get_lines(cantons: List[str]):
phones = []
for canton in cantons:
for consonant in INITIALS:
if canton.startswith(consonant):
c, v = canton[:len(consonant)], canton[len(consonant):]
phones = phones + [c, v]
return phones
class CantonFrontend():
def __init__(self, phone_vocab_path: str):
self.text_normalizer = TextNormalizer()
self.punc = ":,;。?!“”‘’':,;.?!"
self.vocab_phones = {}
if phone_vocab_path:
with open(phone_vocab_path, 'rt', encoding='utf-8') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
self.vocab_phones[phn] = int(id)
# if merge_sentences, merge all sentences into one phone sequence
def _g2p(self, sentences: List[str],
merge_sentences: bool=True) -> List[List[str]]:
phones_list = []
for sentence in sentences:
phones_str = ToJyutping.get_jyutping_text(sentence)
phones_split = get_lines(phones_str.split(' '))
phones_list.append(phones_split)
return phones_list
def _p2id(self, phonemes: List[str]) -> np.ndarray:
# replace unk phone with sp
phonemes = [
phn if phn in self.vocab_phones else "sp" for phn in phonemes
]
phone_ids = [self.vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64)
def get_phonemes(self,
sentence: str,
merge_sentences: bool=True,
print_info: bool=False) -> List[List[str]]:
sentences = self.text_normalizer.normalize(sentence)
phonemes = self._g2p(sentences, merge_sentences=merge_sentences)
if print_info:
print("----------------------------")
print("text norm results:")
print(sentences)
print("----------------------------")
print("g2p results:")
print(phonemes)
print("----------------------------")
return phonemes
def get_input_ids(self,
sentence: str,
merge_sentences: bool=True,
print_info: bool=False,
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
phonemes = self.get_phonemes(
sentence, merge_sentences=merge_sentences, print_info=print_info)
result = {}
temp_phone_ids = []
for phones in phonemes:
if phones:
phone_ids = self._p2id(phones)
# if use paddle.to_tensor() in onnxruntime, the first time will be too low
if to_tensor:
phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids)
if temp_phone_ids:
result["phone_ids"] = temp_phone_ids
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册