diff --git a/.gitignore b/.gitignore index 639472001a719aca5cb93e851ef1f628fc3cae9b..7328b329420c880e71b0f99879e5aa64de885632 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,12 @@ tools/Miniconda3-latest-Linux-x86_64.sh tools/activate_python.sh tools/miniconda.sh tools/CRF++-0.58/ +tools/liblbfgs-1.10/ +tools/srilm/ +tools/env.sh +tools/openfst-1.8.1/ +tools/libsndfile/ +tools/python-soundfile/ speechx/fc_patch/ diff --git a/examples/csmsc/tts2/local/inference.sh b/examples/csmsc/tts2/local/inference.sh index d78c3eb3273362fcbae7288005e7ce848e059379..ed92136cde13d70034de93e7e583ef79bfa1bdda 100755 --- a/examples/csmsc/tts2/local/inference.sh +++ b/examples/csmsc/tts2/local/inference.sh @@ -30,21 +30,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --tones_dict=dump/tone_id_map.txt fi -# style melgan -# style melgan's Dygraph to Static Graph is not ready now -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - python3 ${BIN_DIR}/../inference.py \ - --inference_dir=${train_output_path}/inference \ - --am=speedyspeech_csmsc \ - --voc=style_melgan_csmsc \ - --text=${BIN_DIR}/../sentences.txt \ - --output_dir=${train_output_path}/pd_infer_out \ - --phones_dict=dump/phone_id_map.txt \ - --tones_dict=dump/tone_id_map.txt -fi - # hifigan -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then python3 ${BIN_DIR}/../inference.py \ --inference_dir=${train_output_path}/inference \ --am=speedyspeech_csmsc \ diff --git a/examples/csmsc/tts3/README.md b/examples/csmsc/tts3/README.md index bc672f66f1eea154436323fcef236456c3cd17b5..c734199b46d9da2dc37482f9a4f75d7375bdaa8e 100644 --- a/examples/csmsc/tts3/README.md +++ b/examples/csmsc/tts3/README.md @@ -231,14 +231,19 @@ Pretrained FastSpeech2 model with no silence in the edge of audios: The static model can be downloaded here: - [fastspeech2_nosil_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip) - [fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip) +- [fastspeech2_cnndecoder_csmsc_static_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_static_1.0.0.zip) +- [fastspeech2_cnndecoder_csmsc_streaming_static_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_static_1.0.0.zip) The ONNX model can be downloaded here: - [fastspeech2_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip) +- [fastspeech2_cnndecoder_csmsc_onnx_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_onnx_1.0.0.zip) +- [fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip) Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss :-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------: default| 2(gpu) x 76000|1.0991|0.59132|0.035815|0.31915|0.15287| conformer| 2(gpu) x 76000|1.0675|0.56103|0.035869|0.31553|0.15509| +cnndecoder| 1(gpu) x 153000|1.1153|0.61475|0.03380|0.30414|0.14707| FastSpeech2 checkpoint contains files listed below. ```text diff --git a/examples/csmsc/tts3/local/inference.sh b/examples/csmsc/tts3/local/inference.sh index 9322cfd697912100663457f2b9bcada543e27733..7052b347dd04e51e0582d0c214df7ba7c984a6b2 100755 --- a/examples/csmsc/tts3/local/inference.sh +++ b/examples/csmsc/tts3/local/inference.sh @@ -5,6 +5,7 @@ train_output_path=$1 stage=0 stop_stage=0 +# pwgan if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then python3 ${BIN_DIR}/../inference.py \ --inference_dir=${train_output_path}/inference \ @@ -27,20 +28,9 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --phones_dict=dump/phone_id_map.txt fi -# style melgan -# style melgan's Dygraph to Static Graph is not ready now -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - python3 ${BIN_DIR}/../inference.py \ - --inference_dir=${train_output_path}/inference \ - --am=fastspeech2_csmsc \ - --voc=style_melgan_csmsc \ - --text=${BIN_DIR}/../sentences.txt \ - --output_dir=${train_output_path}/pd_infer_out \ - --phones_dict=dump/phone_id_map.txt -fi # hifigan -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then python3 ${BIN_DIR}/../inference.py \ --inference_dir=${train_output_path}/inference \ --am=fastspeech2_csmsc \ @@ -51,7 +41,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then fi # wavernn -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then python3 ${BIN_DIR}/../inference.py \ --inference_dir=${train_output_path}/inference \ --am=fastspeech2_csmsc \ diff --git a/examples/csmsc/tts3/local/inference_streaming.sh b/examples/csmsc/tts3/local/inference_streaming.sh new file mode 100755 index 0000000000000000000000000000000000000000..719f46c620aee6e8eb3496a7aaa4e40abf7e9817 --- /dev/null +++ b/examples/csmsc/tts3/local/inference_streaming.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +train_output_path=$1 + +stage=0 +stop_stage=0 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/../inference_streaming.py \ + --inference_dir=${train_output_path}/inference_streaming \ + --am=fastspeech2_csmsc \ + --am_stat=dump/train/speech_stats.npy \ + --voc=pwgan_csmsc \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out_streaming \ + --phones_dict=dump/phone_id_map.txt \ + --am_streaming=True +fi + +# for more GAN Vocoders +# multi band melgan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + python3 ${BIN_DIR}/../inference_streaming.py \ + --inference_dir=${train_output_path}/inference_streaming \ + --am=fastspeech2_csmsc \ + --am_stat=dump/train/speech_stats.npy \ + --voc=mb_melgan_csmsc \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out_streaming \ + --phones_dict=dump/phone_id_map.txt \ + --am_streaming=True +fi + +# hifigan +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + python3 ${BIN_DIR}/../inference_streaming.py \ + --inference_dir=${train_output_path}/inference_streaming \ + --am=fastspeech2_csmsc \ + --am_stat=dump/train/speech_stats.npy \ + --voc=hifigan_csmsc \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out_streaming \ + --phones_dict=dump/phone_id_map.txt \ + --am_streaming=True +fi + diff --git a/examples/csmsc/tts3/local/ort_predict_streaming.sh b/examples/csmsc/tts3/local/ort_predict_streaming.sh new file mode 100755 index 0000000000000000000000000000000000000000..502ec912a236ee361cf7adb9aa3adaffe11fba3d --- /dev/null +++ b/examples/csmsc/tts3/local/ort_predict_streaming.sh @@ -0,0 +1,19 @@ +train_output_path=$1 + +stage=0 +stop_stage=0 + +# e2e, synthesize from text +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/../ort_predict_streaming.py \ + --inference_dir=${train_output_path}/inference_onnx_streaming \ + --am=fastspeech2_csmsc \ + --am_stat=dump/train/speech_stats.npy \ + --voc=hifigan_csmsc \ + --output_dir=${train_output_path}/onnx_infer_out_streaming \ + --text=${BIN_DIR}/../csmsc_test.txt \ + --phones_dict=dump/phone_id_map.txt \ + --device=cpu \ + --cpu_threads=2 \ + --am_streaming=True +fi diff --git a/examples/csmsc/tts3/local/synthesize_streaming.sh b/examples/csmsc/tts3/local/synthesize_streaming.sh index 7606c23857fd76d958f8b4757345badf4fb1b9c8..b135db76d42e90a922b2d07d2b664c7c24690a0f 100755 --- a/examples/csmsc/tts3/local/synthesize_streaming.sh +++ b/examples/csmsc/tts3/local/synthesize_streaming.sh @@ -88,5 +88,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --text=${BIN_DIR}/../sentences.txt \ --output_dir=${train_output_path}/test_e2e_streaming \ --phones_dict=dump/phone_id_map.txt \ - --am_streaming=True + --am_streaming=True \ + --inference_dir=${train_output_path}/inference_streaming fi diff --git a/examples/csmsc/tts3/run_cnndecoder.sh b/examples/csmsc/tts3/run_cnndecoder.sh index 5cccef01610515d18799e2af5d490f22266c0cb3..61cd02a932016387639426fc0c1efafc2f4c135a 100755 --- a/examples/csmsc/tts3/run_cnndecoder.sh +++ b/examples/csmsc/tts3/run_cnndecoder.sh @@ -31,18 +31,75 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 fi +# synthesize_e2e non-streaming if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # synthesize_e2e, vocoder is pwgan CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 fi +# inference non-streaming if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # inference with static model CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1 fi +# synthesize_e2e streaming if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # synthesize_e2e, vocoder is pwgan CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_streaming.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 fi +# inference streaming +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + # inference with static model + CUDA_VISIBLE_DEVICES=${gpus} ./local/inference_streaming.sh ${train_output_path} || exit -1 +fi + +# paddle2onnx non streaming +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + # install paddle2onnx + version=$(echo `pip list |grep "paddle2onnx"` |awk -F" " '{print $2}') + if [[ -z "$version" || ${version} != '0.9.4' ]]; then + pip install paddle2onnx==0.9.4 + fi + ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx fastspeech2_csmsc + ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_csmsc +fi + + +# onnxruntime non streaming +# inference with onnxruntime, use fastspeech2 + hifigan by default +if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then + # install onnxruntime + version=$(echo `pip list |grep "onnxruntime"` |awk -F" " '{print $2}') + if [[ -z "$version" || ${version} != '1.10.0' ]]; then + pip install onnxruntime==1.10.0 + fi + ./local/ort_predict.sh ${train_output_path} +fi + +# paddle2onnx streaming +if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then + # install paddle2onnx + version=$(echo `pip list |grep "paddle2onnx"` |awk -F" " '{print $2}') + if [[ -z "$version" || ${version} != '0.9.4' ]]; then + pip install paddle2onnx==0.9.4 + fi + # streaming acoustic model + ./local/paddle2onnx.sh ${train_output_path} inference_streaming inference_onnx_streaming fastspeech2_csmsc_am_encoder_infer + ./local/paddle2onnx.sh ${train_output_path} inference_streaming inference_onnx_streaming fastspeech2_csmsc_am_decoder + ./local/paddle2onnx.sh ${train_output_path} inference_streaming inference_onnx_streaming fastspeech2_csmsc_am_postnet + # vocoder + ./local/paddle2onnx.sh ${train_output_path} inference_streaming inference_onnx_streaming hifigan_csmsc +fi + +# onnxruntime streaming +if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then + # install onnxruntime + version=$(echo `pip list |grep "onnxruntime"` |awk -F" " '{print $2}') + if [[ -z "$version" || ${version} != '1.10.0' ]]; then + pip install onnxruntime==1.10.0 + fi + ./local/ort_predict_streaming.sh ${train_output_path} +fi + diff --git a/examples/other/ngram_lm/s0/local/build_zh_lm.sh b/examples/other/ngram_lm/s0/local/build_zh_lm.sh index 73eb165ec7d773eac0d58725c53b397daa9f635f..b031371f512c171f82eba17a31f738407d65e7ce 100644 --- a/examples/other/ngram_lm/s0/local/build_zh_lm.sh +++ b/examples/other/ngram_lm/s0/local/build_zh_lm.sh @@ -27,7 +27,7 @@ arpa=$3 if [ $stage -le 0 ] && [ $stop_stage -ge 0 ];then # text tn & wordseg preprocess echo "process text." - python3 ${MAIN_ROOT}/utils/zh_tn.py ${type} ${text} ${text}.${type}.tn + python3 ${MAIN_ROOT}/utils/zh_tn.py --token_type ${type} ${text} ${text}.${type}.tn fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ];then diff --git a/examples/other/ngram_lm/s0/local/download_lm_zh.sh b/examples/other/ngram_lm/s0/local/download_lm_zh.sh index f9e2261fdd42fcab6e5d5643c32d48d4abe43c90..050749ce1ba674d07fde74b0c85dc26f8cc490a3 100755 --- a/examples/other/ngram_lm/s0/local/download_lm_zh.sh +++ b/examples/other/ngram_lm/s0/local/download_lm_zh.sh @@ -10,6 +10,11 @@ MD5="29e02312deb2e59b3c8686c7966d4fe3" TARGET=${DIR}/zh_giga.no_cna_cmn.prune01244.klm +if [ -e $TARGET ];then + echo "already have lm" + exit 0; +fi + echo "Download language model ..." download $URL $MD5 $TARGET if [ $? -ne 0 ]; then diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index f1e46ca1bcd622ddad22c986443f736fe2fcdaed..49dd7b3567c6feed3e0fc5873e68ce5c386ff465 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -29,9 +29,10 @@ from ..download import get_path_from_url from ..executor import BaseExecutor from ..log import logger from ..utils import cli_register -from ..utils import download_and_decompress from ..utils import MODEL_HOME from ..utils import stats_wrapper +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import @@ -39,110 +40,13 @@ from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "conformer_wenetspeech-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz', - 'md5': - '76cb19ed857e6623856b7cd7ebbfeda4', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/conformer/checkpoints/wenetspeech', - }, - "transformer_librispeech-en-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz', - 'md5': - '2c667da24922aad391eacafe37bc1660', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/transformer/checkpoints/avg_10', - }, - "deepspeech2offline_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', - 'md5': - '932c3593d62fe5c741b59b31318aa314', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "deepspeech2online_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz', - 'md5': - '23e16c69730a1cb5d735c98c83c21e16', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2_online/checkpoints/avg_1', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "conformer2online_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz', - 'md5': - '4814e52e0fc2fd48899373f95c84b0c9', - 'cfg_path': - 'config.yaml', - 'ckpt_path': - 'exp/deepspeech2_online/checkpoints/avg_30', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "deepspeech2offline_librispeech-en-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz', - 'md5': - 'f5666c81ad015c8de03aac2bc92e5762', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm', - 'lm_md5': - '099a601759d467cd0a8523ff939819c5' - }, -} - -model_alias = { - "deepspeech2offline": - "paddlespeech.s2t.models.ds2:DeepSpeech2Model", - "deepspeech2online": - "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", - "conformer": - "paddlespeech.s2t.models.u2:U2Model", - "conformer_online": - "paddlespeech.s2t.models.u2:U2Model", - "transformer": - "paddlespeech.s2t.models.u2:U2Model", - "wenetspeech": - "paddlespeech.s2t.models.u2:U2Model", -} - - @cli_register( name='paddlespeech.asr', description='Speech to text infer command.') class ASRExecutor(BaseExecutor): def __init__(self): - super(ASRExecutor, self).__init__() + super().__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models self.parser = argparse.ArgumentParser( prog='paddlespeech.asr', add_help=True) @@ -152,7 +56,9 @@ class ASRExecutor(BaseExecutor): '--model', type=str, default='conformer_wenetspeech', - choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()], + choices=[ + tag[:tag.index('-')] for tag in self.pretrained_models.keys() + ], help='Choose model type of asr task.') self.parser.add_argument( '--lang', @@ -208,23 +114,6 @@ class ASRExecutor(BaseExecutor): action='store_true', help='Increase logger verbosity of current task.') - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - """ - support_models = list(pretrained_models.keys()) - assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( - tag, '\n\t\t'.join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(pretrained_models[tag], - res_path) - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - 'Use pretrained model stored in: {}'.format(decompressed_path)) - - return decompressed_path - def _init_from_path(self, model_type: str='wenetspeech', lang: str='zh', @@ -245,10 +134,11 @@ class ASRExecutor(BaseExecutor): tag = model_type + '-' + lang + '-' + sample_rate_str res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = os.path.join(res_path, - pretrained_models[tag]['cfg_path']) + self.cfg_path = os.path.join( + res_path, self.pretrained_models[tag]['cfg_path']) self.ckpt_path = os.path.join( - res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams") + res_path, + self.pretrained_models[tag]['ckpt_path'] + ".pdparams") logger.info(res_path) else: @@ -273,8 +163,8 @@ class ASRExecutor(BaseExecutor): self.collate_fn_test = SpeechCollator.from_config(self.config) self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.vocab) - lm_url = pretrained_models[tag]['lm_url'] - lm_md5 = pretrained_models[tag]['lm_md5'] + lm_url = self.pretrained_models[tag]['lm_url'] + lm_md5 = self.pretrained_models[tag]['lm_md5'] self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) @@ -291,7 +181,7 @@ class ASRExecutor(BaseExecutor): raise Exception("wrong type") model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} - model_class = dynamic_import(model_name, model_alias) + model_class = dynamic_import(model_name, self.model_alias) model_conf = self.config model = model_class.from_config(model_conf) self.model = model diff --git a/paddlespeech/cli/asr/pretrained_models.py b/paddlespeech/cli/asr/pretrained_models.py new file mode 100644 index 0000000000000000000000000000000000000000..a16c4750d3a6a5aa641192838a9c99ec64a72df2 --- /dev/null +++ b/paddlespeech/cli/asr/pretrained_models.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pretrained_models = { + # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". + # e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k". + # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: + # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" + "conformer_wenetspeech-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz', + 'md5': + '76cb19ed857e6623856b7cd7ebbfeda4', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/conformer/checkpoints/wenetspeech', + }, + "transformer_librispeech-en-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz', + 'md5': + '2c667da24922aad391eacafe37bc1660', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/transformer/checkpoints/avg_10', + }, + "deepspeech2offline_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', + 'md5': + '932c3593d62fe5c741b59b31318aa314', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_1', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + "deepspeech2online_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz', + 'md5': + '23e16c69730a1cb5d735c98c83c21e16', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + "deepspeech2offline_librispeech-en-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz', + 'md5': + 'f5666c81ad015c8de03aac2bc92e5762', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_1', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm', + 'lm_md5': + '099a601759d467cd0a8523ff939819c5' + }, +} + +model_alias = { + "deepspeech2offline": + "paddlespeech.s2t.models.ds2:DeepSpeech2Model", + "deepspeech2online": + "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", + "conformer": + "paddlespeech.s2t.models.u2:U2Model", + "transformer": + "paddlespeech.s2t.models.u2:U2Model", + "wenetspeech": + "paddlespeech.s2t.models.u2:U2Model", +} diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index f56d8a579c5d85a9376748b482897483e5886115..1f637a8fee69179f0ca5a538ce9c360d1c6cc935 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -25,55 +25,23 @@ import yaml from ..executor import BaseExecutor from ..log import logger from ..utils import cli_register -from ..utils import download_and_decompress -from ..utils import MODEL_HOME from ..utils import stats_wrapper +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models from paddleaudio import load from paddleaudio.features import LogMelSpectrogram from paddlespeech.s2t.utils.dynamic_import import dynamic_import __all__ = ['CLSExecutor'] -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "panns_cnn6-32k": { - 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz', - 'md5': '4cf09194a95df024fd12f84712cf0f9c', - 'cfg_path': 'panns.yaml', - 'ckpt_path': 'cnn6.pdparams', - 'label_file': 'audioset_labels.txt', - }, - "panns_cnn10-32k": { - 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz', - 'md5': 'cb8427b22176cc2116367d14847f5413', - 'cfg_path': 'panns.yaml', - 'ckpt_path': 'cnn10.pdparams', - 'label_file': 'audioset_labels.txt', - }, - "panns_cnn14-32k": { - 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz', - 'md5': 'e3b9b5614a1595001161d0ab95edee97', - 'cfg_path': 'panns.yaml', - 'ckpt_path': 'cnn14.pdparams', - 'label_file': 'audioset_labels.txt', - }, -} - -model_alias = { - "panns_cnn6": "paddlespeech.cls.models.panns:CNN6", - "panns_cnn10": "paddlespeech.cls.models.panns:CNN10", - "panns_cnn14": "paddlespeech.cls.models.panns:CNN14", -} - @cli_register( name='paddlespeech.cls', description='Audio classification infer command.') class CLSExecutor(BaseExecutor): def __init__(self): - super(CLSExecutor, self).__init__() + super().__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models self.parser = argparse.ArgumentParser( prog='paddlespeech.cls', add_help=True) @@ -83,7 +51,9 @@ class CLSExecutor(BaseExecutor): '--model', type=str, default='panns_cnn14', - choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()], + choices=[ + tag[:tag.index('-')] for tag in self.pretrained_models.keys() + ], help='Choose model type of cls task.') self.parser.add_argument( '--config', @@ -121,23 +91,6 @@ class CLSExecutor(BaseExecutor): action='store_true', help='Increase logger verbosity of current task.') - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - """ - support_models = list(pretrained_models.keys()) - assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( - tag, '\n\t\t'.join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(pretrained_models[tag], - res_path) - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - 'Use pretrained model stored in: {}'.format(decompressed_path)) - - return decompressed_path - def _init_from_path(self, model_type: str='panns_cnn14', cfg_path: Optional[os.PathLike]=None, @@ -153,12 +106,12 @@ class CLSExecutor(BaseExecutor): if label_file is None or ckpt_path is None: tag = model_type + '-' + '32k' # panns_cnn14-32k self.res_path = self._get_pretrained_path(tag) - self.cfg_path = os.path.join(self.res_path, - pretrained_models[tag]['cfg_path']) - self.label_file = os.path.join(self.res_path, - pretrained_models[tag]['label_file']) - self.ckpt_path = os.path.join(self.res_path, - pretrained_models[tag]['ckpt_path']) + self.cfg_path = os.path.join( + self.res_path, self.pretrained_models[tag]['cfg_path']) + self.label_file = os.path.join( + self.res_path, self.pretrained_models[tag]['label_file']) + self.ckpt_path = os.path.join( + self.res_path, self.pretrained_models[tag]['ckpt_path']) else: self.cfg_path = os.path.abspath(cfg_path) self.label_file = os.path.abspath(label_file) @@ -175,7 +128,7 @@ class CLSExecutor(BaseExecutor): self._label_list.append(line.strip()) # model - model_class = dynamic_import(model_type, model_alias) + model_class = dynamic_import(model_type, self.model_alias) model_dict = paddle.load(self.ckpt_path) self.model = model_class(extract_embedding=False) self.model.set_state_dict(model_dict) diff --git a/paddlespeech/cli/cls/pretrained_models.py b/paddlespeech/cli/cls/pretrained_models.py new file mode 100644 index 0000000000000000000000000000000000000000..1d66850aa7fa55733c8a0680889906894e235126 --- /dev/null +++ b/paddlespeech/cli/cls/pretrained_models.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pretrained_models = { + # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". + # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". + # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: + # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" + "panns_cnn6-32k": { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz', + 'md5': '4cf09194a95df024fd12f84712cf0f9c', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn6.pdparams', + 'label_file': 'audioset_labels.txt', + }, + "panns_cnn10-32k": { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz', + 'md5': 'cb8427b22176cc2116367d14847f5413', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn10.pdparams', + 'label_file': 'audioset_labels.txt', + }, + "panns_cnn14-32k": { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz', + 'md5': 'e3b9b5614a1595001161d0ab95edee97', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn14.pdparams', + 'label_file': 'audioset_labels.txt', + }, +} + +model_alias = { + "panns_cnn6": "paddlespeech.cls.models.panns:CNN6", + "panns_cnn10": "paddlespeech.cls.models.panns:CNN10", + "panns_cnn14": "paddlespeech.cls.models.panns:CNN14", +} diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index 064939a85da7a87ac0de8d68c8729d78a5c2125c..df0b6783823b7ac0e23b373f0bb898e0d103be64 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -25,6 +25,8 @@ from typing import Union import paddle from .log import logger +from .utils import download_and_decompress +from .utils import MODEL_HOME class BaseExecutor(ABC): @@ -35,19 +37,8 @@ class BaseExecutor(ABC): def __init__(self): self._inputs = OrderedDict() self._outputs = OrderedDict() - - @abstractmethod - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - - Args: - tag (str): A tag of pretrained model. - - Returns: - os.PathLike: The path on which resources of pretrained model locate. - """ - pass + self.pretrained_models = OrderedDict() + self.model_alias = OrderedDict() @abstractmethod def _init_from_path(self, *args, **kwargs): @@ -227,3 +218,20 @@ class BaseExecutor(ABC): ] for l in loggers: l.disabled = True + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + support_models = list(self.pretrained_models.keys()) + assert tag in self.pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(self.pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path diff --git a/paddlespeech/cli/st/infer.py b/paddlespeech/cli/st/infer.py index e64fc57d1bf2574a016e2655aa021a919e59ab98..29d95f79914e09d7dbf6e258d7ceaba9ae2228c6 100644 --- a/paddlespeech/cli/st/infer.py +++ b/paddlespeech/cli/st/infer.py @@ -32,40 +32,24 @@ from ..utils import cli_register from ..utils import download_and_decompress from ..utils import MODEL_HOME from ..utils import stats_wrapper +from .pretrained_models import kaldi_bins +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ["STExecutor"] -pretrained_models = { - "fat_st_ted-en-zh": { - "url": - "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz", - "md5": - "d62063f35a16d91210a71081bd2dd557", - "cfg_path": - "model.yaml", - "ckpt_path": - "exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams", - } -} - -model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"} - -kaldi_bins = { - "url": - "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz", - "md5": - "c0682303b3f3393dbf6ed4c4e35a53eb", -} - @cli_register( name="paddlespeech.st", description="Speech translation infer command.") class STExecutor(BaseExecutor): def __init__(self): - super(STExecutor, self).__init__() + super().__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models + self.kaldi_bins = kaldi_bins self.parser = argparse.ArgumentParser( prog="paddlespeech.st", add_help=True) @@ -75,7 +59,9 @@ class STExecutor(BaseExecutor): "--model", type=str, default="fat_st_ted", - choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()], + choices=[ + tag[:tag.index('-')] for tag in self.pretrained_models.keys() + ], help="Choose model type of st task.") self.parser.add_argument( "--src_lang", @@ -119,28 +105,11 @@ class STExecutor(BaseExecutor): action='store_true', help='Increase logger verbosity of current task.') - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - """ - support_models = list(pretrained_models.keys()) - assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( - tag, '\n\t\t'.join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(pretrained_models[tag], - res_path) - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - "Use pretrained model stored in: {}".format(decompressed_path)) - - return decompressed_path - def _set_kaldi_bins(self) -> os.PathLike: """ Download and returns kaldi_bins resources path of current task. """ - decompressed_path = download_and_decompress(kaldi_bins, MODEL_HOME) + decompressed_path = download_and_decompress(self.kaldi_bins, MODEL_HOME) decompressed_path = os.path.abspath(decompressed_path) logger.info("Kaldi_bins stored in: {}".format(decompressed_path)) if "LD_LIBRARY_PATH" in os.environ: @@ -197,7 +166,7 @@ class STExecutor(BaseExecutor): model_conf = self.config model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} - model_class = dynamic_import(model_name, model_alias) + model_class = dynamic_import(model_name, self.model_alias) self.model = model_class.from_config(model_conf) self.model.eval() diff --git a/paddlespeech/cli/st/pretrained_models.py b/paddlespeech/cli/st/pretrained_models.py new file mode 100644 index 0000000000000000000000000000000000000000..cc7410d253f34109424e49ea0d2622e12ce93ea5 --- /dev/null +++ b/paddlespeech/cli/st/pretrained_models.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pretrained_models = { + "fat_st_ted-en-zh": { + "url": + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz", + "md5": + "d62063f35a16d91210a71081bd2dd557", + "cfg_path": + "model.yaml", + "ckpt_path": + "exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams", + } +} + +model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"} + +kaldi_bins = { + "url": + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz", + "md5": + "c0682303b3f3393dbf6ed4c4e35a53eb", +} diff --git a/paddlespeech/cli/stats/infer.py b/paddlespeech/cli/stats/infer.py index 4ef50449c37e08c1a3c5f9b8894a5b4141e1c33f..7cf4f2368cbced90bac54cb61bdc1bd8fc3d07f8 100644 --- a/paddlespeech/cli/stats/infer.py +++ b/paddlespeech/cli/stats/infer.py @@ -16,7 +16,6 @@ from typing import List from prettytable import PrettyTable -from ..log import logger from ..utils import cli_register from ..utils import stats_wrapper @@ -27,7 +26,8 @@ model_name_format = { 'cls': 'Model-Sample Rate', 'st': 'Model-Source language-Target language', 'text': 'Model-Task-Language', - 'tts': 'Model-Language' + 'tts': 'Model-Language', + 'vector': 'Model-Sample Rate' } @@ -36,18 +36,18 @@ model_name_format = { description='Get speech tasks support models list.') class StatsExecutor(): def __init__(self): - super(StatsExecutor, self).__init__() + super().__init__() self.parser = argparse.ArgumentParser( prog='paddlespeech.stats', add_help=True) + self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector'] self.parser.add_argument( '--task', type=str, default='asr', - choices=['asr', 'cls', 'st', 'text', 'tts'], + choices=self.task_choices, help='Choose speech task.', required=True) - self.task_choices = ['asr', 'cls', 'st', 'text', 'tts'] def show_support_models(self, pretrained_models: dict): fields = model_name_format[self.task].split("-") @@ -61,73 +61,15 @@ class StatsExecutor(): Command line entry. """ parser_args = self.parser.parse_args(argv) - self.task = parser_args.task - if self.task not in self.task_choices: - logger.error( - "Please input correct speech task, choices = ['asr', 'cls', 'st', 'text', 'tts']" - ) + has_exceptions = False + try: + self(parser_args.task) + except Exception as e: + has_exceptions = True + if has_exceptions: return False - - elif self.task == 'asr': - try: - from ..asr.infer import pretrained_models - logger.info( - "Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - return True - except BaseException: - logger.error("Failed to get the list of ASR pretrained models.") - return False - - elif self.task == 'cls': - try: - from ..cls.infer import pretrained_models - logger.info( - "Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - return True - except BaseException: - logger.error("Failed to get the list of CLS pretrained models.") - return False - - elif self.task == 'st': - try: - from ..st.infer import pretrained_models - logger.info( - "Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - return True - except BaseException: - logger.error("Failed to get the list of ST pretrained models.") - return False - - elif self.task == 'text': - try: - from ..text.infer import pretrained_models - logger.info( - "Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - return True - except BaseException: - logger.error( - "Failed to get the list of TEXT pretrained models.") - return False - - elif self.task == 'tts': - try: - from ..tts.infer import pretrained_models - logger.info( - "Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - return True - except BaseException: - logger.error("Failed to get the list of TTS pretrained models.") - return False + else: + return True @stats_wrapper def __call__( @@ -138,13 +80,12 @@ class StatsExecutor(): """ self.task = task if self.task not in self.task_choices: - print( - "Please input correct speech task, choices = ['asr', 'cls', 'st', 'text', 'tts']" - ) + print("Please input correct speech task, choices = " + str( + self.task_choices)) elif self.task == 'asr': try: - from ..asr.infer import pretrained_models + from ..asr.pretrained_models import pretrained_models print( "Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API" ) @@ -154,7 +95,7 @@ class StatsExecutor(): elif self.task == 'cls': try: - from ..cls.infer import pretrained_models + from ..cls.pretrained_models import pretrained_models print( "Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API" ) @@ -164,7 +105,7 @@ class StatsExecutor(): elif self.task == 'st': try: - from ..st.infer import pretrained_models + from ..st.pretrained_models import pretrained_models print( "Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API" ) @@ -174,7 +115,7 @@ class StatsExecutor(): elif self.task == 'text': try: - from ..text.infer import pretrained_models + from ..text.pretrained_models import pretrained_models print( "Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API" ) @@ -184,10 +125,22 @@ class StatsExecutor(): elif self.task == 'tts': try: - from ..tts.infer import pretrained_models + from ..tts.pretrained_models import pretrained_models print( "Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API" ) self.show_support_models(pretrained_models) except BaseException: print("Failed to get the list of TTS pretrained models.") + + elif self.task == 'vector': + try: + from ..vector.pretrained_models import pretrained_models + print( + "Here is the list of Speaker Recognition pretrained models released by PaddleSpeech that can be used by command line and python API" + ) + self.show_support_models(pretrained_models) + except BaseException: + print( + "Failed to get the list of Speaker Recognition pretrained models." + ) diff --git a/paddlespeech/cli/text/infer.py b/paddlespeech/cli/text/infer.py index dcf306c69f850bbcf13c08b81c1bb906141c71ea..69e62e4b448ca24054bf8e07b22b704f2e78eb32 100644 --- a/paddlespeech/cli/text/infer.py +++ b/paddlespeech/cli/text/infer.py @@ -25,58 +25,21 @@ from ...s2t.utils.dynamic_import import dynamic_import from ..executor import BaseExecutor from ..log import logger from ..utils import cli_register -from ..utils import download_and_decompress -from ..utils import MODEL_HOME from ..utils import stats_wrapper +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models +from .pretrained_models import tokenizer_alias __all__ = ['TextExecutor'] -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "ernie_linear_p7_wudao-punc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz', - 'md5': - '12283e2ddde1797c5d1e57036b512746', - 'cfg_path': - 'ckpt/model_config.json', - 'ckpt_path': - 'ckpt/model_state.pdparams', - 'vocab_file': - 'punc_vocab.txt', - }, - "ernie_linear_p3_wudao-punc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz', - 'md5': - '448eb2fdf85b6a997e7e652e80c51dd2', - 'cfg_path': - 'ckpt/model_config.json', - 'ckpt_path': - 'ckpt/model_state.pdparams', - 'vocab_file': - 'punc_vocab.txt', - }, -} - -model_alias = { - "ernie_linear_p7": "paddlespeech.text.models:ErnieLinear", - "ernie_linear_p3": "paddlespeech.text.models:ErnieLinear", -} - -tokenizer_alias = { - "ernie_linear_p7": "paddlenlp.transformers:ErnieTokenizer", - "ernie_linear_p3": "paddlenlp.transformers:ErnieTokenizer", -} - @cli_register(name='paddlespeech.text', description='Text infer command.') class TextExecutor(BaseExecutor): def __init__(self): - super(TextExecutor, self).__init__() + super().__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models + self.tokenizer_alias = tokenizer_alias self.parser = argparse.ArgumentParser( prog='paddlespeech.text', add_help=True) @@ -92,7 +55,9 @@ class TextExecutor(BaseExecutor): '--model', type=str, default='ernie_linear_p7_wudao', - choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()], + choices=[ + tag[:tag.index('-')] for tag in self.pretrained_models.keys() + ], help='Choose model type of text task.') self.parser.add_argument( '--lang', @@ -131,23 +96,6 @@ class TextExecutor(BaseExecutor): action='store_true', help='Increase logger verbosity of current task.') - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - """ - support_models = list(pretrained_models.keys()) - assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( - tag, '\n\t\t'.join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(pretrained_models[tag], - res_path) - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - 'Use pretrained model stored in: {}'.format(decompressed_path)) - - return decompressed_path - def _init_from_path(self, task: str='punc', model_type: str='ernie_linear_p7_wudao', @@ -167,12 +115,12 @@ class TextExecutor(BaseExecutor): if cfg_path is None or ckpt_path is None or vocab_file is None: tag = '-'.join([model_type, task, lang]) self.res_path = self._get_pretrained_path(tag) - self.cfg_path = os.path.join(self.res_path, - pretrained_models[tag]['cfg_path']) - self.ckpt_path = os.path.join(self.res_path, - pretrained_models[tag]['ckpt_path']) - self.vocab_file = os.path.join(self.res_path, - pretrained_models[tag]['vocab_file']) + self.cfg_path = os.path.join( + self.res_path, self.pretrained_models[tag]['cfg_path']) + self.ckpt_path = os.path.join( + self.res_path, self.pretrained_models[tag]['ckpt_path']) + self.vocab_file = os.path.join( + self.res_path, self.pretrained_models[tag]['vocab_file']) else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path) @@ -187,8 +135,8 @@ class TextExecutor(BaseExecutor): self._punc_list.append(line.strip()) # model - model_class = dynamic_import(model_name, model_alias) - tokenizer_class = dynamic_import(model_name, tokenizer_alias) + model_class = dynamic_import(model_name, self.model_alias) + tokenizer_class = dynamic_import(model_name, self.tokenizer_alias) self.model = model_class( cfg_path=self.cfg_path, ckpt_path=self.ckpt_path) self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0') diff --git a/paddlespeech/cli/text/pretrained_models.py b/paddlespeech/cli/text/pretrained_models.py new file mode 100644 index 0000000000000000000000000000000000000000..817d3caa3cdc634a202703d4885796b21eee4f56 --- /dev/null +++ b/paddlespeech/cli/text/pretrained_models.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pretrained_models = { + # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". + # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". + # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: + # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" + "ernie_linear_p7_wudao-punc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz', + 'md5': + '12283e2ddde1797c5d1e57036b512746', + 'cfg_path': + 'ckpt/model_config.json', + 'ckpt_path': + 'ckpt/model_state.pdparams', + 'vocab_file': + 'punc_vocab.txt', + }, + "ernie_linear_p3_wudao-punc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz', + 'md5': + '448eb2fdf85b6a997e7e652e80c51dd2', + 'cfg_path': + 'ckpt/model_config.json', + 'ckpt_path': + 'ckpt/model_state.pdparams', + 'vocab_file': + 'punc_vocab.txt', + }, +} + +model_alias = { + "ernie_linear_p7": "paddlespeech.text.models:ErnieLinear", + "ernie_linear_p3": "paddlespeech.text.models:ErnieLinear", +} + +tokenizer_alias = { + "ernie_linear_p7": "paddlenlp.transformers:ErnieTokenizer", + "ernie_linear_p3": "paddlenlp.transformers:ErnieTokenizer", +} diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 1c3fb29f413b8f8a5fff24ac86f013f59912e802..1c7199306c7bf5aae39781b510a6145bd42b9894 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -29,9 +29,9 @@ from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger from ..utils import cli_register -from ..utils import download_and_decompress -from ..utils import MODEL_HOME from ..utils import stats_wrapper +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend @@ -39,299 +39,14 @@ from paddlespeech.t2s.modules.normalizer import ZScore __all__ = ['TTSExecutor'] -pretrained_models = { - # speedyspeech - "speedyspeech_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip', - 'md5': - '6f6fa967b408454b6662c8c00c0027cb', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_30600.pdz', - 'speech_stats': - 'feats_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'tones_dict': - 'tone_id_map.txt', - }, - - # fastspeech2 - "fastspeech2_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', - 'md5': - '637d28a5e53aa60275612ba4393d5f22', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_76000.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "fastspeech2_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip', - 'md5': - 'ffed800c93deaf16ca9b3af89bfcd747', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_100000.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "fastspeech2_aishell3-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip', - 'md5': - 'f4dd4a5f49a4552b77981f544ab3392e', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_96400.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'speaker_dict': - 'speaker_id_map.txt', - }, - "fastspeech2_vctk-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip', - 'md5': - '743e5024ca1e17a88c5c271db9779ba4', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_66200.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'speaker_dict': - 'speaker_id_map.txt', - }, - # tacotron2 - "tacotron2_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip', - 'md5': - '0df4b6f0bcbe0d73c5ed6df8867ab91a', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_30600.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "tacotron2_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip', - 'md5': - '6a5eddd81ae0e81d16959b97481135f3', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_60300.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - - # pwgan - "pwgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip', - 'md5': - '2e481633325b5bdf0a3823c714d2c117', - 'config': - 'pwg_default.yaml', - 'ckpt': - 'pwg_snapshot_iter_400000.pdz', - 'speech_stats': - 'pwg_stats.npy', - }, - "pwgan_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip', - 'md5': - '53610ba9708fd3008ccaf8e99dacbaf0', - 'config': - 'pwg_default.yaml', - 'ckpt': - 'pwg_snapshot_iter_400000.pdz', - 'speech_stats': - 'pwg_stats.npy', - }, - "pwgan_aishell3-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip', - 'md5': - 'd7598fa41ad362d62f85ffc0f07e3d84', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1000000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "pwgan_vctk-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip', - 'md5': - 'b3da1defcde3e578be71eb284cb89f2c', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - # mb_melgan - "mb_melgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', - 'md5': - 'ee5f0604e20091f0d495b6ec4618b90d', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1000000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - # style_melgan - "style_melgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip', - 'md5': - '5de2d5348f396de0c966926b8c462755', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - # hifigan - "hifigan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', - 'md5': - 'dd40a3d88dfcf64513fba2f0f961ada6', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "hifigan_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip', - 'md5': - '70e9131695decbca06a65fe51ed38a72', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "hifigan_aishell3-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip', - 'md5': - '3bb49bc75032ed12f79c00c8cc79a09a', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "hifigan_vctk-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip', - 'md5': - '7da8f88359bca2457e705d924cf27bd4', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - - # wavernn - "wavernn_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip', - 'md5': - 'ee37b752f09bcba8f2af3b777ca38e13', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_400000.pdz', - 'speech_stats': - 'feats_stats.npy', - } -} - -model_alias = { - # acoustic model - "speedyspeech": - "paddlespeech.t2s.models.speedyspeech:SpeedySpeech", - "speedyspeech_inference": - "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference", - "fastspeech2": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2", - "fastspeech2_inference": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", - "tacotron2": - "paddlespeech.t2s.models.tacotron2:Tacotron2", - "tacotron2_inference": - "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", - # voc - "pwgan": - "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", - "pwgan_inference": - "paddlespeech.t2s.models.parallel_wavegan:PWGInference", - "mb_melgan": - "paddlespeech.t2s.models.melgan:MelGANGenerator", - "mb_melgan_inference": - "paddlespeech.t2s.models.melgan:MelGANInference", - "style_melgan": - "paddlespeech.t2s.models.melgan:StyleMelGANGenerator", - "style_melgan_inference": - "paddlespeech.t2s.models.melgan:StyleMelGANInference", - "hifigan": - "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", - "hifigan_inference": - "paddlespeech.t2s.models.hifigan:HiFiGANInference", - "wavernn": - "paddlespeech.t2s.models.wavernn:WaveRNN", - "wavernn_inference": - "paddlespeech.t2s.models.wavernn:WaveRNNInference", -} - @cli_register( name='paddlespeech.tts', description='Text to Speech infer command.') class TTSExecutor(BaseExecutor): def __init__(self): super().__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models self.parser = argparse.ArgumentParser( prog='paddlespeech.tts', add_help=True) @@ -449,22 +164,6 @@ class TTSExecutor(BaseExecutor): action='store_true', help='Increase logger verbosity of current task.') - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - """ - support_models = list(pretrained_models.keys()) - assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( - tag, '\n\t\t'.join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(pretrained_models[tag], - res_path) - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - 'Use pretrained model stored in: {}'.format(decompressed_path)) - return decompressed_path - def _init_from_path( self, am: str='fastspeech2_csmsc', @@ -490,16 +189,15 @@ class TTSExecutor(BaseExecutor): if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: am_res_path = self._get_pretrained_path(am_tag) self.am_res_path = am_res_path - self.am_config = os.path.join(am_res_path, - pretrained_models[am_tag]['config']) + self.am_config = os.path.join( + am_res_path, self.pretrained_models[am_tag]['config']) self.am_ckpt = os.path.join(am_res_path, - pretrained_models[am_tag]['ckpt']) + self.pretrained_models[am_tag]['ckpt']) self.am_stat = os.path.join( - am_res_path, pretrained_models[am_tag]['speech_stats']) + am_res_path, self.pretrained_models[am_tag]['speech_stats']) # must have phones_dict in acoustic self.phones_dict = os.path.join( - am_res_path, pretrained_models[am_tag]['phones_dict']) - print("self.phones_dict:", self.phones_dict) + am_res_path, self.pretrained_models[am_tag]['phones_dict']) logger.info(am_res_path) logger.info(self.am_config) logger.info(self.am_ckpt) @@ -509,21 +207,20 @@ class TTSExecutor(BaseExecutor): self.am_stat = os.path.abspath(am_stat) self.phones_dict = os.path.abspath(phones_dict) self.am_res_path = os.path.dirname(os.path.abspath(self.am_config)) - print("self.phones_dict:", self.phones_dict) # for speedyspeech self.tones_dict = None - if 'tones_dict' in pretrained_models[am_tag]: + if 'tones_dict' in self.pretrained_models[am_tag]: self.tones_dict = os.path.join( - am_res_path, pretrained_models[am_tag]['tones_dict']) + am_res_path, self.pretrained_models[am_tag]['tones_dict']) if tones_dict: self.tones_dict = tones_dict # for multi speaker fastspeech2 self.speaker_dict = None - if 'speaker_dict' in pretrained_models[am_tag]: + if 'speaker_dict' in self.pretrained_models[am_tag]: self.speaker_dict = os.path.join( - am_res_path, pretrained_models[am_tag]['speaker_dict']) + am_res_path, self.pretrained_models[am_tag]['speaker_dict']) if speaker_dict: self.speaker_dict = speaker_dict @@ -532,12 +229,12 @@ class TTSExecutor(BaseExecutor): if voc_ckpt is None or voc_config is None or voc_stat is None: voc_res_path = self._get_pretrained_path(voc_tag) self.voc_res_path = voc_res_path - self.voc_config = os.path.join(voc_res_path, - pretrained_models[voc_tag]['config']) - self.voc_ckpt = os.path.join(voc_res_path, - pretrained_models[voc_tag]['ckpt']) + self.voc_config = os.path.join( + voc_res_path, self.pretrained_models[voc_tag]['config']) + self.voc_ckpt = os.path.join( + voc_res_path, self.pretrained_models[voc_tag]['ckpt']) self.voc_stat = os.path.join( - voc_res_path, pretrained_models[voc_tag]['speech_stats']) + voc_res_path, self.pretrained_models[voc_tag]['speech_stats']) logger.info(voc_res_path) logger.info(self.voc_config) logger.info(self.voc_ckpt) @@ -588,8 +285,9 @@ class TTSExecutor(BaseExecutor): # model: {model_name}_{dataset} am_name = am[:am.rindex('_')] - am_class = dynamic_import(am_name, model_alias) - am_inference_class = dynamic_import(am_name + '_inference', model_alias) + am_class = dynamic_import(am_name, self.model_alias) + am_inference_class = dynamic_import(am_name + '_inference', + self.model_alias) if am_name == 'fastspeech2': am = am_class( @@ -618,9 +316,9 @@ class TTSExecutor(BaseExecutor): # vocoder # model: {model_name}_{dataset} voc_name = voc[:voc.rindex('_')] - voc_class = dynamic_import(voc_name, model_alias) + voc_class = dynamic_import(voc_name, self.model_alias) voc_inference_class = dynamic_import(voc_name + '_inference', - model_alias) + self.model_alias) if voc_name != 'wavernn': voc = voc_class(**self.voc_config["generator_params"]) voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"]) @@ -735,7 +433,6 @@ class TTSExecutor(BaseExecutor): am_ckpt = args.am_ckpt am_stat = args.am_stat phones_dict = args.phones_dict - print("phones_dict:", phones_dict) tones_dict = args.tones_dict speaker_dict = args.speaker_dict voc = args.voc diff --git a/paddlespeech/cli/tts/pretrained_models.py b/paddlespeech/cli/tts/pretrained_models.py new file mode 100644 index 0000000000000000000000000000000000000000..65254a9353fc997038d84368d3918f055d2ccee0 --- /dev/null +++ b/paddlespeech/cli/tts/pretrained_models.py @@ -0,0 +1,300 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pretrained_models = { + # speedyspeech + "speedyspeech_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip', + 'md5': + '6f6fa967b408454b6662c8c00c0027cb', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_30600.pdz', + 'speech_stats': + 'feats_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'tones_dict': + 'tone_id_map.txt', + }, + + # fastspeech2 + "fastspeech2_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', + 'md5': + '637d28a5e53aa60275612ba4393d5f22', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_76000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + "fastspeech2_ljspeech-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip', + 'md5': + 'ffed800c93deaf16ca9b3af89bfcd747', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_100000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + "fastspeech2_aishell3-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip', + 'md5': + 'f4dd4a5f49a4552b77981f544ab3392e', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_96400.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'speaker_dict': + 'speaker_id_map.txt', + }, + "fastspeech2_vctk-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip', + 'md5': + '743e5024ca1e17a88c5c271db9779ba4', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_66200.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'speaker_dict': + 'speaker_id_map.txt', + }, + # tacotron2 + "tacotron2_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip', + 'md5': + '0df4b6f0bcbe0d73c5ed6df8867ab91a', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_30600.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + "tacotron2_ljspeech-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip', + 'md5': + '6a5eddd81ae0e81d16959b97481135f3', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_60300.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + + # pwgan + "pwgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip', + 'md5': + '2e481633325b5bdf0a3823c714d2c117', + 'config': + 'pwg_default.yaml', + 'ckpt': + 'pwg_snapshot_iter_400000.pdz', + 'speech_stats': + 'pwg_stats.npy', + }, + "pwgan_ljspeech-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip', + 'md5': + '53610ba9708fd3008ccaf8e99dacbaf0', + 'config': + 'pwg_default.yaml', + 'ckpt': + 'pwg_snapshot_iter_400000.pdz', + 'speech_stats': + 'pwg_stats.npy', + }, + "pwgan_aishell3-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip', + 'md5': + 'd7598fa41ad362d62f85ffc0f07e3d84', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + "pwgan_vctk-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip', + 'md5': + 'b3da1defcde3e578be71eb284cb89f2c', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + # mb_melgan + "mb_melgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'ee5f0604e20091f0d495b6ec4618b90d', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + # style_melgan + "style_melgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip', + 'md5': + '5de2d5348f396de0c966926b8c462755', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + # hifigan + "hifigan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'dd40a3d88dfcf64513fba2f0f961ada6', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + "hifigan_ljspeech-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip', + 'md5': + '70e9131695decbca06a65fe51ed38a72', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + "hifigan_aishell3-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip', + 'md5': + '3bb49bc75032ed12f79c00c8cc79a09a', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + "hifigan_vctk-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip', + 'md5': + '7da8f88359bca2457e705d924cf27bd4', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + + # wavernn + "wavernn_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip', + 'md5': + 'ee37b752f09bcba8f2af3b777ca38e13', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_400000.pdz', + 'speech_stats': + 'feats_stats.npy', + } +} + +model_alias = { + # acoustic model + "speedyspeech": + "paddlespeech.t2s.models.speedyspeech:SpeedySpeech", + "speedyspeech_inference": + "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference", + "fastspeech2": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2", + "fastspeech2_inference": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", + "tacotron2": + "paddlespeech.t2s.models.tacotron2:Tacotron2", + "tacotron2_inference": + "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", + # voc + "pwgan": + "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", + "pwgan_inference": + "paddlespeech.t2s.models.parallel_wavegan:PWGInference", + "mb_melgan": + "paddlespeech.t2s.models.melgan:MelGANGenerator", + "mb_melgan_inference": + "paddlespeech.t2s.models.melgan:MelGANInference", + "style_melgan": + "paddlespeech.t2s.models.melgan:StyleMelGANGenerator", + "style_melgan_inference": + "paddlespeech.t2s.models.melgan:StyleMelGANInference", + "hifigan": + "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", + "hifigan_inference": + "paddlespeech.t2s.models.hifigan:HiFiGANInference", + "wavernn": + "paddlespeech.t2s.models.wavernn:WaveRNN", + "wavernn_inference": + "paddlespeech.t2s.models.wavernn:WaveRNNInference", +} diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 68e832ac74d4dda805a4185ab09a72f2eb7d6413..1dff6edb42d7e0519375156403d40180647b9b13 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -27,45 +27,24 @@ from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger from ..utils import cli_register -from ..utils import download_and_decompress -from ..utils import MODEL_HOME from ..utils import stats_wrapper +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models from paddleaudio.backends import load as load_audio from paddleaudio.compliance.librosa import melspectrogram from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.modules.sid_model import SpeakerIdetification -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]". - # e.g. "ecapatdnn_voxceleb12-16k". - # Command line and python api use "{model_name}[-{dataset}]" as --model, usage: - # "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav" - "ecapatdnn_voxceleb12-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz', - 'md5': - 'cc33023c54ab346cd318408f43fcaf95', - 'cfg_path': - 'conf/model.yaml', # the yaml config path - 'ckpt_path': - 'model/model', # the format is ${dir}/{model_name}, - # so the first 'model' is dir, the second 'model' is the name - # this means we have a model stored as model/model.pdparams - }, -} - -model_alias = { - "ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn", -} - @cli_register( name="paddlespeech.vector", description="Speech to vector embedding infer command.") class VectorExecutor(BaseExecutor): def __init__(self): - super(VectorExecutor, self).__init__() + super().__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models self.parser = argparse.ArgumentParser( prog="paddlespeech.vector", add_help=True) @@ -128,8 +107,8 @@ class VectorExecutor(BaseExecutor): Returns: bool: - False: some audio occurs error - True: all audio process success + False: some audio occurs error + True: all audio process success """ # stage 0: parse the args and get the required args parser_args = self.parser.parse_args(argv) @@ -289,32 +268,6 @@ class VectorExecutor(BaseExecutor): return res - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """get the neural network path from the pretrained model list - we stored all the pretained mode in the variable `pretrained_models` - - Args: - tag (str): model tag in the pretrained model list - - Returns: - os.PathLike: the downloaded pretrained model path in the disk - """ - support_models = list(pretrained_models.keys()) - assert tag in pretrained_models, \ - 'The model "{}" you want to use has not been supported,'\ - 'please choose other models.\n' \ - 'The support models includes\n\t\t{}'.format(tag, "\n\t\t".join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(pretrained_models[tag], - res_path) - - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - 'Use pretrained model stored in: {}'.format(decompressed_path)) - - return decompressed_path - def _init_from_path(self, model_type: str='ecapatdnn_voxceleb12', sample_rate: int=16000, @@ -350,10 +303,11 @@ class VectorExecutor(BaseExecutor): res_path = self._get_pretrained_path(tag) self.res_path = res_path - self.cfg_path = os.path.join(res_path, - pretrained_models[tag]['cfg_path']) + self.cfg_path = os.path.join( + res_path, self.pretrained_models[tag]['cfg_path']) self.ckpt_path = os.path.join( - res_path, pretrained_models[tag]['ckpt_path'] + '.pdparams') + res_path, + self.pretrained_models[tag]['ckpt_path'] + '.pdparams') else: # get the model from disk self.cfg_path = os.path.abspath(cfg_path) @@ -373,7 +327,7 @@ class VectorExecutor(BaseExecutor): logger.info("start to dynamic import the model class") model_name = model_type[:model_type.rindex('_')] logger.info(f"model name {model_name}") - model_class = dynamic_import(model_name, model_alias) + model_class = dynamic_import(model_name, self.model_alias) model_conf = self.config.model backbone = model_class(**model_conf) model = SpeakerIdetification( diff --git a/paddlespeech/cli/vector/pretrained_models.py b/paddlespeech/cli/vector/pretrained_models.py new file mode 100644 index 0000000000000000000000000000000000000000..686a22d8fb025007738a986a18a8eea236c1b3a9 --- /dev/null +++ b/paddlespeech/cli/vector/pretrained_models.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pretrained_models = { + # The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]". + # e.g. "ecapatdnn_voxceleb12-16k". + # Command line and python api use "{model_name}[-{dataset}]" as --model, usage: + # "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav" + "ecapatdnn_voxceleb12-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz', + 'md5': + 'cc33023c54ab346cd318408f43fcaf95', + 'cfg_path': + 'conf/model.yaml', # the yaml config path + 'ckpt_path': + 'model/model', # the format is ${dir}/{model_name}, + # so the first 'model' is dir, the second 'model' is the name + # this means we have a model stored as model/model.pdparams + }, +} + +model_alias = { + "ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn", +} diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index c5b64ac726f5fd15e162fe61b73c247ee54eeb88..3e7c11f2209a76cbe68099ab60f57c773bfcf05c 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -14,92 +14,17 @@ import argparse from pathlib import Path -import numpy import soundfile as sf -from paddle import inference from timer import timer +from paddlespeech.t2s.exps.syn_utils import get_am_output from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_predictor from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_voc_output from paddlespeech.t2s.utils import str2bool -def get_predictor(args, filed='am'): - full_name = '' - if filed == 'am': - full_name = args.am - elif filed == 'voc': - full_name = args.voc - model_name = full_name[:full_name.rindex('_')] - config = inference.Config( - str(Path(args.inference_dir) / (full_name + ".pdmodel")), - str(Path(args.inference_dir) / (full_name + ".pdiparams"))) - if args.device == "gpu": - config.enable_use_gpu(100, 0) - elif args.device == "cpu": - config.disable_gpu() - config.enable_memory_optim() - predictor = inference.create_predictor(config) - return predictor - - -def get_am_output(args, am_predictor, frontend, merge_sentences, input): - am_name = args.am[:args.am.rindex('_')] - am_dataset = args.am[args.am.rindex('_') + 1:] - am_input_names = am_predictor.get_input_names() - get_tone_ids = False - get_spk_id = False - if am_name == 'speedyspeech': - get_tone_ids = True - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: - get_spk_id = True - spk_id = numpy.array([args.spk_id]) - if args.lang == 'zh': - input_ids = frontend.get_input_ids( - input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] - elif args.lang == 'en': - input_ids = frontend.get_input_ids( - input, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - else: - print("lang should in {'zh', 'en'}!") - - if get_tone_ids: - tone_ids = input_ids["tone_ids"] - tones = tone_ids[0].numpy() - tones_handle = am_predictor.get_input_handle(am_input_names[1]) - tones_handle.reshape(tones.shape) - tones_handle.copy_from_cpu(tones) - if get_spk_id: - spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) - spk_id_handle.reshape(spk_id.shape) - spk_id_handle.copy_from_cpu(spk_id) - phones = phone_ids[0].numpy() - phones_handle = am_predictor.get_input_handle(am_input_names[0]) - phones_handle.reshape(phones.shape) - phones_handle.copy_from_cpu(phones) - - am_predictor.run() - am_output_names = am_predictor.get_output_names() - am_output_handle = am_predictor.get_output_handle(am_output_names[0]) - am_output_data = am_output_handle.copy_to_cpu() - return am_output_data - - -def get_voc_output(args, voc_predictor, input): - voc_input_names = voc_predictor.get_input_names() - mel_handle = voc_predictor.get_input_handle(voc_input_names[0]) - mel_handle.reshape(input.shape) - mel_handle.copy_from_cpu(input) - - voc_predictor.run() - voc_output_names = voc_predictor.get_output_names() - voc_output_handle = voc_predictor.get_output_handle(voc_output_names[0]) - wav = voc_output_handle.copy_to_cpu() - return wav - - def parse_args(): parser = argparse.ArgumentParser( description="Paddle Infernce with acoustic model & vocoder.") @@ -204,7 +129,7 @@ def main(): merge_sentences=merge_sentences, input=sentence) wav = get_voc_output( - args, voc_predictor=voc_predictor, input=am_output_data) + voc_predictor=voc_predictor, input=am_output_data) speed = wav.size / t.elapse rtf = fs / speed print( @@ -224,7 +149,7 @@ def main(): merge_sentences=merge_sentences, input=sentence) wav = get_voc_output( - args, voc_predictor=voc_predictor, input=am_output_data) + voc_predictor=voc_predictor, input=am_output_data) N += wav.size T += t.elapse diff --git a/paddlespeech/t2s/exps/inference_streaming.py b/paddlespeech/t2s/exps/inference_streaming.py new file mode 100644 index 0000000000000000000000000000000000000000..0e58056c27a75e71316c9949d2908232001fe55f --- /dev/null +++ b/paddlespeech/t2s/exps/inference_streaming.py @@ -0,0 +1,224 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import numpy as np +import soundfile as sf +from timer import timer + +from paddlespeech.t2s.exps.syn_utils import denorm +from paddlespeech.t2s.exps.syn_utils import get_am_sublayer_output +from paddlespeech.t2s.exps.syn_utils import get_chunks +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_predictor +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output +from paddlespeech.t2s.exps.syn_utils import get_streaming_am_predictor +from paddlespeech.t2s.exps.syn_utils import get_voc_output +from paddlespeech.t2s.utils import str2bool + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Paddle Infernce with acoustic model & vocoder.") + # acoustic model + parser.add_argument( + '--am', + type=str, + default='fastspeech2_csmsc', + choices=['fastspeech2_csmsc'], + help='Choose acoustic model type of tts task.') + parser.add_argument( + "--am_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training acoustic model." + ) + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--tones_dict", type=str, default=None, help="tone vocabulary file.") + parser.add_argument( + "--speaker_dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + '--spk_id', + type=int, + default=0, + help='spk id for multi speaker acoustic model') + # voc + parser.add_argument( + '--voc', + type=str, + default='pwgan_csmsc', + choices=['pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc'], + help='Choose vocoder type of tts task.') + # other + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line") + parser.add_argument( + "--inference_dir", type=str, help="dir to save inference models") + parser.add_argument("--output_dir", type=str, help="output dir") + # inference + parser.add_argument( + "--device", + default="gpu", + choices=["gpu", "cpu"], + help="Device selected for inference.", ) + # streaming related + parser.add_argument( + "--am_streaming", + type=str2bool, + default=False, + help="whether use streaming acoustic model") + parser.add_argument( + "--chunk_size", type=int, default=42, help="chunk size of am streaming") + parser.add_argument( + "--pad_size", type=int, default=12, help="pad size of am streaming") + + args, _ = parser.parse_known_args() + return args + + +# only inference for models trained with csmsc now +def main(): + args = parse_args() + # frontend + frontend = get_frontend(args) + + # am_predictor + am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor = get_streaming_am_predictor( + args) + am_mu, am_std = np.load(args.am_stat) + # model: {model_name}_{dataset} + am_dataset = args.am[args.am.rindex('_') + 1:] + + # voc_predictor + voc_predictor = get_predictor(args, filed='voc') + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + sentences = get_sentences(args) + + merge_sentences = True + + fs = 24000 if am_dataset != 'ljspeech' else 22050 + # warmup + for utt_id, sentence in sentences[:3]: + with timer() as t: + normalized_mel = get_streaming_am_output( + args, + am_encoder_infer_predictor=am_encoder_infer_predictor, + am_decoder_predictor=am_decoder_predictor, + am_postnet_predictor=am_postnet_predictor, + frontend=frontend, + merge_sentences=merge_sentences, + input=sentence) + mel = denorm(normalized_mel, am_mu, am_std) + wav = get_voc_output(voc_predictor=voc_predictor, input=mel) + speed = wav.size / t.elapse + rtf = fs / speed + print( + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + + print("warm up done!") + + N = 0 + T = 0 + chunk_size = args.chunk_size + pad_size = args.pad_size + get_tone_ids = False + for utt_id, sentence in sentences: + with timer() as t: + # frontend + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + else: + print("lang should be 'zh' here!") + phones = phone_ids[0].numpy() + # acoustic model + orig_hs = get_am_sublayer_output( + am_encoder_infer_predictor, input=phones) + + if args.am_streaming: + hss = get_chunks(orig_hs, chunk_size, pad_size) + chunk_num = len(hss) + mel_list = [] + for i, hs in enumerate(hss): + am_decoder_output = get_am_sublayer_output( + am_decoder_predictor, input=hs) + am_postnet_output = get_am_sublayer_output( + am_postnet_predictor, + input=np.transpose(am_decoder_output, (0, 2, 1))) + am_output_data = am_decoder_output + np.transpose( + am_postnet_output, (0, 2, 1)) + normalized_mel = am_output_data[0] + + sub_mel = denorm(normalized_mel, am_mu, am_std) + # clip output part of pad + if i == 0: + sub_mel = sub_mel[:-pad_size] + elif i == chunk_num - 1: + # 最后一块的右侧一定没有 pad 够 + sub_mel = sub_mel[pad_size:] + else: + # 倒数几块的右侧也可能没有 pad 够 + sub_mel = sub_mel[pad_size:(chunk_size + pad_size) - + sub_mel.shape[0]] + mel_list.append(sub_mel) + mel = np.concatenate(mel_list, axis=0) + + else: + am_decoder_output = get_am_sublayer_output( + am_decoder_predictor, input=orig_hs) + + am_postnet_output = get_am_sublayer_output( + am_postnet_predictor, + input=np.transpose(am_decoder_output, (0, 2, 1))) + am_output_data = am_decoder_output + np.transpose( + am_postnet_output, (0, 2, 1)) + normalized_mel = am_output_data[0] + mel = denorm(normalized_mel, am_mu, am_std) + # vocoder + wav = get_voc_output(voc_predictor=voc_predictor, input=mel) + + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = fs / speed + + sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000) + print( + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + + print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }") + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/ort_predict.py b/paddlespeech/t2s/exps/ort_predict.py index 2ca9b5be7cff83f8167b2b91c0985749610f0302..d1f03710b69e298d985a8f6c521b4b6d8416db9b 100644 --- a/paddlespeech/t2s/exps/ort_predict.py +++ b/paddlespeech/t2s/exps/ort_predict.py @@ -16,39 +16,14 @@ from pathlib import Path import jsonlines import numpy as np -import onnxruntime as ort import soundfile as sf from timer import timer +from paddlespeech.t2s.exps.syn_utils import get_sess from paddlespeech.t2s.exps.syn_utils import get_test_dataset from paddlespeech.t2s.utils import str2bool -def get_sess(args, filed='am'): - full_name = '' - if filed == 'am': - full_name = args.am - elif filed == 'voc': - full_name = args.voc - model_dir = str(Path(args.inference_dir) / (full_name + ".onnx")) - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - - if args.device == "gpu": - # fastspeech2/mb_melgan can't use trt now! - if args.use_trt: - providers = ['TensorrtExecutionProvider'] - else: - providers = ['CUDAExecutionProvider'] - elif args.device == "cpu": - providers = ['CPUExecutionProvider'] - sess_options.intra_op_num_threads = args.cpu_threads - sess = ort.InferenceSession( - model_dir, providers=providers, sess_options=sess_options) - return sess - - def ort_predict(args): # construct dataset for evaluation with jsonlines.open(args.test_metadata, 'r') as reader: @@ -131,7 +106,7 @@ def parse_args(): '--voc', type=str, default='hifigan_csmsc', - choices=['hifigan_csmsc', 'mb_melgan_csmsc'], + choices=['hifigan_csmsc', 'mb_melgan_csmsc', 'pwgan_csmsc'], help='Choose vocoder type of tts task.') # other parser.add_argument( diff --git a/paddlespeech/t2s/exps/ort_predict_e2e.py b/paddlespeech/t2s/exps/ort_predict_e2e.py index c62b7ecd87c3e967b96f647c0066f4d03c6231a9..366a2902702b3a97efdb95f8a71de5b977b80e55 100644 --- a/paddlespeech/t2s/exps/ort_predict_e2e.py +++ b/paddlespeech/t2s/exps/ort_predict_e2e.py @@ -15,40 +15,15 @@ import argparse from pathlib import Path import numpy as np -import onnxruntime as ort import soundfile as sf from timer import timer from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_sess from paddlespeech.t2s.utils import str2bool -def get_sess(args, filed='am'): - full_name = '' - if filed == 'am': - full_name = args.am - elif filed == 'voc': - full_name = args.voc - model_dir = str(Path(args.inference_dir) / (full_name + ".onnx")) - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - - if args.device == "gpu": - # fastspeech2/mb_melgan can't use trt now! - if args.use_trt: - providers = ['TensorrtExecutionProvider'] - else: - providers = ['CUDAExecutionProvider'] - elif args.device == "cpu": - providers = ['CPUExecutionProvider'] - sess_options.intra_op_num_threads = args.cpu_threads - sess = ort.InferenceSession( - model_dir, providers=providers, sess_options=sess_options) - return sess - - def ort_predict(args): # frontend @@ -156,7 +131,7 @@ def parse_args(): '--voc', type=str, default='hifigan_csmsc', - choices=['hifigan_csmsc', 'mb_melgan_csmsc'], + choices=['hifigan_csmsc', 'mb_melgan_csmsc', 'pwgan_csmsc'], help='Choose vocoder type of tts task.') # other parser.add_argument( diff --git a/paddlespeech/t2s/exps/ort_predict_streaming.py b/paddlespeech/t2s/exps/ort_predict_streaming.py new file mode 100644 index 0000000000000000000000000000000000000000..1b486d19df94e661170f24e326e911b28c1e10e6 --- /dev/null +++ b/paddlespeech/t2s/exps/ort_predict_streaming.py @@ -0,0 +1,233 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import numpy as np +import soundfile as sf +from timer import timer + +from paddlespeech.t2s.exps.syn_utils import denorm +from paddlespeech.t2s.exps.syn_utils import get_chunks +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_sess +from paddlespeech.t2s.exps.syn_utils import get_streaming_am_sess +from paddlespeech.t2s.utils import str2bool + + +def ort_predict(args): + + # frontend + frontend = get_frontend(args) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + sentences = get_sentences(args) + + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + fs = 24000 if am_dataset != 'ljspeech' else 22050 + + # am + am_encoder_infer_sess, am_decoder_sess, am_postnet_sess = get_streaming_am_sess( + args) + am_mu, am_std = np.load(args.am_stat) + + # vocoder + voc_sess = get_sess(args, filed='voc') + + # frontend warmup + # Loading model cost 0.5+ seconds + if args.lang == 'zh': + frontend.get_input_ids("你好,欢迎使用飞桨框架进行深度学习研究!", merge_sentences=True) + else: + print("lang should in be 'zh' here!") + + # am warmup + for T in [27, 38, 54]: + phone_ids = np.random.randint(1, 266, size=(T, )) + am_encoder_infer_sess.run(None, input_feed={'text': phone_ids}) + + am_decoder_input = np.random.rand(1, T * 15, 384).astype('float32') + am_decoder_sess.run(None, input_feed={'xs': am_decoder_input}) + + am_postnet_input = np.random.rand(1, 80, T * 15).astype('float32') + am_postnet_sess.run(None, input_feed={'xs': am_postnet_input}) + + # voc warmup + for T in [227, 308, 544]: + data = np.random.rand(T, 80).astype("float32") + voc_sess.run(None, input_feed={"logmel": data}) + print("warm up done!") + + N = 0 + T = 0 + merge_sentences = True + get_tone_ids = False + chunk_size = args.chunk_size + pad_size = args.pad_size + + for utt_id, sentence in sentences: + with timer() as t: + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in be 'zh' here!") + # merge_sentences=True here, so we only use the first item of phone_ids + phone_ids = phone_ids[0].numpy() + orig_hs = am_encoder_infer_sess.run( + None, input_feed={'text': phone_ids}) + if args.am_streaming: + hss = get_chunks(orig_hs[0], chunk_size, pad_size) + chunk_num = len(hss) + mel_list = [] + for i, hs in enumerate(hss): + am_decoder_output = am_decoder_sess.run( + None, input_feed={'xs': hs}) + am_postnet_output = am_postnet_sess.run( + None, + input_feed={ + 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) + }) + am_output_data = am_decoder_output + np.transpose( + am_postnet_output[0], (0, 2, 1)) + normalized_mel = am_output_data[0][0] + + sub_mel = denorm(normalized_mel, am_mu, am_std) + # clip output part of pad + if i == 0: + sub_mel = sub_mel[:-pad_size] + elif i == chunk_num - 1: + # 最后一块的右侧一定没有 pad 够 + sub_mel = sub_mel[pad_size:] + else: + # 倒数几块的右侧也可能没有 pad 够 + sub_mel = sub_mel[pad_size:(chunk_size + pad_size) - + sub_mel.shape[0]] + mel_list.append(sub_mel) + mel = np.concatenate(mel_list, axis=0) + else: + am_decoder_output = am_decoder_sess.run( + None, input_feed={'xs': orig_hs[0]}) + am_postnet_output = am_postnet_sess.run( + None, + input_feed={ + 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) + }) + am_output_data = am_decoder_output + np.transpose( + am_postnet_output[0], (0, 2, 1)) + normalized_mel = am_output_data[0] + mel = denorm(normalized_mel, am_mu, am_std) + mel = mel[0] + # vocoder + + wav = voc_sess.run(output_names=None, input_feed={'logmel': mel}) + + N += len(wav[0]) + T += t.elapse + speed = len(wav[0]) / t.elapse + rtf = fs / speed + sf.write( + str(output_dir / (utt_id + ".wav")), + np.array(wav)[0], + samplerate=fs) + print( + f"{utt_id}, mel: {mel.shape}, wave: {len(wav[0])}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Infernce with onnxruntime.") + # acoustic model + parser.add_argument( + '--am', + type=str, + default='fastspeech2_csmsc', + choices=['fastspeech2_csmsc'], + help='Choose acoustic model type of tts task.') + parser.add_argument( + "--am_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training acoustic model." + ) + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--tones_dict", type=str, default=None, help="tone vocabulary file.") + + # voc + parser.add_argument( + '--voc', + type=str, + default='hifigan_csmsc', + choices=['hifigan_csmsc', 'mb_melgan_csmsc', 'pwgan_csmsc'], + help='Choose vocoder type of tts task.') + # other + parser.add_argument( + "--inference_dir", type=str, help="dir to save inference models") + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line") + parser.add_argument("--output_dir", type=str, help="output dir") + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + + # inference + parser.add_argument( + "--use_trt", + type=str2bool, + default=False, + help="Whether to use inference engin TensorRT.", ) + + parser.add_argument( + "--device", + default="gpu", + choices=["gpu", "cpu"], + help="Device selected for inference.", ) + parser.add_argument('--cpu_threads', type=int, default=1) + + # streaming related + parser.add_argument( + "--am_streaming", + type=str2bool, + default=False, + help="whether use streaming acoustic model") + parser.add_argument( + "--chunk_size", type=int, default=42, help="chunk size of am streaming") + parser.add_argument( + "--pad_size", type=int, default=12, help="pad size of am streaming") + + args, _ = parser.parse_known_args() + return args + + +def main(): + args = parse_args() + + ort_predict(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index c52cb372710bd4308d9a54507284e7d10cafa6a1..21aa5bf8cbd5e089460bfa833f84e0e9d834a871 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -11,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math import os +from pathlib import Path import numpy as np +import onnxruntime as ort import paddle +from paddle import inference from paddle import jit from paddle.static import InputSpec @@ -62,6 +66,21 @@ model_alias = { } +def denorm(data, mean, std): + return data * std + mean + + +def get_chunks(data, chunk_size, pad_size): + data_len = data.shape[1] + chunks = [] + n = math.ceil(data_len / chunk_size) + for i in range(n): + start = max(0, i * chunk_size - pad_size) + end = min((i + 1) * chunk_size + pad_size, data_len) + chunks.append(data[:, start:end, :]) + return chunks + + # input def get_sentences(args): # construct dataset for evaluation @@ -241,3 +260,221 @@ def voc_to_static(args, voc_inference): paddle.jit.save(voc_inference, os.path.join(args.inference_dir, args.voc)) voc_inference = paddle.jit.load(os.path.join(args.inference_dir, args.voc)) return voc_inference + + +# inference +def get_predictor(args, filed='am'): + full_name = '' + if filed == 'am': + full_name = args.am + elif filed == 'voc': + full_name = args.voc + config = inference.Config( + str(Path(args.inference_dir) / (full_name + ".pdmodel")), + str(Path(args.inference_dir) / (full_name + ".pdiparams"))) + if args.device == "gpu": + config.enable_use_gpu(100, 0) + elif args.device == "cpu": + config.disable_gpu() + config.enable_memory_optim() + predictor = inference.create_predictor(config) + return predictor + + +def get_am_output(args, am_predictor, frontend, merge_sentences, input): + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + am_input_names = am_predictor.get_input_names() + get_tone_ids = False + get_spk_id = False + if am_name == 'speedyspeech': + get_tone_ids = True + if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + get_spk_id = True + spk_id = np.array([args.spk_id]) + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + elif args.lang == 'en': + input_ids = frontend.get_input_ids( + input, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + tones = tone_ids[0].numpy() + tones_handle = am_predictor.get_input_handle(am_input_names[1]) + tones_handle.reshape(tones.shape) + tones_handle.copy_from_cpu(tones) + if get_spk_id: + spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) + spk_id_handle.reshape(spk_id.shape) + spk_id_handle.copy_from_cpu(spk_id) + phones = phone_ids[0].numpy() + phones_handle = am_predictor.get_input_handle(am_input_names[0]) + phones_handle.reshape(phones.shape) + phones_handle.copy_from_cpu(phones) + + am_predictor.run() + am_output_names = am_predictor.get_output_names() + am_output_handle = am_predictor.get_output_handle(am_output_names[0]) + am_output_data = am_output_handle.copy_to_cpu() + return am_output_data + + +def get_voc_output(voc_predictor, input): + voc_input_names = voc_predictor.get_input_names() + mel_handle = voc_predictor.get_input_handle(voc_input_names[0]) + mel_handle.reshape(input.shape) + mel_handle.copy_from_cpu(input) + + voc_predictor.run() + voc_output_names = voc_predictor.get_output_names() + voc_output_handle = voc_predictor.get_output_handle(voc_output_names[0]) + wav = voc_output_handle.copy_to_cpu() + return wav + + +# streaming am +def get_streaming_am_predictor(args): + full_name = args.am + am_encoder_infer_config = inference.Config( + str( + Path(args.inference_dir) / + (full_name + "_am_encoder_infer" + ".pdmodel")), + str( + Path(args.inference_dir) / + (full_name + "_am_encoder_infer" + ".pdiparams"))) + am_decoder_config = inference.Config( + str( + Path(args.inference_dir) / + (full_name + "_am_decoder" + ".pdmodel")), + str( + Path(args.inference_dir) / + (full_name + "_am_decoder" + ".pdiparams"))) + am_postnet_config = inference.Config( + str( + Path(args.inference_dir) / + (full_name + "_am_postnet" + ".pdmodel")), + str( + Path(args.inference_dir) / + (full_name + "_am_postnet" + ".pdiparams"))) + if args.device == "gpu": + am_encoder_infer_config.enable_use_gpu(100, 0) + am_decoder_config.enable_use_gpu(100, 0) + am_postnet_config.enable_use_gpu(100, 0) + elif args.device == "cpu": + am_encoder_infer_config.disable_gpu() + am_decoder_config.disable_gpu() + am_postnet_config.disable_gpu() + + am_encoder_infer_config.enable_memory_optim() + am_decoder_config.enable_memory_optim() + am_postnet_config.enable_memory_optim() + + am_encoder_infer_predictor = inference.create_predictor( + am_encoder_infer_config) + am_decoder_predictor = inference.create_predictor(am_decoder_config) + am_postnet_predictor = inference.create_predictor(am_postnet_config) + return am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor + + +def get_am_sublayer_output(am_sublayer_predictor, input): + am_sublayer_input_names = am_sublayer_predictor.get_input_names() + input_handle = am_sublayer_predictor.get_input_handle( + am_sublayer_input_names[0]) + input_handle.reshape(input.shape) + input_handle.copy_from_cpu(input) + + am_sublayer_predictor.run() + am_sublayer_names = am_sublayer_predictor.get_output_names() + am_sublayer_handle = am_sublayer_predictor.get_output_handle( + am_sublayer_names[0]) + am_sublayer_output = am_sublayer_handle.copy_to_cpu() + return am_sublayer_output + + +def get_streaming_am_output(args, am_encoder_infer_predictor, + am_decoder_predictor, am_postnet_predictor, + frontend, merge_sentences, input): + get_tone_ids = False + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + else: + print("lang should be 'zh' here!") + + phones = phone_ids[0].numpy() + am_encoder_infer_output = get_am_sublayer_output( + am_encoder_infer_predictor, input=phones) + + am_decoder_output = get_am_sublayer_output( + am_decoder_predictor, input=am_encoder_infer_output) + + am_postnet_output = get_am_sublayer_output( + am_postnet_predictor, input=np.transpose(am_decoder_output, (0, 2, 1))) + am_output_data = am_decoder_output + np.transpose(am_postnet_output, + (0, 2, 1)) + normalized_mel = am_output_data[0] + return normalized_mel + + +def get_sess(args, filed='am'): + full_name = '' + if filed == 'am': + full_name = args.am + elif filed == 'voc': + full_name = args.voc + model_dir = str(Path(args.inference_dir) / (full_name + ".onnx")) + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + + if args.device == "gpu": + # fastspeech2/mb_melgan can't use trt now! + if args.use_trt: + providers = ['TensorrtExecutionProvider'] + else: + providers = ['CUDAExecutionProvider'] + elif args.device == "cpu": + providers = ['CPUExecutionProvider'] + sess_options.intra_op_num_threads = args.cpu_threads + sess = ort.InferenceSession( + model_dir, providers=providers, sess_options=sess_options) + return sess + + +# streaming am +def get_streaming_am_sess(args): + full_name = args.am + am_encoder_infer_model_dir = str( + Path(args.inference_dir) / (full_name + "_am_encoder_infer" + ".onnx")) + am_decoder_model_dir = str( + Path(args.inference_dir) / (full_name + "_am_decoder" + ".onnx")) + am_postnet_model_dir = str( + Path(args.inference_dir) / (full_name + "_am_postnet" + ".onnx")) + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + if args.device == "gpu": + # fastspeech2/mb_melgan can't use trt now! + if args.use_trt: + providers = ['TensorrtExecutionProvider'] + else: + providers = ['CUDAExecutionProvider'] + elif args.device == "cpu": + providers = ['CPUExecutionProvider'] + sess_options.intra_op_num_threads = args.cpu_threads + am_encoder_infer_sess = ort.InferenceSession( + am_encoder_infer_model_dir, + providers=providers, + sess_options=sess_options) + am_decoder_sess = ort.InferenceSession( + am_decoder_model_dir, providers=providers, sess_options=sess_options) + am_postnet_sess = ort.InferenceSession( + am_postnet_model_dir, providers=providers, sess_options=sess_options) + return am_encoder_infer_sess, am_decoder_sess, am_postnet_sess diff --git a/paddlespeech/t2s/exps/synthesize_streaming.py b/paddlespeech/t2s/exps/synthesize_streaming.py index 7b9906c1076edda751c4a10772df0f82e8f1f39b..4f7a84e91f0e6d668c0b4fc6170296e146cb618f 100644 --- a/paddlespeech/t2s/exps/synthesize_streaming.py +++ b/paddlespeech/t2s/exps/synthesize_streaming.py @@ -12,39 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -import math +import os from pathlib import Path import numpy as np import paddle import soundfile as sf import yaml +from paddle import jit +from paddle.static import InputSpec from timer import timer from yacs.config import CfgNode from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.t2s.exps.syn_utils import denorm +from paddlespeech.t2s.exps.syn_utils import get_chunks from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.exps.syn_utils import model_alias +from paddlespeech.t2s.exps.syn_utils import voc_to_static from paddlespeech.t2s.utils import str2bool -def denorm(data, mean, std): - return data * std + mean - - -def get_chunks(data, chunk_size, pad_size): - data_len = data.shape[1] - chunks = [] - n = math.ceil(data_len / chunk_size) - for i in range(n): - start = max(0, i * chunk_size - pad_size) - end = min((i + 1) * chunk_size + pad_size, data_len) - chunks.append(data[:, start:end, :]) - return chunks - - def evaluate(args): # Init body. @@ -84,9 +74,49 @@ def evaluate(args): am_mu = paddle.to_tensor(am_mu) am_std = paddle.to_tensor(am_std) + # am sub layers + am_encoder_infer = am.encoder_infer + am_decoder = am.decoder + am_postnet = am.postnet + # vocoder voc_inference = get_voc_inference(args, voc_config) + # whether dygraph to static + if args.inference_dir: + # fastspeech2 cnndecoder to static + # am.encoder_infer + am_encoder_infer = jit.to_static( + am_encoder_infer, input_spec=[InputSpec([-1], dtype=paddle.int64)]) + paddle.jit.save(am_encoder_infer, + os.path.join(args.inference_dir, + args.am + "_am_encoder_infer")) + am_encoder_infer = paddle.jit.load( + os.path.join(args.inference_dir, args.am + "_am_encoder_infer")) + + # am.decoder + am_decoder = jit.to_static( + am_decoder, + input_spec=[InputSpec([1, -1, 384], dtype=paddle.float32)]) + paddle.jit.save(am_decoder, + os.path.join(args.inference_dir, + args.am + "_am_decoder")) + am_decoder = paddle.jit.load( + os.path.join(args.inference_dir, args.am + "_am_decoder")) + + # am.postnet + am_postnet = jit.to_static( + am_postnet, + input_spec=[InputSpec([1, 80, -1], dtype=paddle.float32)]) + paddle.jit.save(am_postnet, + os.path.join(args.inference_dir, + args.am + "_am_postnet")) + am_postnet = paddle.jit.load( + os.path.join(args.inference_dir, args.am + "_am_postnet")) + + # vocoder + voc_inference = voc_to_static(args, voc_inference) + output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) merge_sentences = True @@ -107,20 +137,19 @@ def evaluate(args): phone_ids = input_ids["phone_ids"] else: - print("lang should in be 'zh' here!") + print("lang should be 'zh' here!") # merge_sentences=True here, so we only use the first item of phone_ids phone_ids = phone_ids[0] with paddle.no_grad(): # acoustic model - orig_hs, h_masks = am.encoder_infer(phone_ids) - + orig_hs = am_encoder_infer(phone_ids) if args.am_streaming: hss = get_chunks(orig_hs, chunk_size, pad_size) chunk_num = len(hss) mel_list = [] for i, hs in enumerate(hss): - before_outs, _ = am.decoder(hs) - after_outs = before_outs + am.postnet( + before_outs = am_decoder(hs) + after_outs = before_outs + am_postnet( before_outs.transpose((0, 2, 1))).transpose( (0, 2, 1)) normalized_mel = after_outs[0] @@ -139,8 +168,8 @@ def evaluate(args): mel = paddle.concat(mel_list, axis=0) else: - before_outs, _ = am.decoder(orig_hs) - after_outs = before_outs + am.postnet( + before_outs = am_decoder(orig_hs) + after_outs = before_outs + am_postnet( before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) normalized_mel = after_outs[0] mel = denorm(normalized_mel, am_mu, am_std) @@ -201,16 +230,9 @@ def parse_args(): default='pwgan_csmsc', choices=[ 'pwgan_csmsc', - 'pwgan_ljspeech', - 'pwgan_aishell3', - 'pwgan_vctk', 'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc', - 'hifigan_ljspeech', - 'hifigan_aishell3', - 'hifigan_vctk', - 'wavernn_csmsc', ], help='Choose vocoder type of tts task.') parser.add_argument( @@ -233,13 +255,19 @@ def parse_args(): default='zh', help='Choose model language. zh or en') + parser.add_argument( + "--inference_dir", + type=str, + default=None, + help="dir to save inference models") + parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") parser.add_argument( "--text", type=str, help="text to synthesize, a 'utt_id sentence' pair per line.") - + # streaming related parser.add_argument( "--am_streaming", type=str2bool, diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 8e52f91625e026e96ed4a7bfdb18312846c968ca..48595bb25ca2241b74ebe22be6325708564b5699 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -590,15 +590,17 @@ class FastSpeech2(nn.Layer): h_masks = self._source_mask(olens_in) else: h_masks = None - if return_after_enc: return hs, h_masks - # (B, Lmax, adim) - zs, _ = self.decoder(hs, h_masks) - # (B, Lmax, odim) + if self.decoder_type == 'cnndecoder': + # remove output masks for dygraph to static graph + zs = self.decoder(hs, h_masks) before_outs = zs else: + # (B, Lmax, adim) + zs, _ = self.decoder(hs, h_masks) + # (B, Lmax, odim) before_outs = self.feat_out(zs).reshape( (paddle.shape(zs)[0], -1, self.odim)) @@ -633,7 +635,8 @@ class FastSpeech2(nn.Layer): tone_id = tone_id.unsqueeze(0) # (1, L, odim) - hs, h_masks = self._forward( + # use *_ to avoid bug in dygraph to static graph + hs, *_ = self._forward( xs, ilens, is_inference=True, @@ -642,7 +645,7 @@ class FastSpeech2(nn.Layer): spk_emb=spk_emb, spk_id=spk_id, tone_id=tone_id) - return hs, h_masks + return hs def inference( self, diff --git a/paddlespeech/t2s/modules/transformer/encoder.py b/paddlespeech/t2s/modules/transformer/encoder.py index d05516c2280dbe7f3665c67b8ed4c69883c94b22..11986360a30ed7dfd7dfcaed4a50bdf259e26983 100644 --- a/paddlespeech/t2s/modules/transformer/encoder.py +++ b/paddlespeech/t2s/modules/transformer/encoder.py @@ -602,7 +602,7 @@ class CNNDecoder(nn.Layer): if masks is not None: outputs = outputs * masks outputs = outputs.transpose([0, 2, 1]) - return outputs, masks + return outputs class CNNPostnet(nn.Layer): diff --git a/paddlespeech/vector/cluster/diarization.py b/paddlespeech/vector/cluster/diarization.py index f8b0a027c6499ded95c901ebbc1b652652176f0b..a8043c227bbc070fffa254c79f6f468d9b4174f0 100644 --- a/paddlespeech/vector/cluster/diarization.py +++ b/paddlespeech/vector/cluster/diarization.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 SpeechBrain Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle and SpeechBrain Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,12 +18,14 @@ This script has an optional dependency on open source sklearn library. A few sklearn functions are modified in this script as per requirement. """ import argparse +import copy import warnings from distutils.util import strtobool import numpy as np import scipy import sklearn +from scipy import linalg from scipy import sparse from scipy.sparse.csgraph import connected_components from scipy.sparse.csgraph import laplacian as csgraph_laplacian @@ -346,6 +348,8 @@ class EmbeddingMeta: --------- segset : list List of session IDs as an array of strings. + modelset : list + List of model IDs as an array of strings. stats : tensor An ndarray of float64. Each line contains embedding from the corresponding session. @@ -354,15 +358,20 @@ class EmbeddingMeta: def __init__( self, segset=None, + modelset=None, stats=None, ): if segset is None: - self.segset = numpy.empty(0, dtype="|O") - self.stats = numpy.array([], dtype=np.float64) + self.segset = np.empty(0, dtype="|O") + self.modelset = np.empty(0, dtype="|O") + self.stats = np.array([], dtype=np.float64) else: self.segset = segset + self.modelset = modelset self.stats = stats + self.stat0 = np.array([[1.0]] * self.stats.shape[0]) + def norm_stats(self): """ Divide all first-order statistics by their Euclidean norm. @@ -371,6 +380,188 @@ class EmbeddingMeta: vect_norm = np.clip(np.linalg.norm(self.stats, axis=1), 1e-08, np.inf) self.stats = (self.stats.transpose() / vect_norm).transpose() + def get_mean_stats(self): + """ + Return the mean of first order statistics. + """ + mu = np.mean(self.stats, axis=0) + return mu + + def get_total_covariance_stats(self): + """ + Compute and return the total covariance matrix of the first-order statistics. + """ + C = self.stats - self.stats.mean(axis=0) + return np.dot(C.transpose(), C) / self.stats.shape[0] + + def get_model_stat0(self, mod_id): + """Return zero-order statistics of a given model + + Arguments + --------- + mod_id : str + ID of the model which stat0 will be returned. + """ + S = self.stat0[self.modelset == mod_id, :] + return S + + def get_model_stats(self, mod_id): + """Return first-order statistics of a given model. + + Arguments + --------- + mod_id : str + ID of the model which stat1 will be returned. + """ + return self.stats[self.modelset == mod_id, :] + + def sum_stat_per_model(self): + """ + Sum the zero- and first-order statistics per model and store them + in a new EmbeddingMeta. + Returns a EmbeddingMeta object with the statistics summed per model + and a numpy array with session_per_model. + """ + + sts_per_model = EmbeddingMeta() + sts_per_model.modelset = np.unique( + self.modelset) # nd: get uniq spkr ids + sts_per_model.segset = copy.deepcopy(sts_per_model.modelset) + sts_per_model.stat0 = np.zeros( + (sts_per_model.modelset.shape[0], self.stat0.shape[1]), + dtype=np.float64, ) + sts_per_model.stats = np.zeros( + (sts_per_model.modelset.shape[0], self.stats.shape[1]), + dtype=np.float64, ) + + session_per_model = np.zeros(np.unique(self.modelset).shape[0]) + + # For each model sum the stats + for idx, model in enumerate(sts_per_model.modelset): + sts_per_model.stat0[idx, :] = self.get_model_stat0(model).sum( + axis=0) + sts_per_model.stats[idx, :] = self.get_model_stats(model).sum( + axis=0) + session_per_model[idx] += self.get_model_stats(model).shape[0] + return sts_per_model, session_per_model + + def center_stats(self, mu): + """ + Center first order statistics. + + Arguments + --------- + mu : array + Array to center on. + """ + + dim = self.stats.shape[1] / self.stat0.shape[1] + index_map = np.repeat(np.arange(self.stat0.shape[1]), dim) + self.stats = self.stats - (self.stat0[:, index_map] * + mu.astype(np.float64)) + + def rotate_stats(self, R): + """ + Rotate first-order statistics by a right-product. + + Arguments + --------- + R : ndarray + Matrix to use for right product on the first order statistics. + """ + self.stats = np.dot(self.stats, R) + + def whiten_stats(self, mu, sigma, isSqrInvSigma=False): + """ + Whiten first-order statistics + If sigma.ndim == 1, case of a diagonal covariance. + If sigma.ndim == 2, case of a single Gaussian with full covariance. + If sigma.ndim == 3, case of a full covariance UBM. + + Arguments + --------- + mu : array + Mean vector to be subtracted from the statistics. + sigma : narray + Co-variance matrix or covariance super-vector. + isSqrInvSigma : bool + True if the input Sigma matrix is the inverse of the square root of a covariance matrix. + """ + + if sigma.ndim == 1: + self.center_stats(mu) + self.stats = self.stats / np.sqrt(sigma.astype(np.float64)) + + elif sigma.ndim == 2: + # Compute the inverse square root of the co-variance matrix Sigma + sqr_inv_sigma = sigma + + if not isSqrInvSigma: + # eigen_values, eigen_vectors = scipy.linalg.eigh(sigma) + eigen_values, eigen_vectors = linalg.eigh(sigma) + ind = eigen_values.real.argsort()[::-1] + eigen_values = eigen_values.real[ind] + eigen_vectors = eigen_vectors.real[:, ind] + + sqr_inv_eval_sigma = 1 / np.sqrt(eigen_values.real) + sqr_inv_sigma = np.dot(eigen_vectors, + np.diag(sqr_inv_eval_sigma)) + else: + pass + + # Whitening of the first-order statistics + self.center_stats(mu) # CENTERING + self.rotate_stats(sqr_inv_sigma) + + elif sigma.ndim == 3: + # we assume that sigma is a 3D ndarray of size D x n x n + # where D is the number of distributions and n is the dimension of a single distribution + n = self.stats.shape[1] // self.stat0.shape[1] + sess_nb = self.stat0.shape[0] + self.center_stats(mu) + self.stats = (np.einsum("ikj,ikl->ilj", + self.stats.T.reshape(-1, n, sess_nb), sigma) + .reshape(-1, sess_nb).T) + + else: + raise Exception("Wrong dimension of Sigma, must be 1 or 2") + + def align_models(self, model_list): + """ + Align models of the current EmbeddingMeta to match a list of models + provided as input parameter. The size of the StatServer might be + reduced to match the input list of models. + + Arguments + --------- + model_list : ndarray of strings + List of models to match. + """ + indx = np.array( + [np.argwhere(self.modelset == v)[0][0] for v in model_list]) + self.segset = self.segset[indx] + self.modelset = self.modelset[indx] + self.stat0 = self.stat0[indx, :] + self.stats = self.stats[indx, :] + + def align_segments(self, segment_list): + """ + Align segments of the current EmbeddingMeta to match a list of segment + provided as input parameter. The size of the StatServer might be + reduced to match the input list of segments. + + Arguments + --------- + segment_list: ndarray of strings + list of segments to match + """ + indx = np.array( + [np.argwhere(self.segset == v)[0][0] for v in segment_list]) + self.segset = self.segset[indx] + self.modelset = self.modelset[indx] + self.stat0 = self.stat0[indx, :] + self.stats = self.stats[indx, :] + class SpecClustUnorm: """ diff --git a/paddlespeech/vector/cluster/plda.py b/paddlespeech/vector/cluster/plda.py new file mode 100644 index 0000000000000000000000000000000000000000..81def435692fc8ad77cc87ac79a814235b74987e --- /dev/null +++ b/paddlespeech/vector/cluster/plda.py @@ -0,0 +1,575 @@ +# Copyright (c) 2022 PaddlePaddle and SpeechBrain Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A popular speaker recognition/diarization model (LDA and PLDA). + +Relevant Papers + - This implementation of PLDA is based on the following papers. + + - PLDA model Training + * Ye Jiang et. al, "PLDA Modeling in I-Vector and Supervector Space for Speaker Verification," in Interspeech, 2012. + * Patrick Kenny et. al, "PLDA for speaker verification with utterances of arbitrary duration," in ICASSP, 2013. + + - PLDA scoring (fast scoring) + * Daniel Garcia-Romero et. al, “Analysis of i-vector length normalization in speaker recognition systems,” in Interspeech, 2011. + * Weiwei-LIN et. al, "Fast Scoring for PLDA with Uncertainty Propagation," in Odyssey, 2016. + * Kong Aik Lee et. al, "Multi-session PLDA Scoring of I-vector for Partially Open-Set Speaker Detection," in Interspeech 2013. + +Credits + This code is adapted from: https://git-lium.univ-lemans.fr/Larcher/sidekit +""" +import copy +import pickle + +import numpy +from scipy import linalg + +from paddlespeech.vector.cluster.diarization import EmbeddingMeta + + +def ismember(list1, list2): + c = [item in list2 for item in list1] + return c + + +class Ndx: + """ + A class that encodes trial index information. It has a list of + model names and a list of test segment names and a matrix + indicating which combinations of model and test segment are + trials of interest. + + Arguments + --------- + modelset : list + List of unique models in a ndarray. + segset : list + List of unique test segments in a ndarray. + trialmask : 2D ndarray of bool. + Rows correspond to the models and columns to the test segments. True, if the trial is of interest. + """ + + def __init__(self, + ndx_file_name="", + models=numpy.array([]), + testsegs=numpy.array([])): + """ + Initialize a Ndx object by loading information from a file. + + Arguments + --------- + ndx_file_name : str + Name of the file to load. + """ + self.modelset = numpy.empty(0, dtype="|O") + self.segset = numpy.empty(0, dtype="|O") + self.trialmask = numpy.array([], dtype="bool") + + if ndx_file_name == "": + # This is needed to make sizes same + d = models.shape[0] - testsegs.shape[0] + if d != 0: + if d > 0: + last = str(testsegs[-1]) + pad = numpy.array([last] * d) + testsegs = numpy.hstack((testsegs, pad)) + # pad = testsegs[-d:] + # testsegs = numpy.concatenate((testsegs, pad), axis=1) + else: + d = abs(d) + last = str(models[-1]) + pad = numpy.array([last] * d) + models = numpy.hstack((models, pad)) + # pad = models[-d:] + # models = numpy.concatenate((models, pad), axis=1) + + modelset = numpy.unique(models) + segset = numpy.unique(testsegs) + + trialmask = numpy.zeros( + (modelset.shape[0], segset.shape[0]), dtype="bool") + for m in range(modelset.shape[0]): + segs = testsegs[numpy.array(ismember(models, modelset[m]))] + trialmask[m, ] = ismember(segset, segs) # noqa E231 + + self.modelset = modelset + self.segset = segset + self.trialmask = trialmask + assert self.validate(), "Wrong Ndx format" + + else: + ndx = Ndx.read(ndx_file_name) + self.modelset = ndx.modelset + self.segset = ndx.segset + self.trialmask = ndx.trialmask + + def save_ndx_object(self, output_file_name): + with open(output_file_name, "wb") as output: + pickle.dump(self, output, pickle.HIGHEST_PROTOCOL) + + def filter(self, modlist, seglist, keep): + """ + Removes some of the information in an Ndx. Useful for creating a + gender specific Ndx from a pooled gender Ndx. Depending on the + value of \'keep\', the two input lists indicate the strings to + retain or the strings to discard. + + Arguments + --------- + modlist : array + A cell array of strings which will be compared with the modelset of 'inndx'. + seglist : array + A cell array of strings which will be compared with the segset of 'inndx'. + keep : bool + Indicating whether modlist and seglist are the models to keep or discard. + """ + if keep: + keepmods = modlist + keepsegs = seglist + else: + keepmods = diff(self.modelset, modlist) + keepsegs = diff(self.segset, seglist) + + keepmodidx = numpy.array(ismember(self.modelset, keepmods)) + keepsegidx = numpy.array(ismember(self.segset, keepsegs)) + + outndx = Ndx() + outndx.modelset = self.modelset[keepmodidx] + outndx.segset = self.segset[keepsegidx] + tmp = self.trialmask[numpy.array(keepmodidx), :] + outndx.trialmask = tmp[:, numpy.array(keepsegidx)] + + assert outndx.validate, "Wrong Ndx format" + + if self.modelset.shape[0] > outndx.modelset.shape[0]: + print( + "Number of models reduced from %d to %d" % + self.modelset.shape[0], + outndx.modelset.shape[0], ) + if self.segset.shape[0] > outndx.segset.shape[0]: + print( + "Number of test segments reduced from %d to %d", + self.segset.shape[0], + outndx.segset.shape[0], ) + return outndx + + def validate(self): + """ + Checks that an object of type Ndx obeys certain rules that + must always be true. Returns a boolean value indicating whether the object is valid + """ + ok = isinstance(self.modelset, numpy.ndarray) + ok &= isinstance(self.segset, numpy.ndarray) + ok &= isinstance(self.trialmask, numpy.ndarray) + + ok &= self.modelset.ndim == 1 + ok &= self.segset.ndim == 1 + ok &= self.trialmask.ndim == 2 + + ok &= self.trialmask.shape == (self.modelset.shape[0], + self.segset.shape[0], ) + return ok + + +class Scores: + """ + A class for storing scores for trials. The modelset and segset + fields are lists of model and test segment names respectively. + The element i,j of scoremat and scoremask corresponds to the + trial involving model i and test segment j. + + Arguments + --------- + modelset : list + List of unique models in a ndarray. + segset : list + List of unique test segments in a ndarray. + scoremask : 2D ndarray of bool + Indicates the trials of interest, i.e., + the entry i,j in scoremat should be ignored if scoremask[i,j] is False. + scoremat : 2D ndarray + Scores matrix. + """ + + def __init__(self, scores_file_name=""): + """ + Initialize a Scores object by loading information from a file HDF5 format. + + Arguments + --------- + scores_file_name : str + Name of the file to load. + """ + self.modelset = numpy.empty(0, dtype="|O") + self.segset = numpy.empty(0, dtype="|O") + self.scoremask = numpy.array([], dtype="bool") + self.scoremat = numpy.array([]) + + if scores_file_name == "": + pass + else: + tmp = Scores.read(scores_file_name) + self.modelset = tmp.modelset + self.segset = tmp.segset + self.scoremask = tmp.scoremask + self.scoremat = tmp.scoremat + + def __repr__(self): + ch = "modelset:\n" + ch += self.modelset + "\n" + ch += "segset:\n" + ch += self.segset + "\n" + ch += "scoremask:\n" + ch += self.scoremask.__repr__() + "\n" + ch += "scoremat:\n" + ch += self.scoremat.__repr__() + "\n" + + +def fa_model_loop( + batch_start, + mini_batch_indices, + factor_analyser, + stat0, + stats, + e_h, + e_hh, ): + """ + A function for PLDA estimation. + + Arguments + --------- + batch_start : int + Index to start at in the list. + mini_batch_indices : list + Indices of the elements in the list (should start at zero). + factor_analyser : instance of PLDA class + PLDA class object. + stat0 : tensor + Matrix of zero-order statistics. + stats: tensor + Matrix of first-order statistics. + e_h : tensor + An accumulator matrix. + e_hh: tensor + An accumulator matrix. + """ + rank = factor_analyser.F.shape[1] + if factor_analyser.Sigma.ndim == 2: + A = factor_analyser.F.T.dot(factor_analyser.F) + inv_lambda_unique = dict() + for sess in numpy.unique(stat0[:, 0]): + inv_lambda_unique[sess] = linalg.inv(sess * A + numpy.eye(A.shape[ + 0])) + + tmp = numpy.zeros( + (factor_analyser.F.shape[1], factor_analyser.F.shape[1]), + dtype=numpy.float64, ) + + for idx in mini_batch_indices: + if factor_analyser.Sigma.ndim == 1: + inv_lambda = linalg.inv( + numpy.eye(rank) + (factor_analyser.F.T * stat0[ + idx + batch_start, :]).dot(factor_analyser.F)) + else: + inv_lambda = inv_lambda_unique[stat0[idx + batch_start, 0]] + + aux = factor_analyser.F.T.dot(stats[idx + batch_start, :]) + numpy.dot(aux, inv_lambda, out=e_h[idx]) + e_hh[idx] = inv_lambda + numpy.outer(e_h[idx], e_h[idx], tmp) + + +def _check_missing_model(enroll, test, ndx): + # Remove missing models and test segments + clean_ndx = ndx.filter(enroll.modelset, test.segset, True) + + # Align EmbeddingMeta to match the clean_ndx + enroll.align_models(clean_ndx.modelset) + test.align_segments(clean_ndx.segset) + + return clean_ndx + + +class PLDA: + """ + A class to train PLDA model from embeddings. + + The input is in paddlespeech.vector.cluster.diarization.EmbeddingMeta format. + Trains a simplified PLDA model no within-class covariance matrix but full residual covariance matrix. + + Arguments + --------- + mean : tensor + Mean of the vectors. + F : tensor + Eigenvoice matrix. + Sigma : tensor + Residual matrix. + """ + + def __init__( + self, + mean=None, + F=None, + Sigma=None, + rank_f=100, + nb_iter=10, + scaling_factor=1.0, ): + self.mean = None + self.F = None + self.Sigma = None + self.rank_f = rank_f + self.nb_iter = nb_iter + self.scaling_factor = scaling_factor + + if mean is not None: + self.mean = mean + if F is not None: + self.F = F + if Sigma is not None: + self.Sigma = Sigma + + def plda( + self, + emb_meta=None, + output_file_name=None, ): + """ + Trains PLDA model with no within class covariance matrix but full residual covariance matrix. + + Arguments + --------- + emb_meta : paddlespeech.vector.cluster.diarization.EmbeddingMeta + Contains vectors and meta-information to perform PLDA + rank_f : int + Rank of the between-class covariance matrix. + nb_iter : int + Number of iterations to run. + scaling_factor : float + Scaling factor to downscale statistics (value between 0 and 1). + output_file_name : str + Name of the output file where to store PLDA model. + """ + + # Dimension of the vector (x-vectors stored in stats) + vect_size = emb_meta.stats.shape[1] + + # Initialize mean and residual covariance from the training data + self.mean = emb_meta.get_mean_stats() + self.Sigma = emb_meta.get_total_covariance_stats() + + # Sum stat0 and stat1 for each speaker model + model_shifted_stat, session_per_model = emb_meta.sum_stat_per_model() + + # Number of speakers (classes) in training set + class_nb = model_shifted_stat.modelset.shape[0] + + # Multiply statistics by scaling_factor + model_shifted_stat.stat0 *= self.scaling_factor + model_shifted_stat.stats *= self.scaling_factor + session_per_model *= self.scaling_factor + + # Covariance for stats + sigma_obs = emb_meta.get_total_covariance_stats() + evals, evecs = linalg.eigh(sigma_obs) + + # Initial F (eigen voice matrix) from rank + idx = numpy.argsort(evals)[::-1] + evecs = evecs.real[:, idx[:self.rank_f]] + self.F = evecs[:, :self.rank_f] + + # Estimate PLDA model by iterating the EM algorithm + for it in range(self.nb_iter): + + # E-step + + # Copy stats as they will be whitened with a different Sigma for each iteration + local_stat = copy.deepcopy(model_shifted_stat) + + # Whiten statistics (with the new mean and Sigma) + local_stat.whiten_stats(self.mean, self.Sigma) + + # Whiten the EigenVoice matrix + eigen_values, eigen_vectors = linalg.eigh(self.Sigma) + ind = eigen_values.real.argsort()[::-1] + eigen_values = eigen_values.real[ind] + eigen_vectors = eigen_vectors.real[:, ind] + sqr_inv_eval_sigma = 1 / numpy.sqrt(eigen_values.real) + sqr_inv_sigma = numpy.dot(eigen_vectors, + numpy.diag(sqr_inv_eval_sigma)) + self.F = sqr_inv_sigma.T.dot(self.F) + + # Replicate self.stat0 + index_map = numpy.zeros(vect_size, dtype=int) + _stat0 = local_stat.stat0[:, index_map] + + e_h = numpy.zeros((class_nb, self.rank_f)) + e_hh = numpy.zeros((class_nb, self.rank_f, self.rank_f)) + + # loop on model id's + fa_model_loop( + batch_start=0, + mini_batch_indices=numpy.arange(class_nb), + factor_analyser=self, + stat0=_stat0, + stats=local_stat.stats, + e_h=e_h, + e_hh=e_hh, ) + + # Accumulate for minimum divergence step + _R = numpy.sum(e_hh, axis=0) / session_per_model.shape[0] + + _C = e_h.T.dot(local_stat.stats).dot(linalg.inv(sqr_inv_sigma)) + _A = numpy.einsum("ijk,i->jk", e_hh, local_stat.stat0.squeeze()) + + # M-step + self.F = linalg.solve(_A, _C).T + + # Update the residual covariance + self.Sigma = sigma_obs - self.F.dot(_C) / session_per_model.sum() + + # Minimum Divergence step + self.F = self.F.dot(linalg.cholesky(_R)) + + def scoring( + self, + enroll, + test, + ndx, + test_uncertainty=None, + Vtrans=None, + p_known=0.0, + scaling_factor=1.0, + check_missing=True, ): + """ + Compute the PLDA scores between to sets of vectors. The list of + trials to perform is given in an Ndx object. PLDA matrices have to be + pre-computed. i-vectors/x-vectors are supposed to be whitened before. + + Arguments + --------- + enroll : paddlespeech.vector.cluster.diarization.EmbeddingMeta + A EmbeddingMeta in which stats are xvectors. + test : paddlespeech.vector.cluster.diarization.EmbeddingMeta + A EmbeddingMeta in which stats are xvectors. + ndx : paddlespeech.vector.cluster.plda.Ndx + An Ndx object defining the list of trials to perform. + p_known : float + Probability of having a known speaker for open-set + identification case (=1 for the verification task and =0 for the + closed-set case). + check_missing : bool + If True, check that all models and segments exist. + """ + + enroll_ctr = copy.deepcopy(enroll) + test_ctr = copy.deepcopy(test) + + # Remove missing models and test segments + if check_missing: + clean_ndx = _check_missing_model(enroll_ctr, test_ctr, ndx) + else: + clean_ndx = ndx + + # Center the i-vectors around the PLDA mean + enroll_ctr.center_stats(self.mean) + test_ctr.center_stats(self.mean) + + # Compute constant component of the PLDA distribution + invSigma = linalg.inv(self.Sigma) + I_spk = numpy.eye(self.F.shape[1], dtype="float") + + K = self.F.T.dot(invSigma * scaling_factor).dot(self.F) + K1 = linalg.inv(K + I_spk) + K2 = linalg.inv(2 * K + I_spk) + + # Compute the Gaussian distribution constant + alpha1 = numpy.linalg.slogdet(K1)[1] + alpha2 = numpy.linalg.slogdet(K2)[1] + plda_cst = alpha2 / 2.0 - alpha1 + + # Compute intermediate matrices + Sigma_ac = numpy.dot(self.F, self.F.T) + Sigma_tot = Sigma_ac + self.Sigma + Sigma_tot_inv = linalg.inv(Sigma_tot) + + Tmp = linalg.inv(Sigma_tot - Sigma_ac.dot(Sigma_tot_inv).dot(Sigma_ac)) + Phi = Sigma_tot_inv - Tmp + Psi = Sigma_tot_inv.dot(Sigma_ac).dot(Tmp) + + # Compute the different parts of PLDA score + model_part = 0.5 * numpy.einsum("ij, ji->i", + enroll_ctr.stats.dot(Phi), + enroll_ctr.stats.T) + seg_part = 0.5 * numpy.einsum("ij, ji->i", + test_ctr.stats.dot(Phi), test_ctr.stats.T) + + # Compute verification scores + score = Scores() # noqa F821 + score.modelset = clean_ndx.modelset + score.segset = clean_ndx.segset + score.scoremask = clean_ndx.trialmask + + score.scoremat = model_part[:, numpy.newaxis] + seg_part + plda_cst + score.scoremat += enroll_ctr.stats.dot(Psi).dot(test_ctr.stats.T) + score.scoremat *= scaling_factor + + # Case of open-set identification, we compute the log-likelihood + # by taking into account the probability of having a known impostor + # or an out-of set class + if p_known != 0: + N = score.scoremat.shape[0] + open_set_scores = numpy.empty(score.scoremat.shape) + tmp = numpy.exp(score.scoremat) + for ii in range(N): + # open-set term + open_set_scores[ii, :] = score.scoremat[ii, :] - numpy.log( + p_known * tmp[~(numpy.arange(N) == ii)].sum(axis=0) / ( + N - 1) + (1 - p_known)) + score.scoremat = open_set_scores + + return score + + +if __name__ == '__main__': + import random + + dim, N, n_spkrs = 10, 100, 10 + train_xv = numpy.random.rand(N, dim) + md = ['md' + str(random.randrange(1, n_spkrs, 1)) for i in range(N)] # spk + modelset = numpy.array(md, dtype="|O") + sg = ['sg' + str(i) for i in range(N)] # utt + segset = numpy.array(sg, dtype="|O") + stat0 = numpy.array([[1.0]] * N) + xvectors_stat = EmbeddingMeta( + modelset=modelset, segset=segset, stats=train_xv) + # Training PLDA model: M ~ (mean, F, Sigma) + plda = PLDA(rank_f=5) + plda.plda(xvectors_stat) + print(plda.mean.shape) #(10,) + print(plda.F.shape) #(10, 5) + print(plda.Sigma.shape) #(10, 10) + # Enrollment (20 utts), + en_N = 20 + en_xv = numpy.random.rand(en_N, dim) + en_sgs = ['en' + str(i) for i in range(en_N)] + en_sets = numpy.array(en_sgs, dtype="|O") + en_stat = EmbeddingMeta(modelset=en_sets, segset=en_sets, stats=en_xv) + # Test (30 utts) + te_N = 30 + te_xv = numpy.random.rand(te_N, dim) + te_sgs = ['te' + str(i) for i in range(te_N)] + te_sets = numpy.array(te_sgs, dtype="|O") + te_stat = EmbeddingMeta(modelset=te_sets, segset=te_sets, stats=te_xv) + ndx = Ndx(models=en_sets, testsegs=te_sets) # trials + # PLDA Scoring + scores_plda = plda.scoring(en_stat, te_stat, ndx) + print(scores_plda.scoremat.shape) #(20, 30) diff --git a/paddlespeech/vector/io/dataset_from_json.py b/paddlespeech/vector/io/dataset_from_json.py index 5c5ce13e39a828cd501d1ece6bcf40d2e3f97fb8..a4d8c4524a504fc7349ca908fe416df7b52b6e35 100644 --- a/paddlespeech/vector/io/dataset_from_json.py +++ b/paddlespeech/vector/io/dataset_from_json.py @@ -26,14 +26,14 @@ from paddleaudio.compliance.librosa import mfcc class meta_info: """the audio meta info in the vector JSONDataset Args: - id (str): the segment name + utt_id (str): the segment name duration (float): segment time wav (str): wav file path start (int): start point in the original wav file stop (int): stop point in the original wav file lab_id (str): the record id """ - id: str + utt_id: str duration: float wav: str start: int diff --git a/setup.py b/setup.py index 82ff6341265a407c62b2599e6e73493c8b9087e1..bc466baae19f5144edc947740522904094f45dda 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ base = [ "loguru", "matplotlib", "nara_wpe", + "onnxruntime", "pandas", "paddleaudio", "paddlenlp", @@ -64,6 +65,7 @@ base = [ "webrtcvad", "yacs~=0.1.8", "prettytable", + "zhon", ] server = [ @@ -90,7 +92,6 @@ requirements = { "unidecode", "yq", "pre-commit", - "zhon", ] } diff --git a/speechx/examples/README.md b/speechx/examples/README.md index 35174a0d756be8d2204788628fb94e21f34ca397..50f5f902fefc190c36f154a8eec243bf39e24f45 100644 --- a/speechx/examples/README.md +++ b/speechx/examples/README.md @@ -1,12 +1,10 @@ # Examples for SpeechX -* dev - for speechx developer, using for test. -* ngram - using to build NGram ARPA lm. * ds2_ol - ds2 streaming test under `aishell-1` test dataset. - The entrypoint is `ds2_ol/aishell/run.sh` + The entrypoint is `ds2_ol/aishell/run.sh` -## How to run +## How to run `run.sh` is the entry point. @@ -17,9 +15,23 @@ pushd ds2_ol/aishell bash run.sh ``` -## Display Model with [Netron](https://github.com/lutzroeder/netron) +## Display Model with [Netron](https://github.com/lutzroeder/netron) ``` pip install netron netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host 10.21.55.20 ``` + +## For Developer + +> Warning: Only for developer, make sure you know what's it. + +* dev - for speechx developer, using for test. + +## Build WFST + +> Warning: Using below example when you know what's it. + +* text_lm - process text for build lm +* ngram - using to build NGram ARPA lm. +* wfst - build wfst for TLG. diff --git a/speechx/examples/ds2_ol/aishell/README.md b/speechx/examples/ds2_ol/aishell/README.md index eec67c3b24decf1bd89543d86a85b7e62075cadd..f4a81516235e3bedb53bd3c63db884243bfb69ea 100644 --- a/speechx/examples/ds2_ol/aishell/README.md +++ b/speechx/examples/ds2_ol/aishell/README.md @@ -10,12 +10,18 @@ Other -> 0.00 % N=0 C=0 S=0 D=0 I=0 ## CTC Prefix Beam Search w LM +LM: zh_giga.no_cna_cmn.prune01244.klm ``` - +Overall -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327 +Mandarin -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327 +Other -> 0.00 % N=0 C=0 S=0 D=0 I=0 ``` ## CTC WFST +LM: aishell train +``` +Overall -> 11.14 % N=103017 C=93363 S=9583 D=71 I=1819 +Mandarin -> 11.14 % N=103017 C=93363 S=9583 D=71 I=1818 +Other -> 0.00 % N=0 C=0 S=0 D=0 I=1 ``` - -``` \ No newline at end of file diff --git a/speechx/examples/ds2_ol/aishell/path.sh b/speechx/examples/ds2_ol/aishell/path.sh index 8e26e6e7eef1421827c92bda7b4b5679677de8c9..0a300f362b8d46c6045e51298ca52fbe18db6f60 100644 --- a/speechx/examples/ds2_ol/aishell/path.sh +++ b/speechx/examples/ds2_ol/aishell/path.sh @@ -11,4 +11,4 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin export LC_AL=C SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat -export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN +export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN \ No newline at end of file diff --git a/speechx/examples/ds2_ol/aishell/run.sh b/speechx/examples/ds2_ol/aishell/run.sh index 3a1c19ee4ce79407b4e80aac2dd77f671166f476..6a59ca9b8c08b58c47ea9a00acf813f88f490bca 100755 --- a/speechx/examples/ds2_ol/aishell/run.sh +++ b/speechx/examples/ds2_ol/aishell/run.sh @@ -5,7 +5,10 @@ set -e . path.sh nj=40 +stage=0 +stop_stage=100 +. utils/parse_options.sh # 1. compile if [ ! -d ${SPEECHX_EXAMPLES} ]; then @@ -26,102 +29,112 @@ vocb_dir=$ckpt_dir/data/lang_char/ mkdir -p exp exp=$PWD/exp -aishell_wav_scp=aishell_test.scp -if [ ! -d $data/test ]; then - pushd $data - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip - unzip aishell_test.zip - popd - - realpath $data/test/*/*.wav > $data/wavlist - awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id - paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp -fi - - -if [ ! -d $ckpt_dir ]; then - mkdir -p $ckpt_dir - wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz - tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir -fi - -lm=$data/zh_giga.no_cna_cmn.prune01244.klm -if [ ! -f $lm ]; then - pushd $data - wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm - popd +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ];then + aishell_wav_scp=aishell_test.scp + if [ ! -d $data/test ]; then + pushd $data + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip + unzip aishell_test.zip + popd + + realpath $data/test/*/*.wav > $data/wavlist + awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id + paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp + fi + + + if [ ! -d $ckpt_dir ]; then + mkdir -p $ckpt_dir + wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz + tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir + fi + + lm=$data/zh_giga.no_cna_cmn.prune01244.klm + if [ ! -f $lm ]; then + pushd $data + wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm + popd + fi fi # 3. make feature +text=$data/test/text label_file=./aishell_result wer=./aishell_wer export GLOG_logtostderr=1 -# 3. gen linear feat -cmvn=$PWD/cmvn.ark -cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + # 3. gen linear feat + cmvn=$data/cmvn.ark + cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn -./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj + ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj -utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \ -linear-spectrogram-wo-db-norm-ol \ - --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ - --feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \ - --cmvn_file=$cmvn \ - --streaming_chunk=0.36 - -text=$data/test/text + utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \ + linear-spectrogram-wo-db-norm-ol \ + --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ + --feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \ + --cmvn_file=$cmvn \ + --streaming_chunk=0.36 +fi -# 4. recognizer -utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \ - ctc-prefix-beam-search-decoder-ol \ - --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --dict_file=$vocb_dir/vocab.txt \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result - -cat $data/split${nj}/*/result > ${label_file} -utils/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer} - -# 4. decode with lm -utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \ - ctc-prefix-beam-search-decoder-ol \ - --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --dict_file=$vocb_dir/vocab.txt \ - --lm_path=$lm \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm - - -cat $data/split${nj}/*/result_lm > ${label_file}_lm -utils/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm - - -graph_dir=./aishell_graph -if [ ! -d $ ]; then - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip - unzip -d aishell_graph.zip +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then + # recognizer + utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \ + ctc-prefix-beam-search-decoder-ol \ + --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ + --model_path=$model_dir/avg_1.jit.pdmodel \ + --param_path=$model_dir/avg_1.jit.pdiparams \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --dict_file=$vocb_dir/vocab.txt \ + --result_wspecifier=ark,t:$data/split${nj}/JOB/result + + cat $data/split${nj}/*/result > $exp/${label_file} + utils/compute-wer.py --char=1 --v=1 $exp/${label_file} $text > $exp/${wer} fi +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ];then + # decode with lm + utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \ + ctc-prefix-beam-search-decoder-ol \ + --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ + --model_path=$model_dir/avg_1.jit.pdmodel \ + --param_path=$model_dir/avg_1.jit.pdiparams \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --dict_file=$vocb_dir/vocab.txt \ + --lm_path=$lm \ + --result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm + + cat $data/split${nj}/*/result_lm > $exp/${label_file}_lm + utils/compute-wer.py --char=1 --v=1 $exp/${label_file}_lm $text > $exp/${wer}_lm +fi -# 5. test TLG decoder -utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \ - wfst-decoder-ol \ - --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --word_symbol_table=$graph_dir/words.txt \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --graph_path=$graph_dir/TLG.fst --max_active=7500 \ - --acoustic_scale=1.2 \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg +wfst=$data/wfst/ +mkdir -p $wfst +if [ ! -f $wfst/aishell_graph.zip ]; then + pushd $wfst + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip + unzip aishell_graph.zip + popd +fi -cat $data/split${nj}/*/result_tlg > ${label_file}_tlg -utils/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg \ No newline at end of file +graph_dir=$wfst/aishell_graph +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + # TLG decoder + utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \ + wfst-decoder-ol \ + --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ + --model_path=$model_dir/avg_1.jit.pdmodel \ + --param_path=$model_dir/avg_1.jit.pdiparams \ + --word_symbol_table=$graph_dir/words.txt \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --graph_path=$graph_dir/TLG.fst --max_active=7500 \ + --acoustic_scale=1.2 \ + --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg + + cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg + utils/compute-wer.py --char=1 --v=1 $exp/${label_file}_tlg $text > $exp/${wer}_tlg +fi \ No newline at end of file diff --git a/speechx/examples/ngram/README.md b/speechx/examples/ngram/README.md deleted file mode 100644 index b120715fc381d8055eab1368487dff3c033eae2a..0000000000000000000000000000000000000000 --- a/speechx/examples/ngram/README.md +++ /dev/null @@ -1 +0,0 @@ -# NGram Train diff --git a/speechx/examples/ngram/en/README.md b/speechx/examples/ngram/en/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/speechx/examples/ngram/zh/README.md b/speechx/examples/ngram/zh/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e11bd3439618e0d2dc92b1b30453a8a4eb5f976c --- /dev/null +++ b/speechx/examples/ngram/zh/README.md @@ -0,0 +1,101 @@ +# ngram train for mandarin + +Quick run: +``` +bash run.sh --stage -1 +``` + +## input + +input files: +``` +data/ +├── lexicon.txt +├── text +└── vocab.txt +``` + +``` +==> data/text <== +BAC009S0002W0122 而 对 楼市 成交 抑制 作用 最 大 的 限 购 +BAC009S0002W0123 也 成为 地方 政府 的 眼中 钉 +BAC009S0002W0124 自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后 +BAC009S0002W0125 各地 政府 便 纷纷 跟进 +BAC009S0002W0126 仅 一 个 多 月 的 时间 里 +BAC009S0002W0127 除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外 +BAC009S0002W0128 四十六 个 限 购 城市 当中 +BAC009S0002W0129 四十一 个 已 正式 取消 或 变相 放松 了 限 购 +BAC009S0002W0130 财政 金融 政策 紧随 其后 而来 +BAC009S0002W0131 显示 出 了 极 强 的 威力 + +==> data/lexicon.txt <== +SIL sil + sil +啊 aa a1 +啊 aa a2 +啊 aa a4 +啊 aa a5 +啊啊啊 aa a2 aa a2 aa a2 +啊啊啊 aa a5 aa a5 aa a5 +坐地 z uo4 d i4 +坐实 z uo4 sh ix2 +坐视 z uo4 sh ix4 +坐稳 z uo4 uu un3 +坐拥 z uo4 ii iong1 +坐诊 z uo4 zh en3 +坐庄 z uo4 zh uang1 +坐姿 z uo4 z iy1 + +==> data/vocab.txt <== + + +A +B +C +D +E +龙 +龚 +龛 + +``` + +## output + +``` +data/ +├── local +│ ├── dict +│ │ ├── lexicon.txt +│ │ └── units.txt +│ └── lm +│ ├── heldout +│ ├── lm.arpa +│ ├── text +│ ├── text.no_oov +│ ├── train +│ ├── unigram.counts +│ ├── word.counts +│ └── wordlist +``` + +``` +/workspace/srilm/bin/i686-m64/ngram-count +Namespace(bpemodel=None, in_lexicon='data/lexicon.txt', out_lexicon='data/local/dict/lexicon.txt', unit_file='data/vocab.txt') +Ignoring words 矽, which contains oov unit +Ignoring words 傩, which contains oov unit +Ignoring words 堀, which contains oov unit +Ignoring words 莼, which contains oov unit +Ignoring words 菰, which contains oov unit +Ignoring words 摭, which contains oov unit +Ignoring words 帙, which contains oov unit +Ignoring words 迨, which contains oov unit +Ignoring words 孥, which contains oov unit +Ignoring words 瑗, which contains oov unit +... +... +... +file data/local/lm/heldout: 10000 sentences, 89496 words, 0 OOVs +0 zeroprobs, logprob= -270337.9 ppl= 521.2819 ppl1= 1048.745 +build LM done. +``` diff --git a/speechx/examples/ngram/zh/local/aishell_train_lms.sh b/speechx/examples/ngram/zh/local/aishell_train_lms.sh new file mode 100755 index 0000000000000000000000000000000000000000..762661513a24b84477c6d0024f11af129001e9af --- /dev/null +++ b/speechx/examples/ngram/zh/local/aishell_train_lms.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# To be run from one directory above this script. +. ./path.sh + +text=data/local/lm/text +lexicon=data/local/dict/lexicon.txt + +for f in "$text" "$lexicon"; do + [ ! -f $x ] && echo "$0: No such file $f" && exit 1; +done + +# Check SRILM tools +if ! which ngram-count > /dev/null; then + echo "srilm tools are not found, please download it and install it from: " + echo "http://www.speech.sri.com/projects/srilm/download.html" + echo "Then add the tools to your PATH" + exit 1 +fi + +# This script takes no arguments. It assumes you have already run +# aishell_data_prep.sh. +# It takes as input the files +# data/local/lm/text +# data/local/dict/lexicon.txt +dir=data/local/lm +mkdir -p $dir + +cleantext=$dir/text.no_oov + +# oov to +# lexicon line: word char0 ... charn +# text line: utt word0 ... wordn -> line: word0 ... wordn +cat $text | awk -v lex=$lexicon 'BEGIN{while((getline0){ seen[$1]=1; } } + {for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf(" ");} } printf("\n");}' \ + > $cleantext || exit 1; + +# compute word counts, sort in descending order +# line: count word +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | \ + sort -nr > $dir/word.counts || exit 1; + +# Get counts from acoustic training transcripts, and add one-count +# for each word in the lexicon (but not silence, we don't want it +# in the LM-- we'll add it optionally later). +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \ + cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \ + sort | uniq -c | sort -nr > $dir/unigram.counts || exit 1; + +# word with +cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo ""; echo "" ) > $dir/wordlist + +# hold out to compute ppl +heldout_sent=10000 # Don't change this if you want result to be comparable with kaldi_lm results + +mkdir -p $dir +cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/heldout +cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/train + +ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \ + -map-unk "" -kndiscount -interpolate -lm $dir/lm.arpa +ngram -lm $dir/lm.arpa -ppl $dir/heldout \ No newline at end of file diff --git a/speechx/examples/ngram/zh/local/text_to_lexicon.py b/speechx/examples/ngram/zh/local/text_to_lexicon.py new file mode 100755 index 0000000000000000000000000000000000000000..0ccd07c7b2dfe1b2897135b966faec51fe39fdcb --- /dev/null +++ b/speechx/examples/ngram/zh/local/text_to_lexicon.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +import argparse +from collections import Counter + +def main(args): + counter = Counter() + with open(args.text, 'r') as fin, open(args.lexicon, 'w') as fout: + for line in fin: + line = line.strip() + if args.has_key: + utt, text = line.split(maxsplit=1) + words = text.split() + else: + words = line.split() + + counter.update(words) + + for word in counter: + val = " ".join(list(word)) + fout.write(f"{word}\t{val}\n") + fout.flush() + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='text(line:utt1 中国 人) to lexicon(line:中国 中 国).') + parser.add_argument( + '--has_key', + default=True, + help='text path, with utt or not') + parser.add_argument( + '--text', + required=True, + help='text path. line: utt1 中国 人 or 中国 人') + parser.add_argument( + '--lexicon', + required=True, + help='lexicon path. line:中国 中 国') + args = parser.parse_args() + print(args) + + main(args) diff --git a/speechx/examples/ngram/zh/path.sh b/speechx/examples/ngram/zh/path.sh new file mode 100644 index 0000000000000000000000000000000000000000..a3fb3d75878a3b7a641d7f13e464aa324a9b0ea6 --- /dev/null +++ b/speechx/examples/ngram/zh/path.sh @@ -0,0 +1,12 @@ +# This contains the locations of binarys build required for running the examples. + +MAIN_ROOT=`realpath $PWD/../../../../` +SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx` + +export LC_AL=C + +# srilm +export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10 +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs +export SRILM=${MAIN_ROOT}/tools/srilm +export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64 diff --git a/speechx/examples/ngram/zh/run.sh b/speechx/examples/ngram/zh/run.sh new file mode 100755 index 0000000000000000000000000000000000000000..f24ad0a7cc3d605d7a3af83d4d519f208abc3acf --- /dev/null +++ b/speechx/examples/ngram/zh/run.sh @@ -0,0 +1,68 @@ +#!/bin/bash +set -eo pipefail + +. path.sh + +stage=-1 +stop_stage=100 +corpus=aishell + +unit=data/vocab.txt # vocab file, line: char/spm_pice +lexicon=data/lexicon.txt # line: word ph0 ... phn, aishell/resource_aishell/lexicon.txt +text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + +. utils/parse_options.sh + +data=$PWD/data +mkdir -p $data + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + if [ ! -f $data/speech.ngram.zh.tar.gz ];then + pushd $data + wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ngram/zh/speech.ngram.zh.tar.gz + tar xvzf speech.ngram.zh.tar.gz + popd + fi +fi + +if [ ! -f $unit ]; then + echo "$0: No such file $unit" + exit 1; +fi + +if ! which ngram-count; then + pushd $MAIN_ROOT/tools + make srilm.done + popd +fi + +mkdir -p data/local/dict +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # 7.1 Prepare dict + # line: char/spm_pices + cp $unit data/local/dict/units.txt + + if [ ! -f $lexicon ];then + local/text_to_lexicon.py --has_key true --text $text --lexicon $lexicon + echo "Generate $lexicon from $text" + fi + + # filter by vocab + # line: word ph0 ... phn -> line: word char0 ... charn + utils/fst/prepare_dict.py \ + --unit_file $unit \ + --in_lexicon ${lexicon} \ + --out_lexicon data/local/dict/lexicon.txt +fi + +lm=data/local/lm +mkdir -p $lm + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # 7.2 Train lm + cp $text $lm/text + local/aishell_train_lms.sh +fi + +echo "build LM done." +exit 0 diff --git a/speechx/examples/ngram/zh/utils b/speechx/examples/ngram/zh/utils new file mode 120000 index 0000000000000000000000000000000000000000..c2519a9dd0bc11c4ca3de4bf89e16634dadcd4c9 --- /dev/null +++ b/speechx/examples/ngram/zh/utils @@ -0,0 +1 @@ +../../../../utils/ \ No newline at end of file diff --git a/speechx/examples/text_lm/.gitignore b/speechx/examples/text_lm/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1269488f7fb1f4b56a8c0e5eb48cecbfadfa9219 --- /dev/null +++ b/speechx/examples/text_lm/.gitignore @@ -0,0 +1 @@ +data diff --git a/speechx/examples/text_lm/README.md b/speechx/examples/text_lm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..627ed3dfbe33805d4bec4936539a7de182e3b115 --- /dev/null +++ b/speechx/examples/text_lm/README.md @@ -0,0 +1,15 @@ +# Text PreProcess for building ngram LM + +Output `text` file like this: + +``` +BAC009S0002W0122 而 对 楼市 成交 抑制 作用 最 大 的 限 购 +BAC009S0002W0123 也 成为 地方 政府 的 眼中 钉 +BAC009S0002W0124 自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后 +BAC009S0002W0125 各地 政府 便 纷纷 跟进 +BAC009S0002W0126 仅 一 个 多 月 的 时间 里 +BAC009S0002W0127 除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外 +BAC009S0002W0128 四十六 个 限 购 城市 当中 +BAC009S0002W0129 四十一 个 已 正式 取消 或 变相 放松 了 限 购 +BAC009S0002W0130 财政 金融 政策 紧随 其后 而来 +``` diff --git a/speechx/examples/text_lm/path.sh b/speechx/examples/text_lm/path.sh new file mode 100644 index 0000000000000000000000000000000000000000..541f852c45047b8161d86fdd4fd829669907178a --- /dev/null +++ b/speechx/examples/text_lm/path.sh @@ -0,0 +1,4 @@ +MAIN_ROOT=`realpath $PWD/../../../../` +SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx` + +export LC_AL=C diff --git a/speechx/examples/text_lm/run.sh b/speechx/examples/text_lm/run.sh new file mode 100755 index 0000000000000000000000000000000000000000..0a733b498dc46b4a7af1a26babcfb6e632e86e22 --- /dev/null +++ b/speechx/examples/text_lm/run.sh @@ -0,0 +1,24 @@ +#!/bin/bash +set -eo pipefail + +. path.sh + +stage=0 +stop_stage=100 +has_key=true +token_type=word + +. utils/parse_options.sh || exit -1; + +text=data/text + +if [ ! -f $text ]; then + echo "$0: Not find $1"; + exit -1; +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ];then + echo "text tn & wordseg preprocess" + rm -rf ${text}.tn + python3 utils/zh_tn.py --has_key $has_key --token_type $token_type ${text} ${text}.tn +fi \ No newline at end of file diff --git a/speechx/examples/text_lm/utils b/speechx/examples/text_lm/utils new file mode 120000 index 0000000000000000000000000000000000000000..256f914abcaa47d966c44878b88a300437f110fb --- /dev/null +++ b/speechx/examples/text_lm/utils @@ -0,0 +1 @@ +../../../utils/ \ No newline at end of file diff --git a/speechx/examples/wfst/.gitignore b/speechx/examples/wfst/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1269488f7fb1f4b56a8c0e5eb48cecbfadfa9219 --- /dev/null +++ b/speechx/examples/wfst/.gitignore @@ -0,0 +1 @@ +data diff --git a/speechx/examples/wfst/README.md b/speechx/examples/wfst/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4f4674a4f7bbe53e466fd454d7fa7df99b9a7159 --- /dev/null +++ b/speechx/examples/wfst/README.md @@ -0,0 +1,186 @@ +# Built TLG wfst + +## Input +``` +data/local/ +├── dict +│ ├── lexicon.txt +│ └── units.txt +└── lm + ├── heldout + ├── lm.arpa + ├── text + ├── text.no_oov + ├── train + ├── unigram.counts + ├── word.counts + └── wordlist +``` + +``` +==> data/local/dict/lexicon.txt <== +啊 啊 +啊啊啊 啊 啊 啊 +阿 阿 +阿尔 阿 尔 +阿根廷 阿 根 廷 +阿九 阿 九 +阿克 阿 克 +阿拉伯数字 阿 拉 伯 数 字 +阿拉法特 阿 拉 法 特 +阿拉木图 阿 拉 木 图 + +==> data/local/dict/units.txt <== + + +A +B +C +D +E +F +G +H + +==> data/local/lm/heldout <== +而 对 楼市 成交 抑制 作用 最 大 的 限 购 +也 成为 地方 政府 的 眼中 钉 +自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后 +各地 政府 便 纷纷 跟进 +仅 一 个 多 月 的 时间 里 +除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外 +四十六 个 限 购 城市 当中 +四十一 个 已 正式 取消 或 变相 放松 了 限 购 +财政 金融 政策 紧随 其后 而来 +显示 出 了 极 强 的 威力 + +==> data/local/lm/lm.arpa <== + +\data\ +ngram 1=129356 +ngram 2=504661 +ngram 3=123455 + +\1-grams: +-1.531278 +-3.828829 -0.1600094 +-6.157292 + +==> data/local/lm/text <== +BAC009S0002W0122 而 对 楼市 成交 抑制 作用 最 大 的 限 购 +BAC009S0002W0123 也 成为 地方 政府 的 眼中 钉 +BAC009S0002W0124 自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后 +BAC009S0002W0125 各地 政府 便 纷纷 跟进 +BAC009S0002W0126 仅 一 个 多 月 的 时间 里 +BAC009S0002W0127 除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外 +BAC009S0002W0128 四十六 个 限 购 城市 当中 +BAC009S0002W0129 四十一 个 已 正式 取消 或 变相 放松 了 限 购 +BAC009S0002W0130 财政 金融 政策 紧随 其后 而来 +BAC009S0002W0131 显示 出 了 极 强 的 威力 + +==> data/local/lm/text.no_oov <== + 而 对 楼市 成交 抑制 作用 最 大 的 限 购 + 也 成为 地方 政府 的 眼中 钉 + 自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后 + 各地 政府 便 纷纷 跟进 + 仅 一 个 多 月 的 时间 里 + 除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外 + 四十六 个 限 购 城市 当中 + 四十一 个 已 正式 取消 或 变相 放松 了 限 购 + 财政 ���融 政策 紧随 其后 而来 + 显示 出 了 极 强 的 威力 + +==> data/local/lm/train <== +汉莎 不 得 不 通过 这样 的 方式 寻求 新 的 发展 点 +并 计划 朝云 计算 方面 发展 +汉莎 的 基础 设施 部门 拥有 一千四百 名 员工 +媒体 就 曾 披露 这笔 交易 +虽然 双方 已经 正式 签署 了 外包 协议 +但是 这笔 交易 还 需要 得到 反 垄断 部门 的 批准 +陈 黎明 一九八九 年 获得 美国 康乃尔 大学 硕士 学位 +并 于 二零零三 年 顺利 完成 美国 哈佛 商学 院 高级 管理 课程 +曾 在 多家 国际 公司 任职 +拥有 业务 开发 商务 及 企业 治理 + +==> data/local/lm/unigram.counts <== + 57487 的 + 13099 在 + 11862 一 + 11397 了 + 10998 不 + 9913 是 + 7952 有 + 6250 和 + 6152 个 + 5422 将 + +==> data/local/lm/word.counts <== + 57486 的 + 13098 在 + 11861 一 + 11396 了 + 10997 不 + 9912 是 + 7951 有 + 6249 和 + 6151 个 + 5421 将 + +==> data/local/lm/wordlist <== +的 +在 +一 +了 +不 +是 +有 +和 +个 +将 +``` + +## Output + +``` +fstaddselfloops 'echo 4234 |' 'echo 123660 |' +Lexicon and Token FSTs compiling succeeded +arpa2fst --read-symbol-table=data/lang_test/words.txt --keep-symbols=true - +LOG (arpa2fst[5.5.0~1-5a37]:Read():arpa-file-parser.cc:94) Reading \data\ section. +LOG (arpa2fst[5.5.0~1-5a37]:Read():arpa-file-parser.cc:149) Reading \1-grams: section. +LOG (arpa2fst[5.5.0~1-5a37]:Read():arpa-file-parser.cc:149) Reading \2-grams: section. +LOG (arpa2fst[5.5.0~1-5a37]:Read():arpa-file-parser.cc:149) Reading \3-grams: section. +Checking how stochastic G is (the first of these numbers should be small): +fstisstochastic data/lang_test/G.fst +0 -1.14386 +fsttablecompose data/lang_test/L.fst data/lang_test/G.fst +fstminimizeencoded +fstdeterminizestar --use-log=true +fsttablecompose data/lang_test/T.fst data/lang_test/LG.fst +Composing decoding graph TLG.fst succeeded +Aishell build TLG done. +``` + +``` +data/ +├── lang_test +│ ├── G.fst +│ ├── L.fst +│ ├── LG.fst +│ ├── T.fst +│ ├── TLG.fst +│ ├── tokens.txt +│ ├── units.txt +│ └── words.txt +└── local + ├── lang + │ ├── L.fst + │ ├── T.fst + │ ├── tokens.txt + │ ├── units.txt + │ └── words.txt + └── tmp + ├── disambig.list + ├── lexiconp_disambig.txt + ├── lexiconp.txt + └── units.list +``` \ No newline at end of file diff --git a/speechx/examples/wfst/path.sh b/speechx/examples/wfst/path.sh new file mode 100644 index 0000000000000000000000000000000000000000..a07c1297de8c54daf5f78d6b5deb0d87de56988f --- /dev/null +++ b/speechx/examples/wfst/path.sh @@ -0,0 +1,19 @@ +# This contains the locations of binarys build required for running the examples. + +MAIN_ROOT=`realpath $PWD/../../../` +SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx` + +export LC_AL=C + +# srilm +export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10 +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs +export SRILM=${MAIN_ROOT}/tools/srilm +export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64 + +# Kaldi +export KALDI_ROOT=${MAIN_ROOT}/tools/kaldi +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present, can not using Kaldi!" +[ -f $KALDI_ROOT/tools/config/common_path.sh ] && . $KALDI_ROOT/tools/config/common_path.sh diff --git a/speechx/examples/wfst/run.sh b/speechx/examples/wfst/run.sh new file mode 100755 index 0000000000000000000000000000000000000000..1354646af1050fe34f08883ef8abc82038c66e7e --- /dev/null +++ b/speechx/examples/wfst/run.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -eo pipefail + +. path.sh + +stage=-1 +stop_stage=100 + +. utils/parse_options.sh + +if ! which fstprint ; then + pushd $MAIN_ROOT/tools + make kaldi.done + popd +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # build T & L + # utils/fst/compile_lexicon_token_fst.sh + utils/fst/compile_lexicon_token_fst.sh \ + data/local/dict data/local/tmp data/local/lang + + # build G & LG & TLG + # utils/fst/make_tlg.sh + utils/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1; +fi + +echo "build TLG done." +exit 0 diff --git a/speechx/examples/wfst/utils b/speechx/examples/wfst/utils new file mode 120000 index 0000000000000000000000000000000000000000..256f914abcaa47d966c44878b88a300437f110fb --- /dev/null +++ b/speechx/examples/wfst/utils @@ -0,0 +1 @@ +../../../utils/ \ No newline at end of file diff --git a/speechx/tools/install_srilm.sh b/speechx/tools/install_srilm.sh deleted file mode 100755 index 813109dbb80ea0e22791e6f32034544073df91e6..0000000000000000000000000000000000000000 --- a/speechx/tools/install_srilm.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env bash - -current_path=`pwd` -current_dir=`basename "$current_path"` - -if [ "tools" != "$current_dir" ]; then - echo "You should run this script in tools/ directory!!" - exit 1 -fi - -if [ ! -d liblbfgs-1.10 ]; then - echo Installing libLBFGS library to support MaxEnt LMs - bash extras/install_liblbfgs.sh || exit 1 -fi - -! command -v gawk > /dev/null && \ - echo "GNU awk is not installed so SRILM will probably not work correctly: refusing to install" && exit 1; - -if [ $# -ne 3 ]; then - echo "SRILM download requires some information about you" - echo - echo "Usage: $0 " - exit 1 -fi - -srilm_url="http://www.speech.sri.com/projects/srilm/srilm_download.php" -post_data="WWW_file=srilm-1.7.3.tar.gz&WWW_name=$1&WWW_org=$2&WWW_email=$3" - -if ! wget --post-data "$post_data" -O ./srilm.tar.gz "$srilm_url"; then - echo 'There was a problem downloading the file.' - echo 'Check you internet connection and try again.' - exit 1 -fi - -mkdir -p srilm -cd srilm - - -if [ -f ../srilm.tgz ]; then - tar -xvzf ../srilm.tgz # Old SRILM format -elif [ -f ../srilm.tar.gz ]; then - tar -xvzf ../srilm.tar.gz # Changed format type from tgz to tar.gz -fi - -major=`gawk -F. '{ print $1 }' RELEASE` -minor=`gawk -F. '{ print $2 }' RELEASE` -micro=`gawk -F. '{ print $3 }' RELEASE` - -if [ $major -le 1 ] && [ $minor -le 7 ] && [ $micro -le 1 ]; then - echo "Detected version 1.7.1 or earlier. Applying patch." - patch -p0 < ../extras/srilm.patch -fi - -# set the SRILM variable in the top-level Makefile to this directory. -cp Makefile tmpf - -cat tmpf | gawk -v pwd=`pwd` '/SRILM =/{printf("SRILM = %s\n", pwd); next;} {print;}' \ - > Makefile || exit 1 -rm tmpf - -mtype=`sbin/machine-type` - -echo HAVE_LIBLBFGS=1 >> common/Makefile.machine.$mtype -grep ADDITIONAL_INCLUDES common/Makefile.machine.$mtype | \ - sed 's|$| -I$(SRILM)/../liblbfgs-1.10/include|' \ - >> common/Makefile.machine.$mtype - -grep ADDITIONAL_LDFLAGS common/Makefile.machine.$mtype | \ - sed 's|$| -L$(SRILM)/../liblbfgs-1.10/lib/ -Wl,-rpath -Wl,$(SRILM)/../liblbfgs-1.10/lib/|' \ - >> common/Makefile.machine.$mtype - -make || exit - -cd .. -( - [ ! -z "${SRILM}" ] && \ - echo >&2 "SRILM variable is aleady defined. Undefining..." && \ - unset SRILM - - [ -f ./env.sh ] && . ./env.sh - - [ ! -z "${SRILM}" ] && \ - echo >&2 "SRILM config is already in env.sh" && exit - - wd=`pwd` - wd=`readlink -f $wd || pwd` - - echo "export SRILM=$wd/srilm" - dirs="\${PATH}" - for directory in $(cd srilm && find bin -type d ) ; do - dirs="$dirs:\${SRILM}/$directory" - done - echo "export PATH=$dirs" -) >> env.sh - -echo >&2 "Installation of SRILM finished successfully" -echo >&2 "Please source the tools/env.sh in your path.sh to enable it" diff --git a/tests/unit/cli/test_cli.sh b/tests/unit/cli/test_cli.sh index 96ab84d65312d4ea7d4974fa86ab85e991108f27..87c24b099ce01f9bd1b319809e3860137faec24b 100755 --- a/tests/unit/cli/test_cli.sh +++ b/tests/unit/cli/test_cli.sh @@ -1,5 +1,6 @@ #!/bin/bash set -e + # Audio classification wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/dog.wav paddlespeech cls --input ./cat.wav --topk 10 @@ -28,26 +29,16 @@ paddlespeech tts --am tacotron2_csmsc --input "你好,欢迎使用百度飞桨 paddlespeech tts --am tacotron2_csmsc --voc wavernn_csmsc --input "你好,欢迎使用百度飞桨深度学习框架!" paddlespeech tts --am tacotron2_ljspeech --voc pwgan_ljspeech --lang en --input "Life was like a box of chocolates, you never know what you're gonna get." - # Speech Translation (only support linux) paddlespeech st --input ./en.wav - -# batch process -echo -e "1 欢迎光临。\n2 谢谢惠顾。" | paddlespeech tts - -# shell pipeline -paddlespeech asr --input ./zh.wav | paddlespeech text --task punc - -# stats -paddlespeech stats --task asr -paddlespeech stats --task tts -paddlespeech stats --task cls - # Speaker Verification wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav paddlespeech vector --task spk --input 85236145389.wav +# batch process +echo -e "1 欢迎光临。\n2 谢谢惠顾。" | paddlespeech tts + echo -e "demo1 85236145389.wav \n demo2 85236145389.wav" > vec.job paddlespeech vector --task spk --input vec.job @@ -55,4 +46,13 @@ echo -e "demo3 85236145389.wav \n demo4 85236145389.wav" | paddlespeech vector - rm 85236145389.wav rm vec.job +# shell pipeline +paddlespeech asr --input ./zh.wav | paddlespeech text --task punc +# stats +paddlespeech stats --task asr +paddlespeech stats --task tts +paddlespeech stats --task cls +paddlespeech stats --task text +paddlespeech stats --task vector +paddlespeech stats --task st diff --git a/tools/Makefile b/tools/Makefile index 285f85c86caf7a2f112fa96175514d888518dcdb..a5a4485da79795ed8f2abec1a24abb2f6d9a98b0 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -25,7 +25,7 @@ clean: apt.done: apt update -y - apt install -y bc flac jq vim tig tree pkg-config libsndfile1 libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev + apt install -y bc flac jq vim tig tree sox pkg-config libsndfile1 libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev echo "check_certificate = off" >> ~/.wgetrc touch apt.done @@ -50,7 +50,7 @@ openblas.done: bash extras/install_openblas.sh touch openblas.done -kaldi.done: openblas.done +kaldi.done: apt.done openblas.done bash extras/install_kaldi.sh touch kaldi.done @@ -58,6 +58,11 @@ sctk.done: ./extras/install_sclite.sh touch sctk.done +srilm.done: + ./extras/install_liblbfgs.sh + extras/install_srilm.sh + touch srilm.done + ###################### dev: python conda_packages.done sctk.done @@ -96,4 +101,4 @@ conda_packages.done: bc.done cmake.done flac.done ffmpeg.done sox.done sndfile.d else conda_packages.done: endif - touch conda_packages.done \ No newline at end of file + touch conda_packages.done diff --git a/tools/extras/install_openfst.sh b/tools/extras/install_openfst.sh index 54ddef6a7119f285fbec4512fd6d113381268e9b..5e97bc81fb4662f96154f8c561d73aa20a36a837 100755 --- a/tools/extras/install_openfst.sh +++ b/tools/extras/install_openfst.sh @@ -7,8 +7,9 @@ set -x # openfst openfst=openfst-1.8.1 shared=true +WGET="wget -c --no-check-certificate" -test -e ${openfst}.tar.gz || wget http://www.openfst.org/twiki/pub/FST/FstDownload/${openfst}.tar.gz +test -e ${openfst}.tar.gz || $WGET http://www.openfst.org/twiki/pub/FST/FstDownload/${openfst}.tar.gz test -d ${openfst} || tar -xvf ${openfst}.tar.gz && chown -R root:root ${openfst} diff --git a/utils/compute-wer.py b/utils/compute-wer.py index 560349a02cf8e0eaab32777adc09021aa15b7abd..2d7cc8e13ed703a3ebb41033f2e6b13435543204 100755 --- a/utils/compute-wer.py +++ b/utils/compute-wer.py @@ -1,62 +1,67 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - # CopyRight WeNet Apache-2.0 License - -import re, sys, unicodedata import codecs +import sys +import unicodedata remove_tag = True -spacelist= [' ', '\t', '\r', '\n'] -puncts = ['!', ',', '?', - '、', '。', '!', ',', ';', '?', - ':', '「', '」', '︰', '『', '』', '《', '》'] +spacelist = [' ', '\t', '\r', '\n'] +puncts = [ + '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』', + '《', '》' +] + + +def characterize(string): + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + #https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': + sep = '>' + j = i + 1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c == sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res -def characterize(string) : - res = [] - i = 0 - while i < len(string): - char = string[i] - if char in puncts: - i += 1 - continue - cat1 = unicodedata.category(char) - #https://unicodebook.readthedocs.io/unicode.html#unicode-categories - if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned - i += 1 - continue - if cat1 == 'Lo': # letter-other - res.append(char) - i += 1 - else: - # some input looks like: , we want to separate it to two words. - sep = ' ' - if char == '<': sep = '>' - j = i+1 - while j < len(string): - c = string[j] - if ord(c) >= 128 or (c in spacelist) or (c==sep): - break - j += 1 - if j < len(string) and string[j] == '>': - j += 1 - res.append(string[i:j]) - i = j - return res def stripoff_tags(x): - if not x: return '' - chars = [] - i = 0; T=len(x) - while i < T: - if x[i] == '<': - while i < T and x[i] != '>': - i += 1 - i += 1 - else: - chars.append(x[i]) - i += 1 - return ''.join(chars) + if not x: + return '' + chars = [] + i = 0 + T = len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) def normalize(sentence, ignore_words, cs, split=None): @@ -66,436 +71,487 @@ def normalize(sentence, ignore_words, cs, split=None): for token in sentence: x = token if not cs: - x = x.upper() + x = x.upper() if x in ignore_words: - continue + continue if remove_tag: - x = stripoff_tags(x) + x = stripoff_tags(x) if not x: - continue + continue if split and x in split: - new_sentence += split[x] + new_sentence += split[x] else: - new_sentence.append(x) + new_sentence.append(x) return new_sentence -class Calculator : - def __init__(self) : - self.data = {} - self.space = [] - self.cost = {} - self.cost['cor'] = 0 - self.cost['sub'] = 1 - self.cost['del'] = 1 - self.cost['ins'] = 1 - def calculate(self, lab, rec) : - # Initialization - lab.insert(0, '') - rec.insert(0, '') - while len(self.space) < len(lab) : - self.space.append([]) - for row in self.space : - for element in row : - element['dist'] = 0 - element['error'] = 'non' - while len(row) < len(rec) : - row.append({'dist' : 0, 'error' : 'non'}) - for i in range(len(lab)) : - self.space[i][0]['dist'] = i - self.space[i][0]['error'] = 'del' - for j in range(len(rec)) : - self.space[0][j]['dist'] = j - self.space[0][j]['error'] = 'ins' - self.space[0][0]['error'] = 'non' - for token in lab : - if token not in self.data and len(token) > 0 : - self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} - for token in rec : - if token not in self.data and len(token) > 0 : - self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} - # Computing edit distance - for i, lab_token in enumerate(lab) : - for j, rec_token in enumerate(rec) : - if i == 0 or j == 0 : - continue - min_dist = sys.maxsize - min_error = 'none' - dist = self.space[i-1][j]['dist'] + self.cost['del'] - error = 'del' - if dist < min_dist : - min_dist = dist - min_error = error - dist = self.space[i][j-1]['dist'] + self.cost['ins'] - error = 'ins' - if dist < min_dist : - min_dist = dist - min_error = error - if lab_token == rec_token : - dist = self.space[i-1][j-1]['dist'] + self.cost['cor'] - error = 'cor' - else : - dist = self.space[i-1][j-1]['dist'] + self.cost['sub'] - error = 'sub' - if dist < min_dist : - min_dist = dist - min_error = error - self.space[i][j]['dist'] = min_dist - self.space[i][j]['error'] = min_error - # Tracing back - result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} - i = len(lab) - 1 - j = len(rec) - 1 - while True : - if self.space[i][j]['error'] == 'cor' : # correct - if len(lab[i]) > 0 : - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 - result['all'] = result['all'] + 1 - result['cor'] = result['cor'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, rec[j]) - i = i - 1 - j = j - 1 - elif self.space[i][j]['error'] == 'sub' : # substitution - if len(lab[i]) > 0 : - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 - result['all'] = result['all'] + 1 - result['sub'] = result['sub'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, rec[j]) - i = i - 1 - j = j - 1 - elif self.space[i][j]['error'] == 'del' : # deletion - if len(lab[i]) > 0 : - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 - result['all'] = result['all'] + 1 - result['del'] = result['del'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, "") - i = i - 1 - elif self.space[i][j]['error'] == 'ins' : # insertion - if len(rec[j]) > 0 : - self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 - result['ins'] = result['ins'] + 1 - result['lab'].insert(0, "") - result['rec'].insert(0, rec[j]) - j = j - 1 - elif self.space[i][j]['error'] == 'non' : # starting point - break - else : # shouldn't reach here - print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error'])) - return result - def overall(self) : - result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} - for token in self.data : - result['all'] = result['all'] + self.data[token]['all'] - result['cor'] = result['cor'] + self.data[token]['cor'] - result['sub'] = result['sub'] + self.data[token]['sub'] - result['ins'] = result['ins'] + self.data[token]['ins'] - result['del'] = result['del'] + self.data[token]['del'] - return result - def cluster(self, data) : - result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} - for token in data : - if token in self.data : - result['all'] = result['all'] + self.data[token]['all'] - result['cor'] = result['cor'] + self.data[token]['cor'] - result['sub'] = result['sub'] + self.data[token]['sub'] - result['ins'] = result['ins'] + self.data[token]['ins'] - result['del'] = result['del'] + self.data[token]['del'] - return result - def keys(self) : - return list(self.data.keys()) + +class Calculator: + def __init__(self): + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + + def calculate(self, lab, rec): + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab): + self.space.append([]) + for row in self.space: + for element in row: + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec): + row.append({'dist': 0, 'error': 'non'}) + for i in range(len(lab)): + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)): + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + for token in rec: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + # Computing edit distance + for i, lab_token in enumerate(lab): + for j, rec_token in enumerate(rec): + if i == 0 or j == 0: + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i - 1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist: + min_dist = dist + min_error = error + dist = self.space[i][j - 1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist: + min_dist = dist + min_error = error + if lab_token == rec_token: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] + error = 'cor' + else: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist: + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = { + 'lab': [], + 'rec': [], + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + i = len(lab) - 1 + j = len(rec) - 1 + while True: + if self.space[i][j]['error'] == 'cor': # correct + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub': # substitution + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del': # deletion + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins': # insertion + if len(rec[j]) > 0: + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non': # starting point + break + else: # shouldn't reach here + print( + 'this should not happen , i = {i} , j = {j} , error = {error}'. + format(i=i, j=j, error=self.space[i][j]['error'])) + return result + + def overall(self): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def cluster(self, data): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in data: + if token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def keys(self): + return list(self.data.keys()) + def width(string): - return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) -def default_cluster(word) : - unicode_names = [ unicodedata.name(char) for char in word ] - for i in reversed(range(len(unicode_names))) : - if unicode_names[i].startswith('DIGIT') : # 1 - unicode_names[i] = 'Number' # 'DIGIT' - elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or - unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) : - # 明 / 郎 - unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' - elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or - unicode_names[i].startswith('LATIN SMALL LETTER')) : - # A / a - unicode_names[i] = 'English' # 'LATIN LETTER' - elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め - unicode_names[i] = 'Japanese' # 'GANA LETTER' - elif (unicode_names[i].startswith('AMPERSAND') or - unicode_names[i].startswith('APOSTROPHE') or - unicode_names[i].startswith('COMMERCIAL AT') or - unicode_names[i].startswith('DEGREE CELSIUS') or - unicode_names[i].startswith('EQUALS SIGN') or - unicode_names[i].startswith('FULL STOP') or - unicode_names[i].startswith('HYPHEN-MINUS') or - unicode_names[i].startswith('LOW LINE') or - unicode_names[i].startswith('NUMBER SIGN') or - unicode_names[i].startswith('PLUS SIGN') or - unicode_names[i].startswith('SEMICOLON')) : - # & / ' / @ / ℃ / = / . / - / _ / # / + / ; - del unicode_names[i] - else : - return 'Other' - if len(unicode_names) == 0 : - return 'Other' - if len(unicode_names) == 1 : - return unicode_names[0] - for i in range(len(unicode_names)-1) : - if unicode_names[i] != unicode_names[i+1] : - return 'Other' - return unicode_names[0] -def usage() : - print("compute-wer.py : compute word error rate (WER) and align recognition results and references.") - print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer") +def default_cluster(word): + unicode_names = [unicodedata.name(char) for char in word] + for i in reversed(range(len(unicode_names))): + if unicode_names[i].startswith('DIGIT'): # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or + unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')): + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or + unicode_names[i].startswith('LATIN SMALL LETTER')): + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') or + unicode_names[i].startswith('APOSTROPHE') or + unicode_names[i].startswith('COMMERCIAL AT') or + unicode_names[i].startswith('DEGREE CELSIUS') or + unicode_names[i].startswith('EQUALS SIGN') or + unicode_names[i].startswith('FULL STOP') or + unicode_names[i].startswith('HYPHEN-MINUS') or + unicode_names[i].startswith('LOW LINE') or + unicode_names[i].startswith('NUMBER SIGN') or + unicode_names[i].startswith('PLUS SIGN') or + unicode_names[i].startswith('SEMICOLON')): + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else: + return 'Other' + if len(unicode_names) == 0: + return 'Other' + if len(unicode_names) == 1: + return unicode_names[0] + for i in range(len(unicode_names) - 1): + if unicode_names[i] != unicode_names[i + 1]: + return 'Other' + return unicode_names[0] + + +def usage(): + print( + "compute-wer.py : compute word error rate (WER) and align recognition results and references." + ) + print( + " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer" + ) + if __name__ == '__main__': - if len(sys.argv) == 1 : - usage() - sys.exit(0) - calculator = Calculator() - cluster_file = '' - ignore_words = set() - tochar = False - verbose= 1 - padding_symbol= ' ' - case_sensitive = False - max_words_per_line = sys.maxsize - split = None - while len(sys.argv) > 3: - a = '--maxw=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):] - del sys.argv[1] - max_words_per_line = int(b) - continue - a = '--rt=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - remove_tag = (b == 'true') or (b != '0') - continue - a = '--cs=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - case_sensitive = (b == 'true') or (b != '0') - continue - a = '--cluster=' - if sys.argv[1].startswith(a): - cluster_file = sys.argv[1][len(a):] - del sys.argv[1] - continue - a = '--splitfile=' - if sys.argv[1].startswith(a): - split_file = sys.argv[1][len(a):] - del sys.argv[1] - split = dict() - with codecs.open(split_file, 'r', 'utf-8') as fh: - for line in fh: # line in unicode - words = line.strip().split() - if len(words) >= 2: - split[words[0]] = words[1:] - continue - a = '--ig=' - if sys.argv[1].startswith(a): - ignore_file = sys.argv[1][len(a):] - del sys.argv[1] - with codecs.open(ignore_file, 'r', 'utf-8') as fh: - for line in fh: # line in unicode - line = line.strip() - if len(line) > 0: - ignore_words.add(line) - continue - a = '--char=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - tochar = (b == 'true') or (b != '0') - continue - a = '--v=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - verbose=0 - try: - verbose=int(b) - except: - if b == 'true' or b != '0': - verbose = 1 - continue - a = '--padding-symbol=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - if b == 'space': - padding_symbol= ' ' - elif b == 'underline': - padding_symbol= '_' - continue - if True or sys.argv[1].startswith('-'): - #ignore invalid switch - del sys.argv[1] - continue + if len(sys.argv) == 1: + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose = 1 + padding_symbol = ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose = 0 + try: + verbose = int(b) + except Exception as e: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol = ' ' + elif b == 'underline': + padding_symbol = '_' + continue + if True or sys.argv[1].startswith('-'): + #ignore invalid switch + del sys.argv[1] + continue - if not case_sensitive: - ig=set([w.upper() for w in ignore_words]) - ignore_words = ig + if not case_sensitive: + ig = set([w.upper() for w in ignore_words]) + ignore_words = ig - default_clusters = {} - default_words = {} + default_clusters = {} + default_words = {} - ref_file = sys.argv[1] - hyp_file = sys.argv[2] - rec_set = {} - if split and not case_sensitive: - newsplit = dict() - for w in split: - words = split[w] - for i in range(len(words)): - words[i] = words[i].upper() - newsplit[w.upper()] = words - split = newsplit + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit - with codecs.open(hyp_file, 'r', 'utf-8') as fh: - for line in fh: + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, + split) + + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8'): if tochar: array = characterize(line) else: - array = line.strip().split() - if len(array)==0: continue + array = line.rstrip('\n').split() + if len(array) == 0: + continue fid = array[0] - rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split) + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) - # compute error rate on the interaction of reference file and hyp file - for line in open(ref_file, 'r', encoding='utf-8') : - if tochar: - array = characterize(line) - else: - array = line.rstrip('\n').split() - if len(array)==0: continue - fid = array[0] - if fid not in rec_set: - continue - lab = normalize(array[1:], ignore_words, case_sensitive, split) - rec = rec_set[fid] - if verbose: - print('\nutt: %s' % fid) + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name - for word in rec + lab : - if word not in default_words : - default_cluster_name = default_cluster(word) - if default_cluster_name not in default_clusters : - default_clusters[default_cluster_name] = {} - if word not in default_clusters[default_cluster_name] : - default_clusters[default_cluster_name][word] = 1 - default_words[word] = default_cluster_name + result = calculator.calculate(lab, rec) + if verbose: + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('WER: %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])): + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length - len_lab) + space['rec'].append(length - len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end=' ') + else: + print('lab:', end=' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['lab'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end=' ') + else: + print('rec:', end=' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['rec'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 - result = calculator.calculate(lab, rec) if verbose: - if result['all'] != 0 : - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : - wer = 0.0 - print('WER: %4.2f %%' % wer, end = ' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - space = {} - space['lab'] = [] - space['rec'] = [] - for idx in range(len(result['lab'])) : - len_lab = width(result['lab'][idx]) - len_rec = width(result['rec'][idx]) - length = max(len_lab, len_rec) - space['lab'].append(length-len_lab) - space['rec'].append(length-len_rec) - upper_lab = len(result['lab']) - upper_rec = len(result['rec']) - lab1, rec1 = 0, 0 - while lab1 < upper_lab or rec1 < upper_rec: - if verbose > 1: - print('lab(%s):' % fid.encode('utf-8'), end = ' ') - else: - print('lab:', end = ' ') - lab2 = min(upper_lab, lab1 + max_words_per_line) - for idx in range(lab1, lab2): - token = result['lab'][idx] - print('{token}'.format(token = token), end = '') - for n in range(space['lab'][idx]) : - print(padding_symbol, end = '') - print(' ',end='') - print() - if verbose > 1: - print('rec(%s):' % fid.encode('utf-8'), end = ' ') - else: - print('rec:', end = ' ') - rec2 = min(upper_rec, rec1 + max_words_per_line) - for idx in range(rec1, rec2): - token = result['rec'][idx] - print('{token}'.format(token = token), end = '') - for n in range(space['rec'][idx]) : - print(padding_symbol, end = '') - print(' ',end='') - print('\n', end='\n') - lab1 = lab2 - rec1 = rec2 - - if verbose: - print('===========================================================================') - print() - - result = calculator.overall() - if result['all'] != 0 : - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : - wer = 0.0 - print('Overall -> %4.2f %%' % wer, end = ' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - if not verbose: - print() + print( + '===========================================================================' + ) + print() - if verbose: - for cluster_id in default_clusters : - result = calculator.cluster([ k for k in default_clusters[cluster_id] ]) - if result['all'] != 0 : - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : + result = calculator.overall() + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: wer = 0.0 - print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - if len(cluster_file) > 0 : # compute separated WERs for word clusters - cluster_id = '' - cluster = [] - for line in open(cluster_file, 'r', encoding='utf-8') : - for token in line.decode('utf-8').rstrip('\n').split() : - # end of cluster reached, like - if token[0:2] == '' and \ - token.lstrip('') == cluster_id : - result = calculator.cluster(cluster) - if result['all'] != 0 : - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : - wer = 0.0 - print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - cluster_id = '' - cluster = [] - # begin of cluster reached, like - elif token[0] == '<' and token[len(token)-1] == '>' and \ - cluster_id == '' : - cluster_id = token.lstrip('<').rstrip('>') - cluster = [] - # general terms, like WEATHER / CAR / ... - else : - cluster.append(token) - print() - print('===========================================================================') + print('Overall -> %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters: + result = calculator.cluster( + [k for k in default_clusters[cluster_id]]) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if len(cluster_file) > 0: # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8'): + for token in line.decode('utf-8').rstrip('\n').split(): + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif token[0] == '<' and token[len(token) - 1] == '>' and \ + cluster_id == '' : + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else: + cluster.append(token) + print() + print( + '===========================================================================' + ) diff --git a/utils/espnet_json_to_manifest.py b/utils/espnet_json_to_manifest.py old mode 100644 new mode 100755 diff --git a/utils/fst/prepare_dict.py b/utils/fst/prepare_dict.py index f59cd3113fd04de078a9616dd61cc6fa0fabe44e..301d72fb08376cab1a552c1c377738746715bfe4 100755 --- a/utils/fst/prepare_dict.py +++ b/utils/fst/prepare_dict.py @@ -3,7 +3,8 @@ import argparse def main(args): - # load `unit` or `vocab` file + # load vocab file + # line: token unit_table = set() with open(args.unit_file, 'r') as fin: for line in fin: @@ -11,27 +12,41 @@ def main(args): unit_table.add(unit) def contain_oov(units): + """token not in vocab + + Args: + units (str): token + + Returns: + bool: True token in voca, else False. + """ for unit in units: if unit not in unit_table: return True return False - # load spm model + # load spm model, for English bpemode = args.bpemodel if bpemode: import sentencepiece as spm sp = spm.SentencePieceProcessor() sp.Load(sys.bpemodel) - # used to filter polyphone + # used to filter polyphone and invalid word lexicon_table = set() + in_n = 0 # in lexicon word count + out_n = 0 # out lexicon word cout with open(args.in_lexicon, 'r') as fin, \ open(args.out_lexicon, 'w') as fout: for line in fin: word = line.split()[0] + in_n += 1 + if word == 'SIL' and not bpemode: # `sil` might be a valid piece in bpemodel + # filter 'SIL' for mandarin, keep it in English continue elif word == '': + # filter continue else: # each word only has one pronunciation for e2e system @@ -39,12 +54,14 @@ def main(args): continue if bpemode: + # for english pieces = sp.EncodeAsPieces(word) if contain_oov(pieces): print('Ignoring words {}, which contains oov unit'. format(''.join(word).strip('▁'))) continue + # word is piece list, which not have piece, filter out by `contain_oov(pieces)` chars = ' '.join( [p if p in unit_table else '' for p in pieces]) else: @@ -58,11 +75,14 @@ def main(args): # we assume the model unit of our e2e system is char now. if word.encode('utf8').isalpha() and '▁' in unit_table: word = '▁' + word + chars = ' '.join(word) # word is a char list fout.write('{} {}\n'.format(word, chars)) lexicon_table.add(word) + out_n += 1 + print(f"Filter lexicon by unit table: filter out {in_n - out_n}, {out_n}/{in_n}") if __name__ == '__main__': parser = argparse.ArgumentParser( diff --git a/utils/generate_infer_yaml.py b/utils/generate_infer_yaml.py old mode 100644 new mode 100755 diff --git a/utils/link_wav.py b/utils/link_wav.py old mode 100644 new mode 100755 diff --git a/utils/manifest_key_value.py b/utils/manifest_key_value.py index fb3d3aaaf47948428cd5eaf4a9ae6b0fe82b93e1..3825fb9b3b03103543549284b3801aa41fff45ca 100755 --- a/utils/manifest_key_value.py +++ b/utils/manifest_key_value.py @@ -26,23 +26,39 @@ def main(args): with wav_scp.open('w') as fwav, dur_scp.open('w') as fdur, text_scp.open( 'w') as ftxt: for line_json in manifest_jsons: + # utt:str + # utt2spk:str + # input: [{name:str, shape:[dur_in_sec, feat_dim], feat:str, filetype:str}, ] + # output: [{name:str, shape:[tokenlen, vocab_dim], text:str, token:str, tokenid:str}, ] utt = line_json['utt'] - feat = line_json['feat'] + utt2spk = line_json['utt2spk'] + + # input + assert (len(line_json['input']) == 1), "only support one input now" + input_json = line_json['input'][0] + feat = input_json['feat'] + feat_shape = input_json['shape'] + file_type = input_json['filetype'] + file_ext = Path(feat).suffix # .wav - text = line_json['text'] - feat_shape = line_json['feat_shape'] dur = feat_shape[0] feat_dim = feat_shape[1] - if 'token' in line_json: - tokens = line_json['token'] - tokenids = line_json['token_id'] - token_shape = line_json['token_shape'] - token_len = token_shape[0] - vocab_dim = token_shape[1] if file_ext == '.wav': fwav.write(f"{utt} {feat}\n") fdur.write(f"{utt} {dur}\n") + + # output + assert ( + len(line_json['output']) == 1), "only support one output now" + output_json = line_json['output'][0] + text = output_json['text'] + if 'token' in output_json: + tokens = output_json['token'] + tokenids = output_json['tokenid'] + token_shape = output_json['shape'] + token_len = token_shape[0] + vocab_dim = token_shape[1] ftxt.write(f"{utt} {text}\n") count += 1 diff --git a/utils/zh_tn.py b/utils/zh_tn.py index 4dcf27431400c14fcd2346c9f82c9b5a865d7fb8..73bb8af222a371f3a40197e7f80b382c81cb5f97 100755 --- a/utils/zh_tn.py +++ b/utils/zh_tn.py @@ -4,6 +4,7 @@ import argparse import re import string import sys +import unicodedata from typing import List from typing import Text @@ -33,6 +34,14 @@ POINT = [u'点', u'點'] # PLUS = [u'加', u'加'] # SIL = [u'杠', u'槓'] +FILLER_CHARS = ['呃', '啊'] + +ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \ + '胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \ + '儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \ + '佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)' +ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST) + # 中文数字系统类型 NUMBERING_TYPES = ['low', 'mid', 'high'] @@ -48,15 +57,330 @@ COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘| # punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) CHINESE_PUNC_STOP = '!?。。' -CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏' -CHINESE_PUNC_OTHER = '·〈〉-' -CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP + CHINESE_PUNC_OTHER +CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏·〈〉-' +CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP + +# https://zh.wikipedia.org/wiki/全行和半行 +QJ2BJ = { + ' ': ' ', + '!': '!', + '"': '"', + '#': '#', + '$': '$', + '%': '%', + '&': '&', + ''': "'", + '(': '(', + ')': ')', + '*': '*', + '+': '+', + ',': ',', + '-': '-', + '.': '.', + '/': '/', + '0': '0', + '1': '1', + '2': '2', + '3': '3', + '4': '4', + '5': '5', + '6': '6', + '7': '7', + '8': '8', + '9': '9', + ':': ':', + ';': ';', + '<': '<', + '=': '=', + '>': '>', + '?': '?', + '@': '@', + 'A': 'A', + 'B': 'B', + 'C': 'C', + 'D': 'D', + 'E': 'E', + 'F': 'F', + 'G': 'G', + 'H': 'H', + 'I': 'I', + 'J': 'J', + 'K': 'K', + 'L': 'L', + 'M': 'M', + 'N': 'N', + 'O': 'O', + 'P': 'P', + 'Q': 'Q', + 'R': 'R', + 'S': 'S', + 'T': 'T', + 'U': 'U', + 'V': 'V', + 'W': 'W', + 'X': 'X', + 'Y': 'Y', + 'Z': 'Z', + '[': '[', + '\': '\\', + ']': ']', + '^': '^', + '_': '_', + '`': '`', + 'a': 'a', + 'b': 'b', + 'c': 'c', + 'd': 'd', + 'e': 'e', + 'f': 'f', + 'g': 'g', + 'h': 'h', + 'i': 'i', + 'j': 'j', + 'k': 'k', + 'l': 'l', + 'm': 'm', + 'n': 'n', + 'o': 'o', + 'p': 'p', + 'q': 'q', + 'r': 'r', + 's': 's', + 't': 't', + 'u': 'u', + 'v': 'v', + 'w': 'w', + 'x': 'x', + 'y': 'y', + 'z': 'z', + '{': '{', + '|': '|', + '}': '}', + '~': '~', +} + +QJ2BJ_transform = str.maketrans(''.join(QJ2BJ.keys()), ''.join(QJ2BJ.values()), + '') + +# char set +DIGIT_CHARS = '0123456789' + +EN_CHARS = ('abcdefghijklmnopqrstuvwxyz' 'ABCDEFGHIJKLMNOPQRSTUVWXYZ') + +# 2013 China National Standard: https://zh.wikipedia.org/wiki/通用规范汉字表 +# raw resources from: https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt , with total 8105 chars +CN_CHARS = ('一丁七万丈三上下不与丏丐丑专且丕世丘丙业丛东丝丞丢两严丧个丫中丰串临丸丹为主丽举' + '乂乃久么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乸乾了予争事二亍于亏云互' + '亓五井亘亚些亟亡亢交亥亦产亨亩享京亭亮亲亳亵亶亸亹人亿什仁仂仃仄仅仆仇仉今介仍从' + '仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伈伉伊伋伍伎伏伐休众优' + '伙会伛伞伟传伢伣伤伥伦伧伪伫伭伯估伲伴伶伸伺似伽伾佁佃但位低住佐佑体何佖佗佘余佚' + '佛作佝佞佟你佣佤佥佩佬佯佰佳佴佶佸佺佻佼佽佾使侁侂侃侄侈侉例侍侏侑侔侗侘供依侠侣' + '侥侦侧侨侩侪侬侮侯侴侵侹便促俄俅俊俍俎俏俐俑俗俘俙俚俜保俞俟信俣俦俨俩俪俫俭修俯' + '俱俳俵俶俸俺俾倌倍倏倒倓倔倕倘候倚倜倞借倡倥倦倧倨倩倪倬倭倮倴债倻值倾偁偃假偈偌' + '偎偏偓偕做停偡健偬偭偰偲偶偷偻偾偿傀傃傅傈傉傍傒傕傣傥傧储傩催傲傺傻僇僎像僔僖僚' + '僦僧僬僭僮僰僳僵僻儆儇儋儒儡儦儳儴儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六' + '兮兰共关兴兵其具典兹养兼兽冀冁内冈冉册再冏冒冔冕冗写军农冠冢冤冥冬冮冯冰冱冲决况' + '冶冷冻冼冽净凄准凇凉凋凌减凑凓凘凛凝几凡凤凫凭凯凰凳凶凸凹出击凼函凿刀刁刃分切刈' + '刊刍刎刑划刖列刘则刚创初删判刨利别刬刭刮到刳制刷券刹刺刻刽刿剀剁剂剃剅削剋剌前剐' + '剑剔剕剖剜剞剟剡剥剧剩剪副割剽剿劁劂劄劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劼' + '劾势勃勇勉勋勍勐勒勔勖勘勚募勠勤勰勺勾勿匀包匆匈匍匏匐匕化北匙匜匝匠匡匣匦匪匮匹' + '区医匼匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卫卬卮卯印危即却卵' + '卷卸卺卿厂厄厅历厉压厌厍厕厖厘厚厝原厢厣厥厦厨厩厮去厾县叁参叆叇又叉及友双反发叔' + '叕取受变叙叚叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊' + '同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吽吾呀呃呆' + '呇呈告呋呐呒呓呔呕呖呗员呙呛呜呢呣呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咇咉咋和咍咎' + '咏咐咒咔咕咖咙咚咛咝咡咣咤咥咦咧咨咩咪咫咬咯咱咳咴咸咺咻咽咿哀品哂哃哄哆哇哈哉哌' + '响哎哏哐哑哒哓哔哕哗哙哚哝哞哟哢哥哦哧哨哩哪哭哮哱哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛' + '唝唠唢唣唤唧唪唬售唯唰唱唳唵唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啫啬啭啮啰啴' + '啵啶啷啸啻啼啾喀喁喂喃善喆喇喈喉喊喋喏喑喔喘喙喜喝喟喤喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌' + '嗍嗐嗑嗒嗓嗔嗖嗜嗝嗞嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘚嘛嘞嘟嘡' + '嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噀噂噇噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚄嚅嚆嚎嚏嚓' + '嚚嚣嚭嚯嚷嚼囊囔囚四回囟因囡团囤囫园困囱围囵囷囹固国图囿圃圄圆圈圉圊圌圐圙圜土圢' + '圣在圩圪圫圬圭圮圯地圲圳圹场圻圾址坂均坉坊坋坌坍坎坏坐坑坒块坚坛坜坝坞坟坠坡坤坥' + '坦坨坩坪坫坬坭坯坰坳坷坻坼坽垂垃垄垆垈型垌垍垎垏垒垓垕垙垚垛垞垟垠垡垢垣垤垦垧垩' + '垫垭垮垯垱垲垴垵垸垺垾垿埂埃埆埇埋埌城埏埒埔埕埗埘埙埚埝域埠埤埪埫埭埯埴埵埸培基' + '埼埽堂堃堆堇堉堋堌堍堎堐堑堕堙堞堠堡堤堧堨堪堰堲堵堼堽堾塄塅塆塌塍塑塔塘塝塞塥填' + '塬塱塾墀墁境墅墈墉墐墒墓墕墘墙墚增墟墡墣墦墨墩墼壁壅壑壕壤士壬壮声壳壶壸壹处备复' + '夏夐夔夕外夙多夜够夤夥大天太夫夬夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奓奔奕奖' + '套奘奚奠奡奢奥奭女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妘妙妞妣妤妥妧妨妩妪妫妭妮' + '妯妲妹妻妾姆姈姊始姐姑姒姓委姗姘姚姜姝姞姣姤姥姨姬姮姱姶姹姻姽姿娀威娃娄娅娆娇娈' + '娉娌娑娓娘娜娟娠娣娥娩娱娲娴娵娶娼婀婆婉婊婌婍婕婘婚婞婠婢婤婧婪婫婳婴婵婶婷婺婻' + '婼婿媂媄媆媒媓媖媚媛媞媪媭媱媲媳媵媸媾嫁嫂嫄嫉嫌嫒嫔嫕嫖嫘嫚嫜嫠嫡嫣嫦嫩嫪嫫嫭嫱' + '嫽嬉嬖嬗嬛嬥嬬嬴嬷嬿孀孅子孑孓孔孕孖字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽' + '宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宧宪宫宬宰害宴宵家宸容宽宾' + '宿寁寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝' + '尢尤尥尧尨尪尬就尴尸尹尺尻尼尽尾尿局屁层屃居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屯山' + '屹屺屼屾屿岁岂岈岊岌岍岐岑岔岖岗岘岙岚岛岜岞岠岢岣岨岩岫岬岭岱岳岵岷岸岽岿峁峂峃' + '峄峋峒峗峘峙峛峡峣峤峥峦峧峨峪峭峰峱峻峿崀崁崂崃崄崆崇崌崎崒崔崖崚崛崞崟崡崤崦崧' + '崩崭崮崴崶崽崾崿嵁嵅嵇嵊嵋嵌嵎嵖嵘嵚嵛嵝嵩嵫嵬嵯嵲嵴嶂嶅嶍嶒嶓嶙嶝嶟嶦嶲嶷巅巇巉' + '巍川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝帡' + '带帧帨席帮帱帷常帻帼帽幂幄幅幌幔幕幖幛幞幡幢幪干平年并幸幺幻幼幽广庄庆庇床庋序庐' + '庑库应底庖店庙庚府庞废庠庤庥度座庭庱庳庵庶康庸庹庼庾廆廉廊廋廑廒廓廖廙廛廨廪延廷' + '建廿开弁异弃弄弆弇弈弊弋式弑弓引弗弘弛弟张弢弥弦弧弨弩弭弯弱弶弸弹强弼彀归当录彖' + '彗彘彝彟形彤彦彧彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律徐徒徕得徘徙徛徜御徨循' + '徭微徵德徼徽心必忆忉忌忍忏忐忑忒忖志忘忙忝忞忠忡忤忧忪快忭忮忱忳念忸忺忻忽忾忿怀' + '态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恓' + '恔恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悆悈悉悌悍悒悔悖悚悛悝悟' + '悠悢患悦您悫悬悭悯悰悱悲悴悸悻悼情惆惇惊惋惎惑惔惕惘惙惚惛惜惝惟惠惦惧惨惩惫惬惭' + '惮惯惰想惴惶惹惺愀愁愃愆愈愉愍愎意愐愔愕愚感愠愣愤愦愧愫愭愿慆慈慊慌慎慑慕慝慢慥' + '慧慨慬慭慰慵慷憋憎憔憕憙憧憨憩憬憭憷憺憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我' + '戒戕或戗战戚戛戟戡戢戣戤戥截戬戭戮戳戴户戽戾房所扁扂扃扅扆扇扈扉扊手才扎扑扒打扔' + '托扛扞扣扦执扩扪扫扬扭扮扯扰扳扶批扺扼扽找承技抃抄抉把抑抒抓抔投抖抗折抚抛抟抠抡' + '抢护报抨披抬抱抵抹抻押抽抿拂拃拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙招拜拟拢拣拤拥' + '拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挓挖挚挛挝挞挟挠挡挣挤挥挦挨挪挫' + '振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捯捶捷捺捻捽掀掂掇授掉掊掌掎' + '掏掐排掖掘掞掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍描提插揕揖揠握揣揩揪揭' + '揳援揶揸揽揿搀搁搂搅搋搌搏搐搒搓搔搛搜搞搠搡搦搪搬搭搴携搽摁摄摅摆摇摈摊摏摒摔摘' + '摛摞摧摩摭摴摸摹摽撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擘擞擢' + '擤擦擿攀攉攒攘攥攫攮支收攸改攻攽放政故效敉敌敏救敔敕敖教敛敝敞敢散敦敩敫敬数敲整' + '敷文斋斌斐斑斓斗料斛斜斝斟斠斡斤斥斧斩斫断斯新斶方於施旁旃旄旅旆旋旌旎族旐旒旖旗' + '旞无既日旦旧旨早旬旭旮旯旰旱旴旵时旷旸旺旻旿昀昂昃昄昆昇昈昉昊昌明昏昒易昔昕昙昝' + '星映昡昣昤春昧昨昪昫昭是昱昳昴昵昶昺昼昽显晁晃晅晊晋晌晏晐晒晓晔晕晖晗晙晚晞晟晡' + '晢晤晦晨晪晫普景晰晱晴晶晷智晾暂暄暅暇暌暑暕暖暗暝暧暨暮暲暴暵暶暹暾暿曈曌曙曛曜' + '曝曦曩曰曲曳更曷曹曼曾替最月有朋服朏朐朓朔朕朗望朝期朦木未末本札术朱朳朴朵朸机朽' + '杀杂权杄杆杈杉杌李杏材村杓杕杖杙杜杞束杠条来杧杨杩杪杭杯杰杲杳杵杷杻杼松板极构枅' + '枇枉枋枍析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枲枳枵架枷枸枹柁柃柄柈柊柏某柑柒染柔' + '柖柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柷柽柿栀栅标栈栉栊栋栌栎栏栐树栒栓栖栗栝栟校栩' + '株栲栳栴样核根栻格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桯' + '桲桴桶桷桹梁梃梅梆梌梏梓梗梠梢梣梦梧梨梭梯械梳梴梵梼梽梾梿检棁棂棉棋棍棐棒棓棕棘' + '棚棠棣棤棨棪棫棬森棰棱棵棹棺棻棼棽椀椁椅椆椋植椎椐椑椒椓椟椠椤椪椭椰椴椸椹椽椿楂' + '楒楔楗楙楚楝楞楠楣楦楩楪楫楮楯楷楸楹楼概榃榄榅榆榇榈榉榍榑榔榕榖榛榜榧榨榫榭榰榱' + '榴榷榻槁槃槊槌槎槐槔槚槛槜槟槠槭槱槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橑橘橙橛橞' + '橡橥橦橱橹橼檀檄檎檐檑檗檞檠檩檫檬櫆欂欠次欢欣欤欧欲欸欹欺欻款歃歅歆歇歉歌歙止正' + '此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殣殪殳殴段殷殿毁毂毅毋毌母每毐毒' + '毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮' + '氯氰氲水永氾氿汀汁求汆汇汈汉汊汋汐汔汕汗汛汜汝汞江池污汤汧汨汩汪汫汭汰汲汴汶汹汽' + '汾沁沂沃沄沅沆沇沈沉沌沏沐沓沔沘沙沚沛沟没沣沤沥沦沧沨沩沪沫沭沮沱河沸油沺治沼沽' + '沾沿泂泃泄泅泇泉泊泌泐泓泔法泖泗泙泚泛泜泞泠泡波泣泥注泪泫泮泯泰泱泳泵泷泸泺泻泼' + '泽泾洁洄洇洈洋洌洎洑洒洓洗洘洙洚洛洞洢洣津洧洨洪洫洭洮洱洲洳洴洵洸洹洺活洼洽派洿' + '流浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕浙浚浛浜浞浟浠浡浣浥浦浩浪浬浭浮浯浰浲浴海浸' + '浼涂涄涅消涉涌涍涎涐涑涓涔涕涘涛涝涞涟涠涡涢涣涤润涧涨涩涪涫涮涯液涴涵涸涿淀淄淅' + '淆淇淋淌淏淑淖淘淙淜淝淞淟淠淡淤淦淫淬淮淯深淳淴混淹添淼清渊渌渍渎渐渑渔渗渚渝渟' + '渠渡渣渤渥温渫渭港渰渲渴游渺渼湃湄湉湍湎湑湓湔湖湘湛湜湝湟湣湫湮湲湴湾湿溁溃溅溆' + '溇溉溍溏源溘溚溜溞溟溠溢溥溦溧溪溯溱溲溴溵溶溷溹溺溻溽滁滂滃滆滇滉滋滍滏滑滓滔滕' + '滗滘滚滞滟滠满滢滤滥滦滧滨滩滪滫滴滹漂漆漈漉漋漏漓演漕漖漠漤漦漩漪漫漭漯漱漳漴漶' + '漷漹漻漼漾潆潇潋潍潏潖潘潜潞潟潢潦潩潭潮潲潴潵潸潺潼潽潾澂澄澈澉澌澍澎澛澜澡澥澧' + '澪澭澳澴澶澹澼澽激濂濉濋濑濒濞濠濡濩濮濯瀌瀍瀑瀔瀚瀛瀣瀱瀵瀹瀼灈灌灏灞火灭灯灰灵' + '灶灸灼灾灿炀炅炆炉炊炌炎炒炔炕炖炘炙炜炝炟炣炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈' + '烊烔烘烙烛烜烝烟烠烤烦烧烨烩烫烬热烯烶烷烹烺烻烽焆焉焊焌焐焓焕焖焗焘焙焚焜焞焦焯' + '焰焱然煁煃煅煊煋煌煎煓煜煞煟煤煦照煨煮煲煳煴煸煺煽熄熇熊熏熔熘熙熛熜熟熠熥熨熬熵' + '熹熻燃燊燋燎燏燔燕燚燠燥燧燮燹爆爇爔爚爝爟爨爪爬爰爱爵父爷爸爹爻爽爿牁牂片版牌牍' + '牒牖牙牚牛牝牟牡牢牤牥牦牧物牮牯牲牵特牺牻牾牿犀犁犄犇犊犋犍犏犒犟犨犬犯犰犴状犷' + '犸犹狁狂狃狄狈狉狍狎狐狒狗狙狝狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猄猇猊猎' + '猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猯猰猱猴猷猹猺猾猿獍獐獒獗獠獬獭獯獴獾玃玄率玉王玎' + '玑玒玓玕玖玘玙玚玛玞玟玠玡玢玤玥玦玩玫玭玮环现玱玲玳玶玷玹玺玻玼玿珀珂珅珇珈珉珊' + '珋珌珍珏珐珑珒珕珖珙珛珝珞珠珢珣珥珦珧珩珪珫班珰珲珵珷珸珹珺珽琀球琄琅理琇琈琉琊' + '琎琏琐琔琚琛琟琡琢琤琥琦琨琪琫琬琭琮琯琰琲琳琴琵琶琼瑀瑁瑂瑃瑄瑅瑆瑑瑓瑔瑕瑖瑗瑙' + '瑚瑛瑜瑝瑞瑟瑢瑧瑨瑬瑭瑰瑱瑳瑶瑷瑾璀璁璃璆璇璈璋璎璐璒璘璜璞璟璠璥璧璨璩璪璬璮璱' + '璲璺瓀瓒瓖瓘瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓻瓿甄甍甏甑甓甗甘甚甜生甡甥甦用甩甪甫甬甭甯' + '田由甲申电男甸町画甾畀畅畈畋界畎畏畔畖留畚畛畜畤略畦番畬畯畲畴畸畹畿疁疃疆疍疏疐' + '疑疔疖疗疙疚疝疟疠疡疢疣疤疥疫疬疭疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒' + '痓痔痕痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘆瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩' + '瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癗癜癞癣癫癯癸登白百癿皂的皆皇皈皋皎皑皓皕皖皙' + '皛皞皤皦皭皮皱皲皴皿盂盅盆盈盉益盍盎盏盐监盒盔盖盗盘盛盟盥盦目盯盱盲直盷相盹盼盾' + '省眄眇眈眉眊看眍眙眚真眠眢眦眨眩眬眭眯眵眶眷眸眺眼着睁睃睄睇睎睐睑睚睛睡睢督睥睦' + '睨睫睬睹睽睾睿瞀瞄瞅瞋瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞫瞬瞭瞰瞳瞵瞻瞽瞿矍矗矛矜矞矢矣知' + '矧矩矫矬短矮矰石矶矸矻矼矾矿砀码砂砄砆砉砌砍砑砒研砖砗砘砚砜砝砟砠砣砥砧砫砬砭砮' + '砰破砵砷砸砹砺砻砼砾础硁硅硇硊硌硍硎硐硒硔硕硖硗硙硚硝硪硫硬硭确硼硿碃碇碈碉碌碍' + '碎碏碑碓碗碘碚碛碜碟碡碣碥碧碨碰碱碲碳碴碶碹碾磁磅磉磊磋磏磐磔磕磙磜磡磨磬磲磴磷' + '磹磻礁礅礌礓礞礴礵示礼社祀祁祃祆祇祈祉祊祋祎祏祐祓祕祖祗祚祛祜祝神祟祠祢祥祧票祭' + '祯祲祷祸祺祼祾禀禁禄禅禊禋福禒禔禘禚禛禤禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租秣' + '秤秦秧秩秫秬秭积称秸移秽秾稀稂稃稆程稌稍税稑稔稗稙稚稞稠稣稳稷稹稻稼稽稿穄穆穑穗' + '穙穜穟穰穴究穷穸穹空穿窀突窃窄窅窈窊窍窎窑窒窕窖窗窘窜窝窟窠窣窥窦窨窬窭窳窸窿立' + '竑竖竘站竞竟章竣童竦竫竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笯' + '笱笳笸笺笼笾筀筅筇等筋筌筏筐筑筒答策筘筚筛筜筝筠筢筤筥筦筮筱筲筵筶筷筹筻筼签简箅' + '箍箐箓箔箕箖算箜管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篯篱篷篼篾' + '簃簇簉簋簌簏簕簖簝簟簠簧簪簰簸簿籀籁籍籥米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮' + '粱粲粳粹粼粽精粿糁糅糇糈糊糌糍糒糕糖糗糙糜糟糠糨糯糵系紊素索紧紫累絜絮絷綦綮縠縢' + '縻繁繄繇纂纛纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁' + '绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩' + '绪绫续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缌缎缐缑缒缓缔缕' + '编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罍罐' + '网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罶罹罽罾羁羊羌美羑羓羔羕羖羚羝羞羟羡群羧羯羰羱羲' + '羸羹羼羽羿翀翁翂翃翅翈翊翌翎翔翕翘翙翚翛翟翠翡翥翦翩翮翯翰翱翳翷翻翼翾耀老考耄者' + '耆耇耋而耍耏耐耑耒耔耕耖耗耘耙耜耠耢耤耥耦耧耨耩耪耰耱耳耵耶耷耸耻耽耿聂聃聆聊聋' + '职聍聒联聘聚聩聪聱聿肃肄肆肇肉肋肌肓肖肘肚肛肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肸' + '肺肼肽肾肿胀胁胂胃胄胆胈背胍胎胖胗胙胚胛胜胝胞胠胡胣胤胥胧胨胩胪胫胬胭胯胰胱胲胳' + '胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脟脩脬脯脱脲脶脸脾脿腆腈腊腋腌腐腑腒' + '腓腔腕腘腙腚腠腥腧腨腩腭腮腯腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膙膛膜膝膦膨膳膺膻' + '臀臂臃臆臊臌臑臜臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舠舢舣舥航舫般' + '舭舯舰舱舲舳舴舵舶舷舸船舻舾艄艅艇艉艋艎艏艘艚艟艨艮良艰色艳艴艺艽艾艿节芃芄芈芊' + '芋芍芎芏芑芒芗芘芙芜芝芟芠芡芣芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芼芽芾苁苄苇苈' + '苉苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苧苫苯英苴苷苹苻苾茀茁茂范茄茅茆' + '茈茉茋茌茎茏茑茓茔茕茗茚茛茜茝茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼茽荀荁荃荄荆荇草荏荐' + '荑荒荓荔荖荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莙莛' + '莜莝莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽莿菀菁菂菅菇菉菊菌菍菏菔菖菘菜菝菟菠菡菥' + '菩菪菰菱菲菹菼菽萁萃萄萆萋萌萍萎萏萑萘萚萜萝萣萤营萦萧萨萩萱萳萸萹萼落葆葎葑葖著' + '葙葚葛葜葡董葩葫葬葭葰葱葳葴葵葶葸葺蒂蒄蒇蒈蒉蒋蒌蒎蒐蒗蒙蒜蒟蒡蒨蒯蒱蒲蒴蒸蒹蒺' + '蒻蒽蒿蓁蓂蓄蓇蓉蓊蓍蓏蓐蓑蓓蓖蓝蓟蓠蓢蓣蓥蓦蓬蓰蓼蓿蔀蔃蔈蔊蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷' + '蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕗蕙蕞蕤蕨蕰蕲蕴蕹蕺蕻蕾薁薄薅薇薏薛薜薢薤薨薪薮薯薰薳薷薸' + '薹薿藁藉藏藐藓藕藜藟藠藤藦藨藩藻藿蘅蘑蘖蘘蘧蘩蘸蘼虎虏虐虑虒虓虔虚虞虢虤虫虬虮虱' + '虷虸虹虺虻虼虽虾虿蚀蚁蚂蚄蚆蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚲蚴蚶蚺蛀蛃蛄蛆' + '蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜎蜐蜒蜓蜕蜗' + '蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝘蝙蝠蝣蝤蝥蝮蝰蝲蝴蝶蝻蝼蝽蝾螂螃' + '螅螈螋融螗螟螠螣螨螫螬螭螯螱螳螵螺螽蟀蟆蟊蟋蟏蟑蟒蟛蟠蟥蟪蟫蟮蟹蟾蠃蠊蠋蠓蠕蠖蠡' + '蠢蠲蠹蠼血衃衄衅行衍衎衒衔街衙衠衡衢衣补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袆袈袋袍袒' + '袖袗袜袢袤袪被袭袯袱袷袼裁裂装裆裈裉裎裒裔裕裘裙裛裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂' + '褊褐褒褓褕褙褚褛褟褡褥褪褫褯褰褴褶襁襄襕襚襜襞襟襦襫襻西要覃覆见观觃规觅视觇览觉' + '觊觋觌觎觏觐觑角觖觚觜觞觟解觥触觫觭觯觱觳觿言訄訇訚訾詈詟詹誉誊誓謇警譬计订讣认' + '讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词' + '诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵请' + '诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谙谚谛谜谝谞谟谠谡' + '谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷谼谿豁豆豇豉豌豕豚象豢豨豪豫豮豳豸豹' + '豺貂貅貆貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼' + '贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赤' + '赦赧赪赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趯趱足趴趵趸趺趼趾趿跂跃跄跆跋跌跎跏跐跑' + '跖跗跚跛距跞跟跣跤跨跪跬路跱跳践跶跷跸跹跺跻跽踅踉踊踌踏踒踔踝踞踟踢踣踦踩踪踬踮' + '踯踱踵踶踹踺踽蹀蹁蹂蹄蹅蹇蹈蹉蹊蹋蹐蹑蹒蹙蹚蹜蹢蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹽蹾蹿躁躅躇' + '躏躐躔躜躞身躬躯躲躺车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较' + '辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱边辽达辿迁迂迄' + '迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹迺追退送适逃逄逅逆' + '选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逴逵逶逸逻逼逾遁遂遄遆遇遍遏遐遑遒' + '道遗遘遛遢遣遥遨遭遮遴遵遹遽避邀邂邃邈邋邑邓邕邗邘邙邛邝邠邡邢那邦邨邪邬邮邯邰邱' + '邲邳邴邵邶邸邹邺邻邽邾邿郁郃郄郅郇郈郊郎郏郐郑郓郗郚郛郜郝郡郢郤郦郧部郪郫郭郯郴' + '郸都郾郿鄀鄂鄃鄄鄅鄌鄑鄗鄘鄙鄚鄜鄞鄠鄢鄣鄫鄯鄱鄹酂酃酅酆酉酊酋酌配酎酏酐酒酗酚酝' + '酞酡酢酣酤酥酦酩酪酬酮酯酰酱酲酴酵酶酷酸酹酺酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醨醪醭' + '醮醯醴醵醺醾采釉释里重野量釐金釜鉴銎銮鋆鋈錾鍪鎏鏊鏖鐾鑫钆钇针钉钊钋钌钍钎钏钐钒' + '钓钔钕钖钗钘钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钷钹钺钻钼' + '钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铏铐铑铒铕铖铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铧铨' + '铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐' + '锑锒锓锔锕锖锗锘错锚锛锜锝锞锟锡锢锣锤锥锦锧锨锩锪锫锬锭键锯锰锱锲锳锴锵锶锷锸锹' + '锺锻锼锽锾锿镀镁镂镃镄镅镆镇镈镉镊镋镌镍镎镏镐镑镒镓镔镕镖镗镘镚镛镜镝镞镠镡镢镣' + '镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镵镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼' + '闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阘阙阚阜队阡阪阮阱防阳阴阵阶' + '阻阼阽阿陀陂附际陆陇陈陉陋陌降陎限陑陔陕陛陞陟陡院除陧陨险陪陬陲陴陵陶陷隃隅隆隈' + '隋隍随隐隔隗隘隙障隧隩隰隳隶隹隺隼隽难雀雁雄雅集雇雉雊雌雍雎雏雒雕雠雨雩雪雯雱雳' + '零雷雹雾需霁霄霅霆震霈霉霍霎霏霓霖霜霞霨霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靬靰' + '靳靴靶靸靺靼靽靿鞁鞅鞋鞍鞑鞒鞔鞘鞠鞡鞣鞧鞨鞫鞬鞭鞮鞯鞲鞳鞴韂韦韧韨韩韪韫韬韭音韵' + '韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颓颔颖颗题颙颚颛颜额' + '颞颟颠颡颢颤颥颦颧风飏飐飑飒飓飔飕飗飘飙飞食飧飨餍餐餮饔饕饥饧饨饩饪饫饬饭饮饯饰' + '饱饲饳饴饵饶饷饸饹饺饻饼饽饿馁馃馄馅馆馇馈馉馊馋馌馍馏馐馑馒馓馔馕首馗馘香馝馞馥' + '馧馨马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑' + '骒骓骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧骨骰骱骶骷骸骺骼髀髁髂髃髅髋髌髎髑髓高' + '髡髢髦髫髭髯髹髻髽鬃鬈鬏鬒鬓鬘鬟鬣鬯鬲鬶鬷鬻鬼魁魂魃魄魅魆魇魈魉魋魍魏魑魔鱼鱽鱾' + '鱿鲀鲁鲂鲃鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨' + '鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳇鳈鳉鳊鳌鳍鳎鳏鳐鳑鳒鳓' + '鳔鳕鳖鳗鳘鳙鳚鳛鳜鳝鳞鳟鳠鳡鳢鳣鳤鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸵鸶' + '鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹔鹕鹖鹗鹘鹙鹚鹛鹜鹝鹞鹟' + '鹠鹡鹢鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹮鹯鹰鹱鹲鹳鹴鹾鹿麀麂麇麈麋麑麒麓麖麝麟麦麸麹麻麽麾黄' + '黇黉黍黎黏黑黔默黛黜黝黟黠黡黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼒鼓鼗鼙鼠鼢鼩鼫鼬鼯鼱鼷' + '鼹鼻鼽鼾齁齇齉齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠龢鿍鿎鿏㑇㑊㕮㘎㙍㙘㙦㛃' + '㛚㛹㟃㠇㠓㤘㥄㧐㧑㧟㫰㬊㬎㬚㭎㭕㮾㰀㳇㳘㳚㴔㵐㶲㸆㸌㺄㻬㽏㿠䁖䂮䃅䃎䅟䌹䎃䎖䏝䏡' + '䏲䐃䓖䓛䓨䓫䓬䗖䗛䗪䗴䜣䝙䢺䢼䣘䥽䦃䲟䲠䲢䴓䴔䴕䴖䴗䴘䴙䶮𠅤𠙶𠳐𡎚𡐓𣗋𣲗𣲘𣸣𤧛𤩽' + '𤫉𥔲𥕢𥖨𥻗𦈡𦒍𦙶𦝼𦭜𦰡𧿹𨐈𨙸𨚕𨟠𨭉𨱇𨱏𨱑𨱔𨺙𩽾𩾃𩾌𪟝𪣻𪤗𪨰𪨶𪩘𪾢𫄧𫄨𫄷𫄸𫇭𫌀𫍣𫍯' + '𫍲𫍽𫐄𫐐𫐓𫑡𫓧𫓯𫓶𫓹𫔍𫔎𫔶𫖮𫖯𫖳𫗧𫗴𫘜𫘝𫘦𫘧𫘨𫘪𫘬𫚕𫚖𫚭𫛭𫞩𫟅𫟦𫟹𫟼𫠆𫠊𫠜𫢸𫫇𫭟' + '𫭢𫭼𫮃𫰛𫵷𫶇𫷷𫸩𬀩𬀪𬂩𬃊𬇕𬇙𬇹𬉼𬊈𬊤𬌗𬍛𬍡𬍤𬒈𬒔𬒗𬕂𬘓𬘘𬘡𬘩𬘫𬘬𬘭𬘯𬙂𬙊𬙋𬜬𬜯𬞟' + '𬟁𬟽𬣙𬣞𬣡𬣳𬤇𬤊𬤝𬨂𬨎𬩽𬪩𬬩𬬭𬬮𬬱𬬸𬬹𬬻𬬿𬭁𬭊𬭎𬭚𬭛𬭤𬭩𬭬𬭯𬭳𬭶𬭸𬭼𬮱𬮿𬯀𬯎𬱖𬱟' + '𬳵𬳶𬳽𬳿𬴂𬴃𬴊𬶋𬶍𬶏𬶐𬶟𬶠𬶨𬶭𬶮𬷕𬸘𬸚𬸣𬸦𬸪𬹼𬺈𬺓') + +VALID_CHARS = CN_CHARS + EN_CHARS + DIGIT_CHARS + ' ' +VALID_CHARS_MAP = {c: True for c in VALID_CHARS} # ================================================================================ # # basic class # ================================================================================ # -class ChineseChar(): +class ChineseChar(object): """ 中文字符 每个字符对应简体和繁体, @@ -67,6 +391,7 @@ class ChineseChar(): def __init__(self, simplified, traditional): self.simplified = simplified self.traditional = traditional + #self.__repr__ = self.__str__ def __str__(self): return self.simplified or self.traditional or None @@ -83,7 +408,7 @@ class ChineseNumberUnit(ChineseChar): """ def __init__(self, power, simplified, traditional, big_s, big_t): - super().__init__(simplified, traditional) + super(ChineseNumberUnit, self).__init__(simplified, traditional) self.power = power self.big_s = big_s self.big_t = big_t @@ -144,7 +469,7 @@ class ChineseNumberDigit(ChineseChar): big_t, alt_s=None, alt_t=None): - super().__init__(simplified, traditional) + super(ChineseNumberDigit, self).__init__(simplified, traditional) self.value = value self.big_s = big_s self.big_t = big_t @@ -165,7 +490,7 @@ class ChineseMath(ChineseChar): """ def __init__(self, simplified, traditional, symbol, expression=None): - super().__init__(simplified, traditional) + super(ChineseMath, self).__init__(simplified, traditional) self.symbol = symbol self.expression = expression self.big_s = simplified @@ -175,14 +500,14 @@ class ChineseMath(ChineseChar): CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath -class NumberSystem(): +class NumberSystem(object): """ 中文数字系统 """ pass -class MathSymbol(): +class MathSymbol(object): """ 用于中文数字系统的数学符号 (繁/简体), e.g. positive = ['正', '正'] @@ -200,7 +525,7 @@ class MathSymbol(): yield v -# class OtherSymbol(): +# class OtherSymbol(object): # """ # 其他符号 # """ @@ -366,17 +691,20 @@ def num2chn(number_string, use_zeros=True, use_units=True): def get_value(value_string, use_zeros=True): + striped_string = value_string.lstrip('0') # record nothing if all zeros if not striped_string: return [] + # record one digits elif len(striped_string) == 1: if use_zeros and len(value_string) != len(striped_string): return [system.digits[0], system.digits[int(striped_string)]] else: return [system.digits[int(striped_string)]] + # recursively record multiple digits else: result_unit = next( @@ -403,7 +731,6 @@ def num2chn(number_string, result_symbols = get_value(int_string) else: result_symbols = [system.digits[int(c)] for c in int_string] - dec_symbols = [system.digits[int(c)] for c in dec_string] if dec_string: result_symbols += [system.math.point] + dec_symbols @@ -418,13 +745,12 @@ def num2chn(number_string, previous_symbol = result_symbols[i - 1] if i > 0 else None if isinstance(next_symbol, CNU) and isinstance( previous_symbol, (CNU, type(None))): - # yapf: disable - if next_symbol.power != 1 and ((previous_symbol is None) or - (previous_symbol.power != 1)): + if next_symbol.power != 1 and ( + (previous_symbol is None) or + (previous_symbol.power != 1)): result_symbols[i] = liang - # yapf: enable - # if big is True, '两' will not be used and `alt_two` has no impact on output + # if big is True, '两' will not be used and `alt_two` has no impact on output if big: attr_name = 'big_' if traditional: @@ -516,6 +842,7 @@ class TelePhone: # return self.telephone def telephone2chntext(self, fixed=False): + if fixed: sil_parts = self.telephone.split('-') self.raw_chntext = ''.join([ @@ -592,7 +919,6 @@ class Date: except ValueError: other = date year = '' - if other: try: month, day = other.strip().split('月', 1) @@ -600,13 +926,11 @@ class Date: except ValueError: day = date month = '' - if day: day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] else: month = '' day = '' - chntext = year + month + day self.chntext = chntext return self.chntext @@ -782,6 +1106,52 @@ class NSWNormalizer: return self.norm_text.lstrip('^').rstrip('$') +# ================================================================================ # +# misc normalization functions +# ================================================================================ # +def remove_erhua(text): + """ + 去除儿化音词中的儿: + 他女儿在那边儿 -> 他女儿在那边 + """ + + new_str = '' + while re.search('儿', text): + a = re.search('儿', text).span() + remove_er_flag = 0 + + if ER_WHITELIST_PATTERN.search(text): + b = ER_WHITELIST_PATTERN.search(text).span() + if b[0] <= a[0]: + remove_er_flag = 1 + + if remove_er_flag == 0: + new_str = new_str + text[0:a[0]] + text = text[a[1]:] + else: + new_str = new_str + text[0:b[1]] + text = text[b[1]:] + + text = new_str + text + return text + + +def check_chars(text): + for c in text: + if not VALID_CHARS_MAP.get(c): + return c + return '' + + +def quanjiao2banjiao(text): + return text.translate(QJ2BJ_transform) + + +# ================================================================================ # +# testing +# ================================================================================ # + + def nsw_test_case(raw_text): print('I:' + raw_text) print('O:' + NSWNormalizer(raw_text).normalize()) @@ -806,89 +1176,234 @@ def nsw_test(): nsw_test_case('有62%的概率') +###################################################################################### + + +## Normalize unicode characters +def remove_weird_chars(text): + # ``` + # (NFKD) will apply the compatibility decomposition, i.e. + # replace all compatibility characters with their equivalents. + # ``` + text = unicodedata.normalize('NFKD', text).encode('utf-8', 'ignore').decode( + 'utf-8', 'ignore') + return text + + +## Remove extra linebreaks +def remove_extra_linebreaks(text): + lines = text.split(r'\n+') + return '\n'.join( + [re.sub(r'[\s]+', ' ', l).strip() for l in lines if len(l) != 0]) + + +## Remove extra medial/trailing/leading spaces +def remove_extra_spaces(text): + return re.sub("\\s+", " ", text).strip() + + +## Seg the text into words +def seg(text): + text_seg = jieba.cut(text) + out = ' '.join(text_seg) + return out + + +## Remove punctuation/symbols +def remove_symbols(text): + """ + + Unicode 6.0 has 7 character categories, and each category has subcategories: + + Letter (L): lowercase (Ll), modifier (Lm), titlecase (Lt), uppercase (Lu), other (Lo) + Mark (M): spacing combining (Mc), enclosing (Me), non-spacing (Mn) + Number (N): decimal digit (Nd), letter (Nl), other (No) + Punctuation (P): connector (Pc), dash (Pd), initial quote (Pi), final quote (Pf), open (Ps), close (Pe), other (Po) + Symbol (S): currency (Sc), modifier (Sk), math (Sm), other (So) + Separator (Z): line (Zl), paragraph (Zp), space (Zs) + Other (C): control (Cc), format (Cf), not assigned (Cn), private use (Co), surrogate (Cs) + + + There are 3 ranges reserved for private use (Co subcategory): + U+E000—U+F8FF (6,400 code points), U+F0000—U+FFFFD (65,534) and U+100000—U+10FFFD (65,534). + Surrogates (Cs subcategory) use the range U+D800—U+DFFF (2,048 code points). + + + """ + ## Brute-force version: list all possible unicode ranges, but this list is not complete. + # text = re.sub('[\u0021-\u002f\u003a-\u0040\u005b-\u0060\u007b-\u007e\u00a1-\u00bf\u2000-\u206f\u2013-\u204a\u20a0-\u20bf\u2100-\u214f\u2150-\u218b\u2190-\u21ff\u2200-\u22ff\u2300-\u23ff\u2460-\u24ff\u2500-\u257f\u2580-\u259f\u25a0-\u25ff\u2600-\u26ff\u2e00-\u2e7f\u3000-\u303f\ufe50-\ufe6f\ufe30-\ufe4f\ufe10-\ufe1f\uff00-\uffef─◆╱]+','',text) + + text = ''.join( + ch for ch in text if unicodedata.category(ch)[0] not in ['P', 'S']) + return text + + +## Remove numbers +def remove_numbers(text): + return re.sub('\\d+', "", text) + + +## Remove alphabets +def remove_alphabets(text): + return re.sub('[a-zA-Z]+', '', text) + + +## Combine every step +def normalize_corpus(corpus, + is_remove_extra_linebreaks=True, + is_remove_weird_chars=True, + is_seg=True, + is_remove_symbols=True, + is_remove_numbers=True, + is_remove_alphabets=True): + + normalized_corpus = [] + # normalize each document in the corpus + for doc in corpus: + + if is_remove_extra_linebreaks: + doc = remove_extra_linebreaks(doc) + + if is_remove_weird_chars: + doc = remove_weird_chars(doc) + + if is_seg: + doc = seg(doc) + + if is_remove_symbols: + doc = remove_symbols(doc) + + if is_remove_alphabets: + doc = remove_alphabets(doc) + + if is_remove_numbers: + doc = remove_numbers(doc) + + normalized_corpus.append(remove_extra_spaces(doc)) + + return normalized_corpus + + +###################################################################################### + + def char_token(s: Text) -> List[Text]: """chinese charactor - Args: - s (Text): [description] + s (Text): "我爱中国“ Returns: - List[Text]: [description] + List[Text]: ['我', '爱', '中', '国'] """ return list(s) def word_token(s: Text) -> List[Text]: """chinese word - Args: - s (Text): [description] + s (Text): "我爱中国“ Returns: - List[Text]: [description] + List[Text]: ['我', '爱', '中国'] """ return jieba.lcut(s) -def text_process(s: Text) -> Text: +def find_chinese(file): + pattern = re.compile(r'[^\u4e00-\u9fa5]') + chinese = re.sub(pattern, '', file) + return chinese + + +def text_process(text: Text, args) -> Text: """do chinese text normaliztion + 1. remove * + 2. NWS + 3. remove puncuation + 4. remove english Args: - s (Text): [description] + text (Text): [description] Returns: Text: [description] """ - s = s.replace('*', '') + # strip + text = text.strip() + text = remove_extra_linebreaks(text) + text = remove_weird_chars(text) + text = remove_extra_spaces(text) + + # quanjiao -> banjiao + if args.to_banjiao: + text = quanjiao2banjiao(text) + + # Unify upper/lower cases + if args.to_upper: + text = text.upper() + if args.to_lower: + text = text.lower() + + # Remove filler chars + if args.remove_fillers: + for c in FILLER_CHARS: + text = text.replace(c, '') + + if args.remove_erhua: + text = remove_erhua(text) + + text = text.replace('*', '') + # NSW(Non-Standard-Word) normalization - s = NSWNormalizer(s).normalize() + text = NSWNormalizer(text).normalize() + if len(text) == 0: + exit(-1) + # Punctuations removal - s = re.sub(f'[{hanzi.punctuation}{string.punctuation}]', "", s) + text = re.sub(f'[{hanzi.punctuation}{string.punctuation}]', "", text) + + # Remove punctuations + old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations + new_chars = ' ' * len(old_chars) + del_chars = '' + text = text.translate(str.maketrans(old_chars, new_chars, del_chars)) + # rm english - s = ''.join(re.findall(hanzi.sent, s)) - return s + text = find_chinese(text) + + # Remove space + if args.remove_space: + text = text.replace(' ', '') + + return text def main(infile, outfile, args): # tokenizer token_type = args.token_type - if token_type == 'char': + if token_type.lower() == 'char': tokenizer = char_token - elif token_type == 'word': + elif token_type.lower() == 'word': tokenizer = word_token else: tokenizer = None with open(infile, 'rt') as fin, open(outfile, 'wt') as fout: - lines = fin.readlines() - n = 0 - for l in lines: - key = '' - text = '' + ndone = 0 + for line in fin: + line = line.strip() + key, text = '', '' if args.has_key: - cols = l.split(maxsplit=1) + cols = line.split(maxsplit=1) key = cols[0] - if len(cols) == 2: - text = cols[1] - else: - text = '' + text = cols[1] if len(cols) == 2 else '' else: - text = l - - # strip - text = text.strip() - # cases - if args.to_upper and args.to_lower: - sys.stderr.write('to_upper OR to_lower?') - exit(1) - if args.to_upper: - text = text.upper() - if args.to_lower: - text = text.lower() - - # Normalization - text = text_process(text) + text = line + + text = text_process(text, args) + + # word segment: chinese char/word if tokenizer: text = ' '.join(tokenizer(text)) @@ -899,29 +1414,56 @@ def main(infile, outfile, args): ) != '': # skip empty line in pure text format(without Kaldi's utt key) fout.write(text + '\n') - n += 1 - if n % args.log_interval == 0: - print(f"process {n} lines.", file=sys.stderr) + ndone += 1 + if ndone % args.log_interval == 0: + print( + f'text norm: {ndone} lines done.', + file=sys.stderr, + flush=True) + + print( + f'text norm: {ndone} lines done in total.', + file=sys.stderr, + flush=True) if __name__ == '__main__': p = argparse.ArgumentParser() - p.add_argument('token_type', default=None, help='token type. [char|word]') - p.add_argument('ifile', help='input filename, assume utf-8 encoding') - p.add_argument('ofile', help='output filename') - p.add_argument( - '--to_upper', action='store_true', help='convert to upper case') - p.add_argument( - '--to_lower', action='store_true', help='convert to lower case') + p.add_argument('--token_type', default=None, help='token type. [char|word]') p.add_argument( '--has_key', - action='store_true', + default=False, help="input text has Kaldi's key as first field.") p.add_argument( '--log_interval', type=int, - default=100000, + default=10000, help='log interval in number of processed lines') - args = p.parse_args() + p.add_argument( + '--to_banjiao', + action='store_true', + help='convert quanjiao chars to banjiao') + p.add_argument( + '--to_upper', action='store_true', help='convert to upper case') + p.add_argument( + '--to_lower', action='store_true', help='convert to lower case') + p.add_argument( + '--remove_fillers', + action='store_true', + help='remove filler chars such as "呃, 啊"') + p.add_argument( + '--remove_erhua', + action='store_true', + help='remove erhua chars such as "他女儿在那边儿 -> 他女儿在那边"') + p.add_argument( + '--check_chars', + action='store_true', + help='skip sentences containing illegal chars') + p.add_argument( + '--remove_space', action='store_true', help='remove whitespace') + p.add_argument('ifile', help='input filename, assume utf-8 encoding') + p.add_argument('ofile', help='output filename') + args = p.parse_args() + print(args) main(args.ifile, args.ofile, args)