未验证 提交 4a11257d 编写于 作者: L liangym 提交者: GitHub

Merge branch 'develop' into update_engine

...@@ -19,7 +19,7 @@ import subprocess ...@@ -19,7 +19,7 @@ import subprocess
import platform import platform
COPYRIGHT = ''' COPYRIGHT = '''
Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
| <a href="#documents"> Documents </a> | <a href="#documents"> Documents </a>
| <a href="#model-list"> Models List </a> | <a href="#model-list"> Models List </a>
| <a href="https://aistudio.baidu.com/aistudio/education/group/info/25130"> AIStudio Courses </a> | <a href="https://aistudio.baidu.com/aistudio/education/group/info/25130"> AIStudio Courses </a>
| <a href="https://arxiv.org/abs/2205.12007"> Paper </a>
| <a href="https://gitee.com/paddlepaddle/PaddleSpeech"> Gitee </a>
</h4> </h4>
</div> </div>
......
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
| <a href="#教程文档"> 教程文档 </a> | <a href="#教程文档"> 教程文档 </a>
| <a href="#模型列表"> 模型列表 </a> | <a href="#模型列表"> 模型列表 </a>
| <a href="https://aistudio.baidu.com/aistudio/education/group/info/25130"> AIStudio 课程 </a> | <a href="https://aistudio.baidu.com/aistudio/education/group/info/25130"> AIStudio 课程 </a>
| <a href="https://arxiv.org/abs/2205.12007"> 论文 </a>
| <a href="https://gitee.com/paddlepaddle/PaddleSpeech"> Gitee
</h4> </h4>
</div> </div>
......
...@@ -11,6 +11,5 @@ ...@@ -11,6 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .dtw import dtw_distance
from .eer import compute_eer from .eer import compute_eer
from .eer import compute_minDCF from .eer import compute_minDCF
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from dtaidistance import dtw_ndim
__all__ = [
'dtw_distance',
]
def dtw_distance(xs: np.ndarray, ys: np.ndarray) -> float:
"""Dynamic Time Warping.
This function keeps a compact matrix, not the full warping paths matrix.
Uses dynamic programming to compute:
Examples:
.. code-block:: python
wps[i, j] = (s1[i]-s2[j])**2 + min(
wps[i-1, j ] + penalty, // vertical / insertion / expansion
wps[i , j-1] + penalty, // horizontal / deletion / compression
wps[i-1, j-1]) // diagonal / match
dtw = sqrt(wps[-1, -1])
Args:
xs (np.ndarray): ref sequence, [T,D]
ys (np.ndarray): hyp sequence, [T,D]
Returns:
float: dtw distance
"""
return dtw_ndim.distance(xs, ys)
...@@ -83,7 +83,7 @@ setuptools.setup( ...@@ -83,7 +83,7 @@ setuptools.setup(
python_requires='>=3.6', python_requires='>=3.6',
install_requires=[ install_requires=[
'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2', 'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2',
'soundfile >= 0.9.0', 'colorlog', 'dtaidistance == 2.3.1', 'pathos' 'soundfile >= 0.9.0', 'colorlog', 'pathos == 0.2.8'
], ],
extras_require={ extras_require={
'test': [ 'test': [
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
([简体中文](./README_cn.md)|English) ([简体中文](./README_cn.md)|English)
The directory containes many speech applications in multi scenarios. This directory contains many speech applications in multiple scenarios.
* audio searching - mass audio similarity retrieval * audio searching - mass audio similarity retrieval
* audio tagging - multi-label tagging of an audio file * audio tagging - multi-label tagging of an audio file
* automatic_video_subtitiles - generate subtitles from a video * automatic_video_subtitles - generate subtitles from a video
* metaverse - 2D AR with TTS * metaverse - 2D AR with TTS
* punctuation_restoration - restore punctuation from raw text * punctuation_restoration - restore punctuation from raw text
* speech recogintion - recognize text of an audio file * speech recognition - recognize text of an audio file
* speech server - Server for Speech Task, e.g. ASR,TTS,CLS * speech server - Server for Speech Task, e.g. ASR,TTS,CLS
* streaming asr server - receive audio stream from websocket, and recognize to transcript. * streaming asr server - receive audio stream from websocket, and recognize to transcript.
* speech translation - end to end speech translation * speech translation - end to end speech translation
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import numpy as np import numpy as np
from logs import LOGGER from logs import LOGGER
from paddlespeech.cli import VectorExecutor from paddlespeech.cli.vector import VectorExecutor
vector_executor = VectorExecutor() vector_executor = VectorExecutor()
......
...@@ -57,7 +57,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe ...@@ -57,7 +57,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import CLSExecutor from paddlespeech.cli.cls import CLSExecutor
cls_executor = CLSExecutor() cls_executor = CLSExecutor()
result = cls_executor( result = cls_executor(
......
...@@ -57,7 +57,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe ...@@ -57,7 +57,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespe
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import CLSExecutor from paddlespeech.cli.cls import CLSExecutor
cls_executor = CLSExecutor() cls_executor = CLSExecutor()
result = cls_executor( result = cls_executor(
......
...@@ -28,7 +28,8 @@ ffmpeg -i subtitle_demo1.mp4 -ac 1 -ar 16000 -vn input.wav ...@@ -28,7 +28,8 @@ ffmpeg -i subtitle_demo1.mp4 -ac 1 -ar 16000 -vn input.wav
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import ASRExecutor, TextExecutor from paddlespeech.cli.asr import ASRExecutor
from paddlespeech.cli.text import TextExecutor
asr_executor = ASRExecutor() asr_executor = ASRExecutor()
text_executor = TextExecutor() text_executor = TextExecutor()
......
...@@ -23,7 +23,8 @@ ffmpeg -i subtitle_demo1.mp4 -ac 1 -ar 16000 -vn input.wav ...@@ -23,7 +23,8 @@ ffmpeg -i subtitle_demo1.mp4 -ac 1 -ar 16000 -vn input.wav
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import ASRExecutor, TextExecutor from paddlespeech.cli.asr import ASRExecutor
from paddlespeech.cli.text import TextExecutor
asr_executor = ASRExecutor() asr_executor = ASRExecutor()
text_executor = TextExecutor() text_executor = TextExecutor()
......
...@@ -16,8 +16,8 @@ import os ...@@ -16,8 +16,8 @@ import os
import paddle import paddle
from paddlespeech.cli import ASRExecutor from paddlespeech.cli.asr import ASRExecutor
from paddlespeech.cli import TextExecutor from paddlespeech.cli.text import TextExecutor
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
......
...@@ -42,7 +42,7 @@ The input of this demo should be a text of the specific language that can be pas ...@@ -42,7 +42,7 @@ The input of this demo should be a text of the specific language that can be pas
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import TextExecutor from paddlespeech.cli.text import TextExecutor
text_executor = TextExecutor() text_executor = TextExecutor()
result = text_executor( result = text_executor(
......
...@@ -44,7 +44,7 @@ ...@@ -44,7 +44,7 @@
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import TextExecutor from paddlespeech.cli.text import TextExecutor
text_executor = TextExecutor() text_executor = TextExecutor()
result = text_executor( result = text_executor(
......
...@@ -96,7 +96,7 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav ...@@ -96,7 +96,7 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
- Python API - Python API
```python ```python
from paddlespeech.cli import VectorExecutor from paddlespeech.cli.vector import VectorExecutor
vector_executor = VectorExecutor() vector_executor = VectorExecutor()
audio_emb = vector_executor( audio_emb = vector_executor(
......
...@@ -95,7 +95,7 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav ...@@ -95,7 +95,7 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import VectorExecutor from paddlespeech.cli.vector import VectorExecutor
vector_executor = VectorExecutor() vector_executor = VectorExecutor()
audio_emb = vector_executor( audio_emb = vector_executor(
......
...@@ -58,7 +58,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -58,7 +58,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import ASRExecutor from paddlespeech.cli.asr import ASRExecutor
asr_executor = ASRExecutor() asr_executor = ASRExecutor()
text = asr_executor( text = asr_executor(
......
...@@ -56,7 +56,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -56,7 +56,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import ASRExecutor from paddlespeech.cli.asr import ASRExecutor
asr_executor = ASRExecutor() asr_executor = ASRExecutor()
text = asr_executor( text = asr_executor(
......
...@@ -47,7 +47,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -47,7 +47,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import STExecutor from paddlespeech.cli.st import STExecutor
st_executor = STExecutor() st_executor = STExecutor()
text = st_executor( text = st_executor(
......
...@@ -47,7 +47,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -47,7 +47,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import STExecutor from paddlespeech.cli.st import STExecutor
st_executor = STExecutor() st_executor = STExecutor()
text = st_executor( text = st_executor(
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# SERVER SETTING # # SERVER SETTING #
################################################################################# #################################################################################
host: 0.0.0.0 host: 0.0.0.0
port: 8090 port: 8091
# The task format in the engin_list is: <speech task>_<engine type> # The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online'] # task choices = ['asr_online']
......
...@@ -77,7 +77,7 @@ The input of this demo should be a text of the specific language that can be pas ...@@ -77,7 +77,7 @@ The input of this demo should be a text of the specific language that can be pas
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import TTSExecutor from paddlespeech.cli.tts import TTSExecutor
tts_executor = TTSExecutor() tts_executor = TTSExecutor()
wav_file = tts_executor( wav_file = tts_executor(
......
...@@ -80,7 +80,7 @@ ...@@ -80,7 +80,7 @@
- Python API - Python API
```python ```python
import paddle import paddle
from paddlespeech.cli import TTSExecutor from paddlespeech.cli.tts import TTSExecutor
tts_executor = TTSExecutor() tts_executor = TTSExecutor()
wav_file = tts_executor( wav_file = tts_executor(
......
...@@ -88,11 +88,3 @@ ECAPA-TDNN | VoxCeleb| [voxceleb_ecapatdnn](https://github.com/PaddlePaddle/Padd ...@@ -88,11 +88,3 @@ ECAPA-TDNN | VoxCeleb| [voxceleb_ecapatdnn](https://github.com/PaddlePaddle/Padd
Model Type | Dataset| Example Link | Pretrained Models Model Type | Dataset| Example Link | Pretrained Models
:-------------:| :------------:| :-----: | :-----: :-------------:| :------------:| :-----: | :-----:
Ernie Linear | IWLST2012_zh |[iwslt2012_punc0](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/iwslt2012/punc0)|[ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip) Ernie Linear | IWLST2012_zh |[iwslt2012_punc0](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/iwslt2012/punc0)|[ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_iwslt2012_zh_ckpt_0.1.1.zip)
## Speech Recognition Model from paddle 1.8
| Acoustic Model |Training Data| Token-based | Size | Descriptions | CER | WER | Hours of speech |
| :-----:| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
| [Ds2 Offline Aishell model](https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz) | Aishell Dataset | Char-based | 234 MB | 2 Conv + 3 bidirectional GRU layers | 0.0804 | — | 151 h |
| [Ds2 Offline Librispeech model](https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz) | Librispeech Dataset | Word-based | 307 MB | 2 Conv + 3 bidirectional sharing weight RNN layers | — | 0.0685 | 960 h |
| [Ds2 Offline Baidu en8k model](https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz) | Baidu Internal English Dataset | Word-based | 273 MB | 2 Conv + 3 bidirectional GRU layers |— | 0.0541 | 8628 h|
...@@ -113,12 +113,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ...@@ -113,12 +113,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p
``` ```
```text ```text
usage: synthesize.py [-h] usage: synthesize.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT]
[--voice-cloning VOICE_CLONING] [--voice-cloning VOICE_CLONING]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--ngpu NGPU] [--voc_stat VOC_STAT] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
...@@ -127,11 +127,10 @@ Synthesize with acoustic model & vocoder ...@@ -127,11 +127,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -143,10 +142,10 @@ optional arguments: ...@@ -143,10 +142,10 @@ optional arguments:
speaker id map file. speaker id map file.
--voice-cloning VOICE_CLONING --voice-cloning VOICE_CLONING
whether training voice cloning model. whether training voice cloning model.
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -162,12 +161,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ...@@ -162,12 +161,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp
``` ```
```text ```text
usage: synthesize_e2e.py [-h] usage: synthesize_e2e.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--tones_dict TONES_DICT]
[--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--lang LANG] [--voc_stat VOC_STAT] [--lang LANG]
[--inference_dir INFERENCE_DIR] [--ngpu NGPU] [--inference_dir INFERENCE_DIR] [--ngpu NGPU]
...@@ -177,11 +176,10 @@ Synthesize with acoustic model & vocoder ...@@ -177,11 +176,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -192,10 +190,10 @@ optional arguments: ...@@ -192,10 +190,10 @@ optional arguments:
--speaker_dict SPEAKER_DICT --speaker_dict SPEAKER_DICT
speaker id map file. speaker id map file.
--spk_id SPK_ID spk id for multi speaker acoustic model --spk_id SPK_ID spk id for multi speaker acoustic model
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -208,9 +206,9 @@ optional arguments: ...@@ -208,9 +206,9 @@ optional arguments:
output dir. output dir.
``` ```
1. `--am` is acoustic model type with the format {model_name}_{dataset} 1. `--am` is acoustic model type with the format {model_name}_{dataset}
2. `--am_config`, `--am_checkpoint`, `--am_stat`, `--phones_dict` `--speaker_dict` are arguments for acoustic model, which correspond to the 5 files in the fastspeech2 pretrained model. 2. `--am_config`, `--am_ckpt`, `--am_stat`, `--phones_dict` `--speaker_dict` are arguments for acoustic model, which correspond to the 5 files in the fastspeech2 pretrained model.
3. `--voc` is vocoder type with the format {model_name}_{dataset} 3. `--voc` is vocoder type with the format {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model.
5. `--lang` is the model language, which can be `zh` or `en`. 5. `--lang` is the model language, which can be `zh` or `en`.
6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
7. `--text` is the text file, which contains sentences to synthesize. 7. `--text` is the text file, which contains sentences to synthesize.
......
...@@ -68,7 +68,7 @@ Train a ParallelWaveGAN model. ...@@ -68,7 +68,7 @@ Train a ParallelWaveGAN model.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG config file to overwrite default config. --config CONFIG ParallelWaveGAN config file.
--train-metadata TRAIN_METADATA --train-metadata TRAIN_METADATA
training data. training data.
--dev-metadata DEV_METADATA --dev-metadata DEV_METADATA
......
...@@ -59,15 +59,13 @@ Here's the complete help message. ...@@ -59,15 +59,13 @@ Here's the complete help message.
```text ```text
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
[--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] [--ngpu NGPU]
[--run-benchmark RUN_BENCHMARK]
[--profiler_options PROFILER_OPTIONS]
Train a ParallelWaveGAN model. Train a HiFiGAN model.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG config file to overwrite default config. --config CONFIG HiFiGAN config file.
--train-metadata TRAIN_METADATA --train-metadata TRAIN_METADATA
training data. training data.
--dev-metadata DEV_METADATA --dev-metadata DEV_METADATA
...@@ -75,19 +73,6 @@ optional arguments: ...@@ -75,19 +73,6 @@ optional arguments:
--output-dir OUTPUT_DIR --output-dir OUTPUT_DIR
output dir. output dir.
--ngpu NGPU if ngpu == 0, use cpu. --ngpu NGPU if ngpu == 0, use cpu.
benchmark:
arguments related to benchmark.
--batch-size BATCH_SIZE
batch size.
--max-iter MAX_ITER train max steps.
--run-benchmark RUN_BENCHMARK
runing benchmark or not, if True, use the --batch-size
and --max-iter.
--profiler_options PROFILER_OPTIONS
The option of profiler, which should be in format
"key1=value1;key2=value2;key3=value3".
``` ```
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. 1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
......
...@@ -103,12 +103,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ...@@ -103,12 +103,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p
``` ```
```text ```text
usage: synthesize.py [-h] usage: synthesize.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc}] [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT]
[--voice-cloning VOICE_CLONING] [--voice-cloning VOICE_CLONING]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--ngpu NGPU] [--voc_stat VOC_STAT] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
...@@ -117,11 +117,10 @@ Synthesize with acoustic model & vocoder ...@@ -117,11 +117,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc} --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -133,10 +132,10 @@ optional arguments: ...@@ -133,10 +132,10 @@ optional arguments:
speaker id map file. speaker id map file.
--voice-cloning VOICE_CLONING --voice-cloning VOICE_CLONING
whether training voice cloning model. whether training voice cloning model.
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -152,12 +151,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ...@@ -152,12 +151,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp
``` ```
```text ```text
usage: synthesize_e2e.py [-h] usage: synthesize_e2e.py [-h]
[--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc}] [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--tones_dict TONES_DICT]
[--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--lang LANG] [--voc_stat VOC_STAT] [--lang LANG]
[--inference_dir INFERENCE_DIR] [--ngpu NGPU] [--inference_dir INFERENCE_DIR] [--ngpu NGPU]
...@@ -167,11 +166,10 @@ Synthesize with acoustic model & vocoder ...@@ -167,11 +166,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc} --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -182,10 +180,10 @@ optional arguments: ...@@ -182,10 +180,10 @@ optional arguments:
--speaker_dict SPEAKER_DICT --speaker_dict SPEAKER_DICT
speaker id map file. speaker id map file.
--spk_id SPK_ID spk id for multi speaker acoustic model --spk_id SPK_ID spk id for multi speaker acoustic model
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -198,9 +196,9 @@ optional arguments: ...@@ -198,9 +196,9 @@ optional arguments:
output dir. output dir.
``` ```
1. `--am` is acoustic model type with the format {model_name}_{dataset} 1. `--am` is acoustic model type with the format {model_name}_{dataset}
2. `--am_config`, `--am_checkpoint`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the Tacotron2 pretrained model. 2. `--am_config`, `--am_ckpt`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the Tacotron2 pretrained model.
3. `--voc` is vocoder type with the format {model_name}_{dataset} 3. `--voc` is vocoder type with the format {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model.
5. `--lang` is the model language, which can be `zh` or `en`. 5. `--lang` is the model language, which can be `zh` or `en`.
6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
7. `--text` is the text file, which contains sentences to synthesize. 7. `--text` is the text file, which contains sentences to synthesize.
......
...@@ -109,12 +109,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ...@@ -109,12 +109,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p
``` ```
```text ```text
usage: synthesize.py [-h] usage: synthesize.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT]
[--voice-cloning VOICE_CLONING] [--voice-cloning VOICE_CLONING]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--ngpu NGPU] [--voc_stat VOC_STAT] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
...@@ -123,11 +123,10 @@ Synthesize with acoustic model & vocoder ...@@ -123,11 +123,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -139,10 +138,10 @@ optional arguments: ...@@ -139,10 +138,10 @@ optional arguments:
speaker id map file. speaker id map file.
--voice-cloning VOICE_CLONING --voice-cloning VOICE_CLONING
whether training voice cloning model. whether training voice cloning model.
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -158,12 +157,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ...@@ -158,12 +157,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp
``` ```
```text ```text
usage: synthesize_e2e.py [-h] usage: synthesize_e2e.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--tones_dict TONES_DICT]
[--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--lang LANG] [--voc_stat VOC_STAT] [--lang LANG]
[--inference_dir INFERENCE_DIR] [--ngpu NGPU] [--inference_dir INFERENCE_DIR] [--ngpu NGPU]
...@@ -173,11 +172,10 @@ Synthesize with acoustic model & vocoder ...@@ -173,11 +172,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -188,10 +186,10 @@ optional arguments: ...@@ -188,10 +186,10 @@ optional arguments:
--speaker_dict SPEAKER_DICT --speaker_dict SPEAKER_DICT
speaker id map file. speaker id map file.
--spk_id SPK_ID spk id for multi speaker acoustic model --spk_id SPK_ID spk id for multi speaker acoustic model
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -204,9 +202,9 @@ optional arguments: ...@@ -204,9 +202,9 @@ optional arguments:
output dir. output dir.
``` ```
1. `--am` is acoustic model type with the format {model_name}_{dataset} 1. `--am` is acoustic model type with the format {model_name}_{dataset}
2. `--am_config`, `--am_checkpoint`, `--am_stat`, `--phones_dict` and `--tones_dict` are arguments for acoustic model, which correspond to the 5 files in the speedyspeech pretrained model. 2. `--am_config`, `--am_ckpt`, `--am_stat`, `--phones_dict` and `--tones_dict` are arguments for acoustic model, which correspond to the 5 files in the speedyspeech pretrained model.
3. `--voc` is vocoder type with the format {model_name}_{dataset} 3. `--voc` is vocoder type with the format {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model.
5. `--lang` is the model language, which can be `zh` or `en`. 5. `--lang` is the model language, which can be `zh` or `en`.
6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
7. `--text` is the text file, which contains sentences to synthesize. 7. `--text` is the text file, which contains sentences to synthesize.
......
...@@ -111,12 +111,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ...@@ -111,12 +111,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p
``` ```
```text ```text
usage: synthesize.py [-h] usage: synthesize.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT]
[--voice-cloning VOICE_CLONING] [--voice-cloning VOICE_CLONING]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--ngpu NGPU] [--voc_stat VOC_STAT] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
...@@ -125,11 +125,10 @@ Synthesize with acoustic model & vocoder ...@@ -125,11 +125,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -141,10 +140,10 @@ optional arguments: ...@@ -141,10 +140,10 @@ optional arguments:
speaker id map file. speaker id map file.
--voice-cloning VOICE_CLONING --voice-cloning VOICE_CLONING
whether training voice cloning model. whether training voice cloning model.
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -160,12 +159,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ...@@ -160,12 +159,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp
``` ```
```text ```text
usage: synthesize_e2e.py [-h] usage: synthesize_e2e.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--tones_dict TONES_DICT]
[--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--lang LANG] [--voc_stat VOC_STAT] [--lang LANG]
[--inference_dir INFERENCE_DIR] [--ngpu NGPU] [--inference_dir INFERENCE_DIR] [--ngpu NGPU]
...@@ -175,11 +174,10 @@ Synthesize with acoustic model & vocoder ...@@ -175,11 +174,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -190,10 +188,10 @@ optional arguments: ...@@ -190,10 +188,10 @@ optional arguments:
--speaker_dict SPEAKER_DICT --speaker_dict SPEAKER_DICT
speaker id map file. speaker id map file.
--spk_id SPK_ID spk id for multi speaker acoustic model --spk_id SPK_ID spk id for multi speaker acoustic model
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -204,11 +202,12 @@ optional arguments: ...@@ -204,11 +202,12 @@ optional arguments:
--text TEXT text to synthesize, a 'utt_id sentence' pair per line. --text TEXT text to synthesize, a 'utt_id sentence' pair per line.
--output_dir OUTPUT_DIR --output_dir OUTPUT_DIR
output dir. output dir.
``` ```
1. `--am` is acoustic model type with the format {model_name}_{dataset} 1. `--am` is acoustic model type with the format {model_name}_{dataset}
2. `--am_config`, `--am_checkpoint`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the fastspeech2 pretrained model. 2. `--am_config`, `--am_ckpt`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the fastspeech2 pretrained model.
3. `--voc` is vocoder type with the format {model_name}_{dataset} 3. `--voc` is vocoder type with the format {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model.
5. `--lang` is the model language, which can be `zh` or `en`. 5. `--lang` is the model language, which can be `zh` or `en`.
6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
7. `--text` is the text file, which contains sentences to synthesize. 7. `--text` is the text file, which contains sentences to synthesize.
......
...@@ -117,12 +117,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ...@@ -117,12 +117,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p
``` ```
```text ```text
usage: synthesize.py [-h] usage: synthesize.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT]
[--voice-cloning VOICE_CLONING] [--voice-cloning VOICE_CLONING]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--ngpu NGPU] [--voc_stat VOC_STAT] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
...@@ -131,11 +131,10 @@ Synthesize with acoustic model & vocoder ...@@ -131,11 +131,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -147,10 +146,10 @@ optional arguments: ...@@ -147,10 +146,10 @@ optional arguments:
speaker id map file. speaker id map file.
--voice-cloning VOICE_CLONING --voice-cloning VOICE_CLONING
whether training voice cloning model. whether training voice cloning model.
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -167,12 +166,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ...@@ -167,12 +166,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp
``` ```
```text ```text
usage: synthesize_e2e.py [-h] usage: synthesize_e2e.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--tones_dict TONES_DICT]
[--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--lang LANG] [--voc_stat VOC_STAT] [--lang LANG]
[--inference_dir INFERENCE_DIR] [--ngpu NGPU] [--inference_dir INFERENCE_DIR] [--ngpu NGPU]
...@@ -182,11 +181,10 @@ Synthesize with acoustic model & vocoder ...@@ -182,11 +181,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -197,10 +195,10 @@ optional arguments: ...@@ -197,10 +195,10 @@ optional arguments:
--speaker_dict SPEAKER_DICT --speaker_dict SPEAKER_DICT
speaker id map file. speaker id map file.
--spk_id SPK_ID spk id for multi speaker acoustic model --spk_id SPK_ID spk id for multi speaker acoustic model
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -213,9 +211,9 @@ optional arguments: ...@@ -213,9 +211,9 @@ optional arguments:
output dir. output dir.
``` ```
1. `--am` 声学模型格式是否符合 {model_name}_{dataset} 1. `--am` 声学模型格式是否符合 {model_name}_{dataset}
2. `--am_config`, `--am_checkpoint`, `--am_stat``--phones_dict` 是声学模型的参数,对应于 fastspeech2 预训练模型中的 4 个文件。 2. `--am_config`, `--am_ckpt`, `--am_stat``--phones_dict` 是声学模型的参数,对应于 fastspeech2 预训练模型中的 4 个文件。
3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset} 3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。 4. `--voc_config`, `--voc_ckpt`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。
5. `--lang` 对应模型的语言可以是 `zh``en` 5. `--lang` 对应模型的语言可以是 `zh``en`
6. `--test_metadata` 应为 `dump` 文件夹中 `test` 下的规范化元数据文件、 6. `--test_metadata` 应为 `dump` 文件夹中 `test` 下的规范化元数据文件、
7. `--text` 是文本文件,其中包含要合成的句子。 7. `--text` 是文本文件,其中包含要合成的句子。
......
# VITS with CSMSC
This example contains code used to train a [VITS](https://arxiv.org/abs/2106.06103) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html).
## Dataset
### Download and Extract
Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/source).
### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for VITS, the durations of MFA are not needed here.
You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo.
## Get Started
Assume the path to the dataset is `~/datasets/BZNSYP`.
Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`.
Run the command below to
1. **source path**.
2. preprocess the dataset.
3. train the model.
4. synthesize wavs.
- synthesize waveform from `metadata.jsonl`.
- synthesize waveform from a text file.
```bash
./run.sh
```
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
```bash
./run.sh --stage 0 --stop-stage 0
```
### Data Preprocessing
```bash
./local/preprocess.sh ${conf_path}
```
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
```text
dump
├── dev
│   ├── norm
│   └── raw
├── phone_id_map.txt
├── speaker_id_map.txt
├── test
│   ├── norm
│   └── raw
└── train
├── feats_stats.npy
├── norm
└── raw
```
The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains wave and linear spectrogram of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/feats_stats.npy`.
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains phones, text_lengths, feats, feats_lengths, the path of linear spectrogram features, the path of raw waves, speaker, and the id of each utterance.
### Model Training
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
```
`./local/train.sh` calls `${BIN_DIR}/train.py`.
Here's the complete help message.
```text
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
[--ngpu NGPU] [--phones-dict PHONES_DICT]
Train a VITS model.
optional arguments:
-h, --help show this help message and exit
--config CONFIG config file to overwrite default config.
--train-metadata TRAIN_METADATA
training data.
--dev-metadata DEV_METADATA
dev data.
--output-dir OUTPUT_DIR
output dir.
--ngpu NGPU if ngpu == 0, use cpu.
--phones-dict PHONES_DICT
phone vocabulary file.
```
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
5. `--phones-dict` is the path of the phone vocabulary file.
### Synthesizing
`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
```
```text
usage: synthesize.py [-h] [--config CONFIG] [--ckpt CKPT]
[--phones_dict PHONES_DICT] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
Synthesize with VITS
optional arguments:
-h, --help show this help message and exit
--config CONFIG Config of VITS.
--ckpt CKPT Checkpoint file of VITS.
--phones_dict PHONES_DICT
phone vocabulary file.
--ngpu NGPU if ngpu == 0, use cpu.
--test_metadata TEST_METADATA
test metadata.
--output_dir OUTPUT_DIR
output dir.
```
`./local/synthesize_e2e.sh` calls `${BIN_DIR}/synthesize_e2e.py`, which can synthesize waveform from text file.
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name}
```
```text
usage: synthesize_e2e.py [-h] [--config CONFIG] [--ckpt CKPT]
[--phones_dict PHONES_DICT] [--lang LANG]
[--inference_dir INFERENCE_DIR] [--ngpu NGPU]
[--text TEXT] [--output_dir OUTPUT_DIR]
Synthesize with VITS
optional arguments:
-h, --help show this help message and exit
--config CONFIG Config of VITS.
--ckpt CKPT Checkpoint file of VITS.
--phones_dict PHONES_DICT
phone vocabulary file.
--lang LANG Choose model language. zh or en
--inference_dir INFERENCE_DIR
dir to save inference models
--ngpu NGPU if ngpu == 0, use cpu.
--text TEXT text to synthesize, a 'utt_id sentence' pair per line.
--output_dir OUTPUT_DIR
output dir.
```
1. `--config`, `--ckpt`, and `--phones_dict` are arguments for acoustic model, which correspond to the 3 files in the VITS pretrained model.
2. `--lang` is the model language, which can be `zh` or `en`.
3. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
4. `--text` is the text file, which contains sentences to synthesize.
5. `--output_dir` is the directory to save synthesized audio files.
6. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Model
...@@ -65,7 +65,7 @@ Train a ParallelWaveGAN model. ...@@ -65,7 +65,7 @@ Train a ParallelWaveGAN model.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG config file to overwrite default config. --config CONFIG ParallelWaveGAN config file.
--train-metadata TRAIN_METADATA --train-metadata TRAIN_METADATA
training data. training data.
--dev-metadata DEV_METADATA --dev-metadata DEV_METADATA
......
...@@ -63,7 +63,7 @@ Train a Multi-Band MelGAN model. ...@@ -63,7 +63,7 @@ Train a Multi-Band MelGAN model.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG config file to overwrite default config. --config CONFIG Multi-Band MelGAN config file.
--train-metadata TRAIN_METADATA --train-metadata TRAIN_METADATA
training data. training data.
--dev-metadata DEV_METADATA --dev-metadata DEV_METADATA
......
...@@ -63,7 +63,7 @@ Train a Style MelGAN model. ...@@ -63,7 +63,7 @@ Train a Style MelGAN model.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG config file to overwrite default config. --config CONFIG Style MelGAN config file.
--train-metadata TRAIN_METADATA --train-metadata TRAIN_METADATA
training data. training data.
--dev-metadata DEV_METADATA --dev-metadata DEV_METADATA
......
...@@ -63,7 +63,7 @@ Train a HiFiGAN model. ...@@ -63,7 +63,7 @@ Train a HiFiGAN model.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG config file to overwrite default config. --config CONFIG HiFiGAN config file.
--train-metadata TRAIN_METADATA --train-metadata TRAIN_METADATA
training data. training data.
--dev-metadata DEV_METADATA --dev-metadata DEV_METADATA
......
...@@ -63,7 +63,7 @@ Train a WaveRNN model. ...@@ -63,7 +63,7 @@ Train a WaveRNN model.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG config file to overwrite default config. --config CONFIG WaveRNN config file.
--train-metadata TRAIN_METADATA --train-metadata TRAIN_METADATA
training data. training data.
--dev-metadata DEV_METADATA --dev-metadata DEV_METADATA
......
...@@ -103,12 +103,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ...@@ -103,12 +103,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p
``` ```
```text ```text
usage: synthesize.py [-h] usage: synthesize.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc}] [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT]
[--voice-cloning VOICE_CLONING] [--voice-cloning VOICE_CLONING]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--ngpu NGPU] [--voc_stat VOC_STAT] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
...@@ -117,11 +117,10 @@ Synthesize with acoustic model & vocoder ...@@ -117,11 +117,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc} --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -133,10 +132,10 @@ optional arguments: ...@@ -133,10 +132,10 @@ optional arguments:
speaker id map file. speaker id map file.
--voice-cloning VOICE_CLONING --voice-cloning VOICE_CLONING
whether training voice cloning model. whether training voice cloning model.
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -152,12 +151,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ...@@ -152,12 +151,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp
``` ```
```text ```text
usage: synthesize_e2e.py [-h] usage: synthesize_e2e.py [-h]
[--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc}] [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--tones_dict TONES_DICT]
[--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--lang LANG] [--voc_stat VOC_STAT] [--lang LANG]
[--inference_dir INFERENCE_DIR] [--ngpu NGPU] [--inference_dir INFERENCE_DIR] [--ngpu NGPU]
...@@ -167,11 +166,10 @@ Synthesize with acoustic model & vocoder ...@@ -167,11 +166,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc} --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -182,10 +180,10 @@ optional arguments: ...@@ -182,10 +180,10 @@ optional arguments:
--speaker_dict SPEAKER_DICT --speaker_dict SPEAKER_DICT
speaker id map file. speaker id map file.
--spk_id SPK_ID spk id for multi speaker acoustic model --spk_id SPK_ID spk id for multi speaker acoustic model
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -198,9 +196,9 @@ optional arguments: ...@@ -198,9 +196,9 @@ optional arguments:
output dir. output dir.
``` ```
1. `--am` is acoustic model type with the format {model_name}_{dataset} 1. `--am` is acoustic model type with the format {model_name}_{dataset}
2. `--am_config`, `--am_checkpoint`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the Tacotron2 pretrained model. 2. `--am_config`, `--am_ckpt`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the Tacotron2 pretrained model.
3. `--voc` is vocoder type with the format {model_name}_{dataset} 3. `--voc` is vocoder type with the format {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model.
5. `--lang` is the model language, which can be `zh` or `en`. 5. `--lang` is the model language, which can be `zh` or `en`.
6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
7. `--text` is the text file, which contains sentences to synthesize. 7. `--text` is the text file, which contains sentences to synthesize.
......
...@@ -58,7 +58,7 @@ Train a TransformerTTS model with LJSpeech TTS dataset. ...@@ -58,7 +58,7 @@ Train a TransformerTTS model with LJSpeech TTS dataset.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG config file to overwrite default config. --config CONFIG TransformerTTS config file.
--train-metadata TRAIN_METADATA --train-metadata TRAIN_METADATA
training data. training data.
--dev-metadata DEV_METADATA --dev-metadata DEV_METADATA
......
...@@ -107,14 +107,14 @@ pwg_ljspeech_ckpt_0.5 ...@@ -107,14 +107,14 @@ pwg_ljspeech_ckpt_0.5
```bash ```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
``` ```
``text ```text
usage: synthesize.py [-h] usage: synthesize.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT]
[--voice-cloning VOICE_CLONING] [--voice-cloning VOICE_CLONING]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--ngpu NGPU] [--voc_stat VOC_STAT] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
...@@ -123,11 +123,10 @@ Synthesize with acoustic model & vocoder ...@@ -123,11 +123,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -139,10 +138,10 @@ optional arguments: ...@@ -139,10 +138,10 @@ optional arguments:
speaker id map file. speaker id map file.
--voice-cloning VOICE_CLONING --voice-cloning VOICE_CLONING
whether training voice cloning model. whether training voice cloning model.
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -158,12 +157,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ...@@ -158,12 +157,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp
``` ```
```text ```text
usage: synthesize_e2e.py [-h] usage: synthesize_e2e.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--tones_dict TONES_DICT]
[--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--lang LANG] [--voc_stat VOC_STAT] [--lang LANG]
[--inference_dir INFERENCE_DIR] [--ngpu NGPU] [--inference_dir INFERENCE_DIR] [--ngpu NGPU]
...@@ -173,11 +172,10 @@ Synthesize with acoustic model & vocoder ...@@ -173,11 +172,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -188,10 +186,10 @@ optional arguments: ...@@ -188,10 +186,10 @@ optional arguments:
--speaker_dict SPEAKER_DICT --speaker_dict SPEAKER_DICT
speaker id map file. speaker id map file.
--spk_id SPK_ID spk id for multi speaker acoustic model --spk_id SPK_ID spk id for multi speaker acoustic model
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -204,9 +202,9 @@ optional arguments: ...@@ -204,9 +202,9 @@ optional arguments:
output dir. output dir.
``` ```
1. `--am` is acoustic model type with the format {model_name}_{dataset} 1. `--am` is acoustic model type with the format {model_name}_{dataset}
2. `--am_config`, `--am_checkpoint`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the fastspeech2 pretrained model. 2. `--am_config`, `--am_ckpt`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the fastspeech2 pretrained model.
3. `--voc` is vocoder type with the format {model_name}_{dataset} 3. `--voc` is vocoder type with the format {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model.
5. `--lang` is the model language, which can be `zh` or `en`. 5. `--lang` is the model language, which can be `zh` or `en`.
6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
7. `--text` is the text file, which contains sentences to synthesize. 7. `--text` is the text file, which contains sentences to synthesize.
......
...@@ -65,7 +65,7 @@ Train a ParallelWaveGAN model. ...@@ -65,7 +65,7 @@ Train a ParallelWaveGAN model.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG config file to overwrite default config. --config CONFIG ParallelWaveGAN config file.
--train-metadata TRAIN_METADATA --train-metadata TRAIN_METADATA
training data. training data.
--dev-metadata DEV_METADATA --dev-metadata DEV_METADATA
......
...@@ -57,15 +57,13 @@ Here's the complete help message. ...@@ -57,15 +57,13 @@ Here's the complete help message.
```text ```text
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
[--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] [--ngpu NGPU]
[--run-benchmark RUN_BENCHMARK]
[--profiler_options PROFILER_OPTIONS]
Train a ParallelWaveGAN model. Train a HiFiGAN model.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG config file to overwrite default config. --config CONFIG HiFiGAN config file.
--train-metadata TRAIN_METADATA --train-metadata TRAIN_METADATA
training data. training data.
--dev-metadata DEV_METADATA --dev-metadata DEV_METADATA
...@@ -73,19 +71,6 @@ optional arguments: ...@@ -73,19 +71,6 @@ optional arguments:
--output-dir OUTPUT_DIR --output-dir OUTPUT_DIR
output dir. output dir.
--ngpu NGPU if ngpu == 0, use cpu. --ngpu NGPU if ngpu == 0, use cpu.
benchmark:
arguments related to benchmark.
--batch-size BATCH_SIZE
batch size.
--max-iter MAX_ITER train max steps.
--run-benchmark RUN_BENCHMARK
runing benchmark or not, if True, use the --batch-size
and --max-iter.
--profiler_options PROFILER_OPTIONS
The option of profiler, which should be in format
"key1=value1;key2=value2;key3=value3".
``` ```
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. 1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
......
# 1xt2x
Convert Deepspeech 1.8 released model to 2.x.
## Model source directory
* Deepspeech2x
## Expriment directory
* aishell
* librispeech
* baidu_en8k
# The released model
Acoustic Model | Training Data | Hours of Speech | Token-based | CER | WER
:-------------:| :------------:| :---------------: | :---------: | :---: | :----:
Ds2 Offline Aishell 1xt2x model| Aishell Dataset | 151 h | Char-based | 0.080447 |
Ds2 Offline Librispeech 1xt2x model | Librispeech Dataset | 960 h | Word-based | | 0.068548
Ds2 Offline Baidu en8k 1x2x model | Baidu Internal English Dataset | 8628 h |Word-based | | 0.054112
# https://yaml.org/type/float.html
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.0
max_input_len: 27.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
###########################################
# Dataloader #
###########################################
batch_size: 64 # one gpu
mean_std_filepath: data/mean_std.npz
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
spectrum_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
############################################
# Network Architecture #
############################################
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 1024
use_gru: True
share_rnn_weights: False
blank_id: 4333
###########################################
# Training #
###########################################
n_epoch: 80
accum_grad: 1
lr: 2e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decode_batch_size: 32
error_rate_type: cer
decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 2.6
beta: 5.0
beam_size: 300
cutoff_prob: 0.99
cutoff_top_n: 40
num_proc_bsearch: 8
\ No newline at end of file
#!/bin/bash
if [ $# != 1 ];then
echo "usage: ${0} ckpt_dir"
exit -1
fi
ckpt_dir=$1
stage=-1
stop_stage=100
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/dataset
mkdir -p ${TARGET_DIR}
bash local/download_model.sh ${ckpt_dir}
if [ $? -ne 0 ]; then
exit 1
fi
cd ${ckpt_dir}
tar xzvf aishell_model_v1.8_to_v2.x.tar.gz
cd -
mv ${ckpt_dir}/mean_std.npz data/
mv ${ckpt_dir}/vocab.txt data/
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/aishell/aishell.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/aishell"
if [ $? -ne 0 ]; then
echo "Prepare Aishell failed. Terminated."
exit 1
fi
for dataset in train dev test; do
mv data/manifest.${dataset} data/manifest.${dataset}.raw
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for dataset in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--cmvn_path "data/mean_std.npz" \
--unit_type "char" \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${dataset}.raw" \
--output_path="data/manifest.${dataset}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
} &
done
wait
fi
echo "Aishell data preparation done."
exit 0
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL='https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
MD5="29e02312deb2e59b3c8686c7966d4fe3"
TARGET=${DIR}/zh_giga.no_cna_cmn.prune01244.klm
echo "Start downloading the language model. The language model is large, please wait for a moment ..."
download $URL $MD5 $TARGET > /dev/null 2>&1
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
else
echo "Download the language model sucessfully"
fi
exit 0
#! /usr/bin/env bash
if [ $# != 1 ];then
echo "usage: ${0} ckpt_dir"
exit -1
fi
ckpt_dir=$1
. ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz'
MD5=87e7577d4bea737dbf3e8daab37aa808
TARGET=${ckpt_dir}/aishell_model_v1.8_to_v2.x.tar.gz
echo "Download Aishell model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download Aishell model!"
exit 1
fi
exit 0
#!/bin/bash
if [ $# != 4 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
decode_config_path=$2
ckpt_prefix=$3
model_type=$4
# download language model
bash local/download_lm_ch.sh
if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../../`
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=deepspeech2
export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin
echo "BIN_DIR "${BIN_DIR}
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
decode_conf_path=conf/tuning/decode.yaml
avg_num=1
model_type=offline
gpus=2
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
v18_ckpt=aishell_v1.8
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
mkdir -p exp/${ckpt}/checkpoints
bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi
# https://yaml.org/type/float.html
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
min_input_len: 0.0
max_input_len: .inf # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
###########################################
# Dataloader #
###########################################
batch_size: 64 # one gpu
mean_std_filepath: data/mean_std.npz
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
spectrum_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
############################################
# Network Architecture #
############################################
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 1024
use_gru: True
share_rnn_weights: False
blank_id: 28
###########################################
# Training #
###########################################
n_epoch: 80
accum_grad: 1
lr: 2e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decode_batch_size: 32
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 1.4
beta: 0.35
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8
\ No newline at end of file
#!/bin/bash
if [ $# != 1 ];then
echo "usage: ${0} ckpt_dir"
exit -1
fi
ckpt_dir=$1
stage=-1
stop_stage=100
unit_type=char
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/dataset
mkdir -p ${TARGET_DIR}
bash local/download_model.sh ${ckpt_dir}
if [ $? -ne 0 ]; then
exit 1
fi
cd ${ckpt_dir}
tar xzvf baidu_en8k_v1.8_to_v2.x.tar.gz
cd -
mv ${ckpt_dir}/mean_std.npz data/
mv ${ckpt_dir}/vocab.txt data/
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/librispeech/librispeech.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/librispeech" \
--full_download="True"
if [ $? -ne 0 ]; then
echo "Prepare LibriSpeech failed. Terminated."
exit 1
fi
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mv data/manifest.${set} data/manifest.${set}.raw
done
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
for set in train-clean-100 train-clean-360 train-other-500; do
cat data/manifest.${set}.raw >> data/manifest.train.raw
done
for set in dev-clean dev-other; do
cat data/manifest.${set}.raw >> data/manifest.dev.raw
done
for set in test-clean test-other; do
cat data/manifest.${set}.raw >> data/manifest.test.raw
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--cmvn_path "data/mean_std.npz" \
--unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${set}.raw" \
--output_path="data/manifest.${set}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest.${set} failed. Terminated."
exit 1
fi
}&
done
wait
fi
echo "LibriSpeech Data preparation done."
exit 0
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
MD5="099a601759d467cd0a8523ff939819c5"
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
echo "Start downloading the language model. The language model is large, please wait for a moment ..."
download $URL $MD5 $TARGET > /dev/null 2>&1
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
else
echo "Download the language model sucessfully"
fi
exit 0
#! /usr/bin/env bash
if [ $# != 1 ];then
echo "usage: ${0} ckpt_dir"
exit -1
fi
ckpt_dir=$1
. ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz'
MD5=c1676be8505cee436e6f312823e9008c
TARGET=${ckpt_dir}/baidu_en8k_v1.8_to_v2.x.tar.gz
echo "Download BaiduEn8k model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download BaiduEn8k model!"
exit 1
fi
exit 0
#!/bin/bash
if [ $# != 4 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
decode_config_path=$2
ckpt_prefix=$3
model_type=$4
# download language model
bash local/download_lm_en.sh
if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../../`
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=deepspeech2
export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin
echo "BIN_DIR "${BIN_DIR}
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
decode_conf_path=conf/tuning/decode.yaml
avg_num=1
model_type=offline
gpus=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
v18_ckpt=baidu_en8k_v1.8
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
mkdir -p exp/${ckpt}/checkpoints
bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi
# https://yaml.org/type/float.html
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
min_input_len: 0.0
max_input_len: 1000.0 # second
min_output_len: 0.0
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
###########################################
# Dataloader #
###########################################
batch_size: 64 # one gpu
mean_std_filepath: data/mean_std.npz
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
spectrum_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
############################################
# Network Architecture #
############################################
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 2048
use_gru: False
share_rnn_weights: True
blank_id: 28
###########################################
# Training #
###########################################
n_epoch: 80
accum_grad: 1
lr: 2e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decode_batch_size: 32
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
beta: 0.3
beam_size: 500
cutoff_prob: 1.0
cutoff_top_n: 40
num_proc_bsearch: 8
\ No newline at end of file
#!/bin/bash
if [ $# != 1 ];then
echo "usage: ${0} ckpt_dir"
exit -1
fi
ckpt_dir=$1
stage=-1
stop_stage=100
unit_type=char
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
TARGET_DIR=${MAIN_ROOT}/dataset
mkdir -p ${TARGET_DIR}
bash local/download_model.sh ${ckpt_dir}
if [ $? -ne 0 ]; then
exit 1
fi
cd ${ckpt_dir}
tar xzvf librispeech_v1.8_to_v2.x.tar.gz
cd -
mv ${ckpt_dir}/mean_std.npz data/
mv ${ckpt_dir}/vocab.txt data/
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/librispeech/librispeech.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/librispeech" \
--full_download="True"
if [ $? -ne 0 ]; then
echo "Prepare LibriSpeech failed. Terminated."
exit 1
fi
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mv data/manifest.${set} data/manifest.${set}.raw
done
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
for set in train-clean-100 train-clean-360 train-other-500; do
cat data/manifest.${set}.raw >> data/manifest.train.raw
done
for set in dev-clean dev-other; do
cat data/manifest.${set}.raw >> data/manifest.dev.raw
done
for set in test-clean test-other; do
cat data/manifest.${set}.raw >> data/manifest.test.raw
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--cmvn_path "data/mean_std.npz" \
--unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${set}.raw" \
--output_path="data/manifest.${set}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest.${set} failed. Terminated."
exit 1
fi
}&
done
wait
fi
echo "LibriSpeech Data preparation done."
exit 0
#!/bin/bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
MD5="099a601759d467cd0a8523ff939819c5"
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
echo "Start downloading the language model. The language model is large, please wait for a moment ..."
download $URL $MD5 $TARGET > /dev/null 2>&1
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
else
echo "Download the language model sucessfully"
fi
exit 0
#! /usr/bin/env bash
if [ $# != 1 ];then
echo "usage: ${0} ckpt_dir"
exit -1
fi
ckpt_dir=$1
. ${MAIN_ROOT}/utils/utility.sh
URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz'
MD5=a06d9aadb560ea113984dc98d67232c8
TARGET=${ckpt_dir}/librispeech_v1.8_to_v2.x.tar.gz
echo "Download LibriSpeech model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download LibriSpeech model!"
exit 1
fi
exit 0
#!/bin/bash
if [ $# != 4 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
decode_config_path=$2
ckpt_prefix=$3
model_type=$4
# download language model
bash local/download_lm_en.sh
if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0
export MAIN_ROOT=`realpath ${PWD}/../../../../`
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=deepspeech2
export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin
#!/bin/bash
set -e
source path.sh
stage=0
stop_stage=100
conf_path=conf/deepspeech2.yaml
decode_conf_path=conf/tuning/decode.yaml
avg_num=1
model_type=offline
gpus=1
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
v18_ckpt=librispeech_v1.8
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
mkdir -p exp/${ckpt}/checkpoints
bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
fi
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import List
from typing import Tuple
from typing import Union
import paddle
from paddle import nn
from paddle.fluid import core
from paddle.nn import functional as F
from paddlespeech.s2t.utils.log import Log
#TODO(Hui Zhang): remove fluid import
logger = Log(__name__).getlog()
########### hack logging #############
logger.warn = logger.warning
########### hack paddle #############
paddle.half = 'float16'
paddle.float = 'float32'
paddle.double = 'float64'
paddle.short = 'int16'
paddle.int = 'int32'
paddle.long = 'int64'
paddle.uint16 = 'uint16'
paddle.cdouble = 'complex128'
def convert_dtype_to_string(tensor_dtype):
"""
Convert the data type in numpy to the data type in Paddle
Args:
tensor_dtype(core.VarDesc.VarType): the data type in numpy.
Returns:
core.VarDesc.VarType: the data type in Paddle.
"""
dtype = tensor_dtype
if dtype == core.VarDesc.VarType.FP32:
return paddle.float32
elif dtype == core.VarDesc.VarType.FP64:
return paddle.float64
elif dtype == core.VarDesc.VarType.FP16:
return paddle.float16
elif dtype == core.VarDesc.VarType.INT32:
return paddle.int32
elif dtype == core.VarDesc.VarType.INT16:
return paddle.int16
elif dtype == core.VarDesc.VarType.INT64:
return paddle.int64
elif dtype == core.VarDesc.VarType.BOOL:
return paddle.bool
elif dtype == core.VarDesc.VarType.BF16:
# since there is still no support for bfloat16 in NumPy,
# uint16 is used for casting bfloat16
return paddle.uint16
elif dtype == core.VarDesc.VarType.UINT8:
return paddle.uint8
elif dtype == core.VarDesc.VarType.INT8:
return paddle.int8
elif dtype == core.VarDesc.VarType.COMPLEX64:
return paddle.complex64
elif dtype == core.VarDesc.VarType.COMPLEX128:
return paddle.complex128
else:
raise ValueError("Not supported tensor dtype %s" % dtype)
if not hasattr(paddle, 'softmax'):
logger.warn("register user softmax to paddle, remove this when fixed!")
setattr(paddle, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle, 'log_softmax'):
logger.warn("register user log_softmax to paddle, remove this when fixed!")
setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax)
if not hasattr(paddle, 'sigmoid'):
logger.warn("register user sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle, 'log_sigmoid'):
logger.warn("register user log_sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid)
if not hasattr(paddle, 'relu'):
logger.warn("register user relu to paddle, remove this when fixed!")
setattr(paddle, 'relu', paddle.nn.functional.relu)
def cat(xs, dim=0):
return paddle.concat(xs, axis=dim)
if not hasattr(paddle, 'cat'):
logger.warn(
"override cat of paddle if exists or register, remove this when fixed!")
paddle.cat = cat
########### hack paddle.Tensor #############
def item(x: paddle.Tensor):
return x.numpy().item()
if not hasattr(paddle.Tensor, 'item'):
logger.warn(
"override item of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.item = item
def func_long(x: paddle.Tensor):
return paddle.cast(x, paddle.long)
if not hasattr(paddle.Tensor, 'long'):
logger.warn(
"override long of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.long = func_long
if not hasattr(paddle.Tensor, 'numel'):
logger.warn(
"override numel of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.numel = paddle.numel
def new_full(x: paddle.Tensor,
size: Union[List[int], Tuple[int], paddle.Tensor],
fill_value: Union[float, int, bool, paddle.Tensor],
dtype=None):
return paddle.full(size, fill_value, dtype=x.dtype)
if not hasattr(paddle.Tensor, 'new_full'):
logger.warn(
"override new_full of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.new_full = new_full
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
if convert_dtype_to_string(xs.dtype) == paddle.bool:
xs = xs.astype(paddle.int)
return xs.equal(
paddle.to_tensor(
ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place))
if not hasattr(paddle.Tensor, 'eq'):
logger.warn(
"override eq of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.eq = eq
if not hasattr(paddle, 'eq'):
logger.warn(
"override eq of paddle if exists or register, remove this when fixed!")
paddle.eq = eq
def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
return xs
if not hasattr(paddle.Tensor, 'contiguous'):
logger.warn(
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.contiguous = contiguous
def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
nargs = len(args)
assert (nargs <= 1)
s = paddle.shape(xs)
if nargs == 1:
return s[args[0]]
else:
return s
#`to_static` do not process `size` property, maybe some `paddle` api dependent on it.
logger.warn(
"override size of paddle.Tensor "
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
)
paddle.Tensor.size = size
def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
return xs.reshape(args)
if not hasattr(paddle.Tensor, 'view'):
logger.warn("register user view to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view = view
def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
return xs.reshape(ys.size())
if not hasattr(paddle.Tensor, 'view_as'):
logger.warn(
"register user view_as to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view_as = view_as
def is_broadcastable(shp1, shp2):
for a, b in zip(shp1[::-1], shp2[::-1]):
if a == 1 or b == 1 or a == b:
pass
else:
return False
return True
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
return xs
if not hasattr(paddle.Tensor, 'masked_fill'):
logger.warn(
"register user masked_fill to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill = masked_fill
def masked_fill_(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]) -> paddle.Tensor:
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
ret = paddle.where(mask, trues, xs)
paddle.assign(ret.detach(), output=xs)
return xs
if not hasattr(paddle.Tensor, 'masked_fill_'):
logger.warn(
"register user masked_fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill_ = masked_fill_
def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor:
val = paddle.full_like(xs, value)
paddle.assign(val.detach(), output=xs)
return xs
if not hasattr(paddle.Tensor, 'fill_'):
logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.fill_ = fill_
def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
return paddle.tile(xs, size)
if not hasattr(paddle.Tensor, 'repeat'):
logger.warn(
"register user repeat to paddle.Tensor, remove this when fixed!")
paddle.Tensor.repeat = repeat
if not hasattr(paddle.Tensor, 'softmax'):
logger.warn(
"register user softmax to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle.Tensor, 'sigmoid'):
logger.warn(
"register user sigmoid to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle.Tensor, 'relu'):
logger.warn("register user relu to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu)
def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor:
return x.astype(other.dtype)
if not hasattr(paddle.Tensor, 'type_as'):
logger.warn(
"register user type_as to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'type_as', type_as)
def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
assert len(args) == 1
if isinstance(args[0], str): # dtype
return x.astype(args[0])
elif isinstance(args[0], paddle.Tensor): #Tensor
return x.astype(args[0].dtype)
else: # Device
return x
if not hasattr(paddle.Tensor, 'to'):
logger.warn("register user to to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'to', to)
def func_float(x: paddle.Tensor) -> paddle.Tensor:
return x.astype(paddle.float)
if not hasattr(paddle.Tensor, 'float'):
logger.warn("register user float to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'float', func_float)
def func_int(x: paddle.Tensor) -> paddle.Tensor:
return x.astype(paddle.int)
if not hasattr(paddle.Tensor, 'int'):
logger.warn("register user int to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'int', func_int)
def tolist(x: paddle.Tensor) -> List[Any]:
return x.numpy().tolist()
if not hasattr(paddle.Tensor, 'tolist'):
logger.warn(
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
########### hack paddle.nn #############
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return F.glu(xs, axis=self.dim)
if not hasattr(paddle.nn, 'GLU'):
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'GLU', GLU)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
from src_deepspeech2x.test_model import DeepSpeech2Tester as Tester
from yacs.config import CfgNode
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
def main_sp(config, args):
exp = Tester(config, args)
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .deepspeech2 import DeepSpeech2InferModel
from .deepspeech2 import DeepSpeech2Model
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Model"""
import paddle
from paddle import nn
from src_deepspeech2x.models.ds2.rnn import RNNStack
from paddlespeech.s2t.models.ds2.conv import ConvStack
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.checkpoint import Checkpoint
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
class CRNNEncoder(nn.Layer):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True):
super().__init__()
self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size
self.conv = ConvStack(feat_size, num_conv_layers)
i_size = self.conv.output_height # H after conv stack
self.rnn = RNNStack(
i_size=i_size,
h_size=rnn_size,
num_stacks=num_rnn_layers,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
@property
def output_size(self):
return self.rnn_size * 2
def forward(self, audio, audio_len):
"""Compute Encoder outputs
Args:
audio (Tensor): [B, Tmax, D]
text (Tensor): [B, Umax]
audio_len (Tensor): [B]
text_len (Tensor): [B]
Returns:
x (Tensor): encoder outputs, [B, T, D]
x_lens (Tensor): encoder length, [B]
"""
# [B, T, D] -> [B, D, T]
audio = audio.transpose([0, 2, 1])
# [B, D, T] -> [B, C=1, D, T]
x = audio.unsqueeze(1)
x_lens = audio_len
# convolution group
x, x_lens = self.conv(x, x_lens)
x_val = x.numpy()
# convert data from convolution feature map to sequence of vectors
#B, C, D, T = paddle.shape(x) # not work under jit
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
#x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit
x = x.reshape([0, 0, -1]) #[B, T, C*D]
# remove padding part
x, x_lens = self.rnn(x, x_lens) #[B, T, D]
return x, x_lens
class DeepSpeech2Model(nn.Layer):
"""The DeepSpeech2 network structure.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param masks: Masks data layer to reset padding.
:type masks: Variable
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward direction RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True,
blank_id=0):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
assert (self.encoder.output_size == rnn_size * 2)
self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size,
blank_id=blank_id, # first token is <blank>
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
Args:
audio (Tensor): [B, T, D]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
loss (Tensor): [1]
"""
eouts, eouts_len = self.encoder(audio, audio_len)
loss = self.decoder(eouts, eouts_len, text, text_len)
return loss
@paddle.no_grad()
def decode(self, audio, audio_len):
# decoders only accept string encoded in utf-8
# Make sure the decoder has been initialized
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
batch_size = probs.shape[0]
self.decoder.reset_decoder(batch_size=batch_size)
self.decoder.next(probs, eouts_len)
trans_best, trans_beam = self.decoder.decode()
return trans_best
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2Model
The model built from pretrained result.
"""
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=len(dataloader.collate_fn.vocab_list),
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights)
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2Model from config
Parameters
config: yacs.config.CfgNode
config
Returns
-------
DeepSpeech2Model
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights,
blank_id=config.blank_id)
return model
class DeepSpeech2InferModel(DeepSpeech2Model):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True,
blank_id=0):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights,
blank_id=blank_id)
def forward(self, audio, audio_len):
"""export model function
Args:
audio (Tensor): [B, T, D]
audio_len (Tensor): [B]
Returns:
probs: probs after softmax
"""
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs, eouts_len
def export(self):
static_model = paddle.jit.to_static(
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None, self.encoder.feat_size],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
return static_model
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from paddlespeech.s2t.modules.activation import brelu
from paddlespeech.s2t.modules.mask import make_non_pad_mask
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['RNNStack']
class RNNCell(nn.RNNCellBase):
r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states.
The formula used is as follows:
.. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`.
"""
def __init__(self,
hidden_size: int,
activation="tanh",
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
if activation not in ["tanh", "relu", "brelu"]:
raise ValueError(
"activation for SimpleRNNCell should be tanh or relu, "
"but get {}".format(activation))
self.activation = activation
self._activation_fn = paddle.tanh \
if activation == "tanh" \
else F.relu
if activation == 'brelu':
self._activation_fn = brelu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_h = states
i2h = inputs
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self._activation_fn(i2h + h2h)
return h, h
@property
def state_shape(self):
return (self.hidden_size, )
class GRUCell(nn.RNNCellBase):
r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states.
The formula for GRU used is as follows:
.. math::
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator.
"""
def __init__(self,
input_size: int,
hidden_size: int,
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(3 * hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(3 * hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
self.input_size = input_size
self._gate_activation = F.sigmoid
self._activation = paddle.relu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_hidden = states # shape [batch_size, hidden_size]
x_gates = inputs
if self.bias_ih is not None:
x_gates = x_gates + self.bias_ih
bias_u, bias_r, bias_c = paddle.split(
self.bias_hh, num_or_sections=3, axis=0)
weight_hh = paddle.transpose(
self.weight_hh,
perm=[1, 0]) #weight_hh:shape[hidden_size, 3 * hidden_size]
w_u_r_c = paddle.flatten(weight_hh)
size_u_r = self.hidden_size * 2 * self.hidden_size
w_u_r = paddle.reshape(w_u_r_c[:size_u_r],
(self.hidden_size, self.hidden_size * 2))
w_u, w_r = paddle.split(w_u_r, num_or_sections=2, axis=1)
w_c = paddle.reshape(w_u_r_c[size_u_r:],
(self.hidden_size, self.hidden_size))
h_u = paddle.matmul(
pre_hidden, w_u,
transpose_y=False) + bias_u #shape [batch_size, hidden_size]
h_r = paddle.matmul(
pre_hidden, w_r,
transpose_y=False) + bias_r #shape [batch_size, hidden_size]
x_u, x_r, x_c = paddle.split(
x_gates, num_or_sections=3, axis=1) #shape[batch_size, hidden_size]
u = self._gate_activation(x_u + h_u) #shape [batch_size, hidden_size]
r = self._gate_activation(x_r + h_r) #shape [batch_size, hidden_size]
c = self._activation(
x_c + paddle.matmul(r * pre_hidden, w_c, transpose_y=False) +
bias_c) # [batch_size, hidden_size]
h = (1 - u) * pre_hidden + u * c
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru
return h, h
@property
def state_shape(self):
r"""
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
size would be automatically inserted into shape). The shape corresponds
to the shape of :math:`h_{t-1}`.
"""
return (self.hidden_size, )
class BiRNNWithBN(nn.Layer):
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param size: Dimension of RNN cells.
:type size: int
:param share_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
:type share_weights: bool
:return: Bidirectional simple rnn layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int, share_weights: bool):
super().__init__()
self.share_weights = share_weights
if self.share_weights:
#input-hidden weights shared between bi-directional rnn.
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
# batch norm is only performed on input-state projection
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = self.fw_fc
self.bw_bn = self.fw_bn
else:
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class BiGRUWithBN(nn.Layer):
"""Bidirectonal gru layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param name: Name of the layer.
:type name: string
:param input: Input layer.
:type input: Variable
:param size: Dimension of GRU cells.
:type size: int
:param act: Activation type.
:type act: string
:return: Bidirectional GRU layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int):
super().__init__()
hidden_size = h_size * 3
self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class RNNStack(nn.Layer):
"""RNN group with stacked bidirectional simple RNN or GRU layers.
:param input: Input layer.
:type input: Variable
:param size: Dimension of RNN cells in each layer.
:type size: int
:param num_stacks: Number of stacked rnn layers.
:type num_stacks: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: Output layer of the RNN group.
:rtype: Variable
"""
def __init__(self,
i_size: int,
h_size: int,
num_stacks: int,
use_gru: bool,
share_rnn_weights: bool):
super().__init__()
rnn_stacks = []
for i in range(num_stacks):
if use_gru:
#default:GRU using tanh
rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size))
else:
rnn_stacks.append(
BiRNNWithBN(
i_size=i_size,
h_size=h_size,
share_weights=share_rnn_weights))
i_size = h_size * 2
self.rnn_stacks = nn.LayerList(rnn_stacks)
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
"""
x: shape [B, T, D]
x_len: shpae [B]
"""
for i, rnn in enumerate(self.rnn_stacks):
x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
import time
from collections import defaultdict
from contextlib import nullcontext
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from src_deepspeech2x.models.ds2 import DeepSpeech2InferModel
from src_deepspeech2x.models.ds2 import DeepSpeech2Model
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.io.sampler import SortagradBatchSampler
from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler
from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline
from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline
from paddlespeech.s2t.training.gradclip import ClipGradByGlobalNormWithLog
from paddlespeech.s2t.training.trainer import Trainer
from paddlespeech.s2t.utils import error_rate
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils import mp_tools
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config
start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data
loss = self.model(audio, audio_len, text, text_len)
losses_np = {
'train_loss': float(loss),
}
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
self.iteration += 1
iteration_time = time.time() - start
msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
logger.info(msg)
if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items():
# `step -1` since we update `step` after optimizer.step().
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration - 1)
@paddle.no_grad()
def valid(self):
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
self.model.eval()
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
utt, audio, audio_len, text, text_len = batch
loss = self.model(audio, audio_len, text, text_len)
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
valid_losses['val_loss'].append(float(loss))
if (i + 1) % self.config.log_interval == 0:
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
valid_dump['val_history_loss'] = total_loss / num_seen_utts
# logging
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
logger.info('Rank {} Val info val_loss {}'.format(
dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts
def setup_model(self):
config = self.config.clone()
config.defrost()
config.feat_size = self.train_loader.collate_fn.feature_size
#config.dict_size = self.train_loader.collate_fn.vocab_size
config.dict_size = len(self.train_loader.collate_fn.vocab_list)
config.freeze()
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config)
elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline.from_config(config)
else:
raise Exception("wrong model type")
if self.parallel:
model = paddle.DataParallel(model)
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
grad_clip = ClipGradByGlobalNormWithLog(config.global_grad_clip)
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=config.lr, gamma=config.lr_decay, verbose=True)
optimizer = paddle.optimizer.Adam(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=paddle.regularizer.L2Decay(config.weight_decay),
grad_clip=grad_clip)
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
logger.info("Setup model/optimizer/lr_scheduler!")
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.keep_transcription_text = False
config.manifest = config.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.manifest = config.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
config.manifest = config.test_manifest
test_dataset = ManifestDataset.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.sortagrad,
shuffle_method=config.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.batch_size,
drop_last=True,
sortagrad=config.sortagrad,
shuffle_method=config.shuffle_method)
collate_fn_train = SpeechCollator.from_config(config)
config.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
config.keep_transcription_text = True
config.augmentation_config = ""
collate_fn_test = SpeechCollator.from_config(config)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn_train,
num_workers=config.num_workers)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decode.decode_batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_test)
if "<eos>" in self.test_loader.collate_fn.vocab_list:
self.test_loader.collate_fn.vocab_list.remove("<eos>")
if "<eos>" in self.valid_loader.collate_fn.vocab_list:
self.valid_loader.collate_fn.vocab_list.remove("<eos>")
if "<eos>" in self.train_loader.collate_fn.vocab_list:
self.train_loader.collate_fn.vocab_list.remove("<eos>")
logger.info("Setup train/valid/test Dataloader!")
class DeepSpeech2Tester(DeepSpeech2Trainer):
def __init__(self, config, args):
self._text_featurizer = TextFeaturizer(
unit_type=config.unit_type, vocab=None)
super().__init__(config, args)
def ordid2token(self, texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decode
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.compute_result_transcripts(audio, audio_len)
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info("Current error rate [%s] = %f" %
(cfg.error_rate_type, error_rate_func(target, result)))
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins,
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)
def compute_result_transcripts(self, audio, audio_len):
result_transcripts = self.model.decode(audio, audio_len)
result_transcripts = [
self._text_featurizer.detokenize(item)
for item in result_transcripts
]
return result_transcripts
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
self.model.eval()
cfg = self.config
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
# Initialized the decoder in model
decode_cfg = self.config.decode
vocab_list = self.test_loader.collate_fn.vocab_list
decode_batch_size = self.test_loader.batch_size
self.model.decoder.init_decoder(
decode_batch_size, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts,
texts_len, fout)
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs))
# logging
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
self.model.decoder.del_decoder()
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def export(self):
if self.args.model_type == 'offline':
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
elif self.args.model_type == 'online':
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
else:
raise Exception("wrong model type")
infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size
static_model = infer_model.export()
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
exit(-1)
...@@ -42,9 +42,6 @@ def get_transcripts(path: Union[str, Path]): ...@@ -42,9 +42,6 @@ def get_transcripts(path: Union[str, Path]):
for i in range(0, len(lines), 2): for i in range(0, len(lines), 2):
sentence_id = lines[i].split()[0] sentence_id = lines[i].split()[0]
transcription = lines[i + 1].strip() transcription = lines[i + 1].strip()
# tones are dropped here
# since the lexicon does not consider tones, too
transcription = " ".join([item[:-1] for item in transcription.split()])
transcripts[sentence_id] = transcription transcripts[sentence_id] = transcription
return transcripts return transcripts
......
...@@ -4,7 +4,7 @@ mkdir -p $EXP_DIR ...@@ -4,7 +4,7 @@ mkdir -p $EXP_DIR
LEXICON_NAME='simple' LEXICON_NAME='simple'
if [ ! -f "$EXP_DIR/$LEXICON_NAME.lexicon" ]; then if [ ! -f "$EXP_DIR/$LEXICON_NAME.lexicon" ]; then
echo "generating lexicon..." echo "generating lexicon..."
python local/generate_lexicon.py "$EXP_DIR/$LEXICON_NAME" --with-r python local/generate_lexicon.py "$EXP_DIR/$LEXICON_NAME" --with-r --with-tone
echo "lexicon done" echo "lexicon done"
fi fi
...@@ -16,6 +16,7 @@ if [ ! -d $EXP_DIR/baker_corpus ]; then ...@@ -16,6 +16,7 @@ if [ ! -d $EXP_DIR/baker_corpus ]; then
echo "transcription for each audio file is saved with the same namd in $EXP_DIR/baker_corpus " echo "transcription for each audio file is saved with the same namd in $EXP_DIR/baker_corpus "
fi fi
echo "detecting oov..." echo "detecting oov..."
python local/detect_oov.py $EXP_DIR/baker_corpus $EXP_DIR/"$LEXICON_NAME.lexicon" python local/detect_oov.py $EXP_DIR/baker_corpus $EXP_DIR/"$LEXICON_NAME.lexicon"
echo "detecting oov done. you may consider regenerate lexicon if there is unexpected OOVs." echo "detecting oov done. you may consider regenerate lexicon if there is unexpected OOVs."
...@@ -44,6 +45,3 @@ if [ ! -d "$EXP_DIR/baker_alignment" ]; then ...@@ -44,6 +45,3 @@ if [ ! -d "$EXP_DIR/baker_alignment" ]; then
echo "model: $EXP_DIR/baker_model" echo "model: $EXP_DIR/baker_model"
fi fi
...@@ -112,12 +112,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p ...@@ -112,12 +112,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p
``` ```
```text ```text
usage: synthesize.py [-h] usage: synthesize.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] [--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT] [--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT]
[--voice-cloning VOICE_CLONING] [--voice-cloning VOICE_CLONING]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--ngpu NGPU] [--voc_stat VOC_STAT] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR] [--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
...@@ -126,11 +126,10 @@ Synthesize with acoustic model & vocoder ...@@ -126,11 +126,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} --am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -142,10 +141,10 @@ optional arguments: ...@@ -142,10 +141,10 @@ optional arguments:
speaker id map file. speaker id map file.
--voice-cloning VOICE_CLONING --voice-cloning VOICE_CLONING
whether training voice cloning model. whether training voice cloning model.
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -161,12 +160,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp ...@@ -161,12 +160,12 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_outp
``` ```
```text ```text
usage: synthesize_e2e.py [-h] usage: synthesize_e2e.py [-h]
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}] [--am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT] [--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT] [--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--tones_dict TONES_DICT] [--tones_dict TONES_DICT]
[--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID] [--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID]
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}] [--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT] [--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--lang LANG] [--voc_stat VOC_STAT] [--lang LANG]
[--inference_dir INFERENCE_DIR] [--ngpu NGPU] [--inference_dir INFERENCE_DIR] [--ngpu NGPU]
...@@ -176,11 +175,10 @@ Synthesize with acoustic model & vocoder ...@@ -176,11 +175,10 @@ Synthesize with acoustic model & vocoder
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk} --am {speedyspeech_csmsc,speedyspeech_aishell3,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech}
Choose acoustic model type of tts task. Choose acoustic model type of tts task.
--am_config AM_CONFIG --am_config AM_CONFIG
Config of acoustic model. Use deault config when it is Config of acoustic model.
None.
--am_ckpt AM_CKPT Checkpoint file of acoustic model. --am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize --am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model. spectrogram when training acoustic model.
...@@ -191,10 +189,10 @@ optional arguments: ...@@ -191,10 +189,10 @@ optional arguments:
--speaker_dict SPEAKER_DICT --speaker_dict SPEAKER_DICT
speaker id map file. speaker id map file.
--spk_id SPK_ID spk id for multi speaker acoustic model --spk_id SPK_ID spk id for multi speaker acoustic model
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc} --voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,style_melgan_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,wavernn_csmsc}
Choose vocoder type of tts task. Choose vocoder type of tts task.
--voc_config VOC_CONFIG --voc_config VOC_CONFIG
Config of voc. Use deault config when it is None. Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc. --voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize --voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc. spectrogram when training voc.
...@@ -207,9 +205,9 @@ optional arguments: ...@@ -207,9 +205,9 @@ optional arguments:
output dir. output dir.
``` ```
1. `--am` is acoustic model type with the format {model_name}_{dataset} 1. `--am` is acoustic model type with the format {model_name}_{dataset}
2. `--am_config`, `--am_checkpoint`, `--am_stat`, `--phones_dict` `--speaker_dict` are arguments for acoustic model, which correspond to the 5 files in the fastspeech2 pretrained model. 2. `--am_config`, `--am_ckpt`, `--am_stat`, `--phones_dict` `--speaker_dict` are arguments for acoustic model, which correspond to the 5 files in the fastspeech2 pretrained model.
3. `--voc` is vocoder type with the format {model_name}_{dataset} 3. `--voc` is vocoder type with the format {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model. 4. `--voc_config`, `--voc_ckpt`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model.
5. `--lang` is the model language, which can be `zh` or `en`. 5. `--lang` is the model language, which can be `zh` or `en`.
6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. 6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
7. `--text` is the text file, which contains sentences to synthesize. 7. `--text` is the text file, which contains sentences to synthesize.
......
...@@ -70,7 +70,7 @@ Train a ParallelWaveGAN model. ...@@ -70,7 +70,7 @@ Train a ParallelWaveGAN model.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG config file to overwrite default config. --config CONFIG ParallelWaveGAN config file.
--train-metadata TRAIN_METADATA --train-metadata TRAIN_METADATA
training data. training data.
--dev-metadata DEV_METADATA --dev-metadata DEV_METADATA
......
...@@ -62,15 +62,13 @@ Here's the complete help message. ...@@ -62,15 +62,13 @@ Here's the complete help message.
```text ```text
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
[--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] [--ngpu NGPU]
[--run-benchmark RUN_BENCHMARK]
[--profiler_options PROFILER_OPTIONS]
Train a ParallelWaveGAN model. Train a HiFiGAN model.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG config file to overwrite default config. --config CONFIG HiFiGAN config file.
--train-metadata TRAIN_METADATA --train-metadata TRAIN_METADATA
training data. training data.
--dev-metadata DEV_METADATA --dev-metadata DEV_METADATA
...@@ -78,19 +76,6 @@ optional arguments: ...@@ -78,19 +76,6 @@ optional arguments:
--output-dir OUTPUT_DIR --output-dir OUTPUT_DIR
output dir. output dir.
--ngpu NGPU if ngpu == 0, use cpu. --ngpu NGPU if ngpu == 0, use cpu.
benchmark:
arguments related to benchmark.
--batch-size BATCH_SIZE
batch size.
--max-iter MAX_ITER train max steps.
--run-benchmark RUN_BENCHMARK
runing benchmark or not, if True, use the --batch-size
and --max-iter.
--profiler_options PROFILER_OPTIONS
The option of profiler, which should be in format
"key1=value1;key2=value2;key3=value3".
``` ```
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. 1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 Xiaomi Corporation (Author: Yongqing Wang) # Copyright 2021 Xiaomi Corporation (Author: Yongqing Wang)
# Mobvoi Inc(Author: Di Wu, Binbin Zhang) # Mobvoi Inc(Author: Di Wu, Binbin Zhang)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -24,6 +13,7 @@ ...@@ -24,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import json import json
import os import os
......
...@@ -13,14 +13,7 @@ ...@@ -13,14 +13,7 @@
# limitations under the License. # limitations under the License.
import _locale import _locale
from .asr import ASRExecutor
from .base_commands import BaseCommand from .base_commands import BaseCommand
from .base_commands import HelpCommand from .base_commands import HelpCommand
from .cls import CLSExecutor
from .st import STExecutor
from .stats import StatsExecutor
from .text import TextExecutor
from .tts import TTSExecutor
from .vector import VectorExecutor
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])
...@@ -29,30 +29,21 @@ from yacs.config import CfgNode ...@@ -29,30 +29,21 @@ from yacs.config import CfgNode
from ..download import get_path_from_url from ..download import get_path_from_url
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register
from ..utils import CLI_TIMER from ..utils import CLI_TIMER
from ..utils import MODEL_HOME from ..utils import MODEL_HOME
from ..utils import stats_wrapper from ..utils import stats_wrapper
from ..utils import timer_register from ..utils import timer_register
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor'] __all__ = ['ASRExecutor']
@timer_register @timer_register
@cli_register(
name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor): class ASRExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__(task='asr', inference_type='offline')
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.asr', add_help=True) prog='paddlespeech.asr', add_help=True)
self.parser.add_argument( self.parser.add_argument(
...@@ -62,7 +53,8 @@ class ASRExecutor(BaseExecutor): ...@@ -62,7 +53,8 @@ class ASRExecutor(BaseExecutor):
type=str, type=str,
default='conformer_wenetspeech', default='conformer_wenetspeech',
choices=[ choices=[
tag[:tag.index('-')] for tag in self.pretrained_models.keys() tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
], ],
help='Choose model type of asr task.') help='Choose model type of asr task.')
self.parser.add_argument( self.parser.add_argument(
...@@ -144,14 +136,14 @@ class ASRExecutor(BaseExecutor): ...@@ -144,14 +136,14 @@ class ASRExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None: if cfg_path is None or ckpt_path is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.task_resource.set_task_model(tag, version=None)
self.res_path = res_path self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join( self.cfg_path = os.path.join(
res_path, self.pretrained_models[tag]['cfg_path']) self.res_path, self.task_resource.res_dict['cfg_path'])
self.ckpt_path = os.path.join( self.ckpt_path = os.path.join(
res_path, self.res_path,
self.pretrained_models[tag]['ckpt_path'] + ".pdparams") self.task_resource.res_dict['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(self.res_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
...@@ -175,8 +167,8 @@ class ASRExecutor(BaseExecutor): ...@@ -175,8 +167,8 @@ class ASRExecutor(BaseExecutor):
self.collate_fn_test = SpeechCollator.from_config(self.config) self.collate_fn_test = SpeechCollator.from_config(self.config)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, vocab=self.vocab) unit_type=self.config.unit_type, vocab=self.vocab)
lm_url = self.pretrained_models[tag]['lm_url'] lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.pretrained_models[tag]['lm_md5'] lm_md5 = self.task_resource.res_dict['lm_md5']
self.download_lm( self.download_lm(
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
...@@ -194,7 +186,7 @@ class ASRExecutor(BaseExecutor): ...@@ -194,7 +186,7 @@ class ASRExecutor(BaseExecutor):
raise Exception("wrong type") raise Exception("wrong type")
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
model_class = dynamic_import(model_name, self.model_alias) model_class = self.task_resource.get_model_class(model_name)
model_conf = self.config model_conf = self.config
model = model_class.from_config(model_conf) model = model_class.from_config(model_conf)
self.model = model self.model = model
...@@ -441,7 +433,7 @@ class ASRExecutor(BaseExecutor): ...@@ -441,7 +433,7 @@ class ASRExecutor(BaseExecutor):
if not parser_args.verbose: if not parser_args.verbose:
self.disable_task_loggers() self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input) task_source = self.get_input_source(parser_args.input)
task_results = OrderedDict() task_results = OrderedDict()
has_exceptions = False has_exceptions = False
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"conformer_wenetspeech-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz',
'md5':
'76cb19ed857e6623856b7cd7ebbfeda4',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/conformer/checkpoints/wenetspeech',
},
"conformer_online_wenetspeech-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz',
'md5':
'b8c02632b04da34aca88459835be54a6',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_10',
},
"conformer_online_multicn-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz',
'md5':
'7989b3248c898070904cf042fd656003',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/multi_cn',
},
"conformer_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz',
'md5':
'3f073eccfa7bb14e0c6867d65fc0dc3a',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/conformer/checkpoints/avg_30',
},
"conformer_online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz',
'md5':
'b374cfb93537761270b6224fb0bfc26a',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_30',
},
"transformer_librispeech-en-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'2c667da24922aad391eacafe37bc1660',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/transformer/checkpoints/avg_10',
},
"deepspeech2online_wenetspeech-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz',
'md5':
'e393d4d274af0f6967db24fc146e8074',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2offline_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz',
'md5':
'98b87b171b7240b7cae6e07d8d0bc9be',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2offline_librispeech-en-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'f5666c81ad015c8de03aac2bc92e5762',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm',
'lm_md5':
'099a601759d467cd0a8523ff939819c5'
},
}
model_alias = {
"deepspeech2offline":
"paddlespeech.s2t.models.ds2:DeepSpeech2Model",
"deepspeech2online":
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer":
"paddlespeech.s2t.models.u2:U2Model",
"conformer_online":
"paddlespeech.s2t.models.u2:U2Model",
"transformer":
"paddlespeech.s2t.models.u2:U2Model",
"wenetspeech":
"paddlespeech.s2t.models.u2:U2Model",
}
...@@ -11,16 +11,18 @@ ...@@ -11,16 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
from typing import List from typing import List
from prettytable import PrettyTable
from ..resource import CommonTaskResource
from .entry import commands from .entry import commands
from .utils import cli_register from .utils import cli_register
from .utils import explicit_command_register
from .utils import get_command from .utils import get_command
__all__ = [ __all__ = ['BaseCommand', 'HelpCommand', 'StatsCommand']
'BaseCommand',
'HelpCommand',
]
@cli_register(name='paddlespeech') @cli_register(name='paddlespeech')
...@@ -73,3 +75,74 @@ class VersionCommand: ...@@ -73,3 +75,74 @@ class VersionCommand:
print(msg) print(msg)
return True return True
model_name_format = {
'asr': 'Model-Language-Sample Rate',
'cls': 'Model-Sample Rate',
'st': 'Model-Source language-Target language',
'text': 'Model-Task-Language',
'tts': 'Model-Language',
'vector': 'Model-Sample Rate'
}
@cli_register(
name='paddlespeech.stats',
description='Get speech tasks support models list.')
class StatsCommand:
def __init__(self):
self.parser = argparse.ArgumentParser(
prog='paddlespeech.stats', add_help=True)
self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector']
self.parser.add_argument(
'--task',
type=str,
default='asr',
choices=self.task_choices,
help='Choose speech task.',
required=True)
def show_support_models(self, pretrained_models: dict):
fields = model_name_format[self.task].split("-")
table = PrettyTable(fields)
for key in pretrained_models:
table.add_row(key.split("-"))
print(table)
def execute(self, argv: List[str]) -> bool:
parser_args = self.parser.parse_args(argv)
self.task = parser_args.task
if self.task not in self.task_choices:
print("Please input correct speech task, choices = " + str(
self.task_choices))
return
pretrained_models = CommonTaskResource(task=self.task).pretrained_models
try:
print(
"Here is the list of {} pretrained models released by PaddleSpeech that can be used by command line and python API"
.format(self.task.upper()))
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of {} pretrained models.".format(
self.task.upper()))
# Dynamic import when running specific command
_commands = {
'asr': ['Speech to text infer command.', 'ASRExecutor'],
'cls': ['Audio classification infer command.', 'CLSExecutor'],
'st': ['Speech translation infer command.', 'STExecutor'],
'text': ['Text command.', 'TextExecutor'],
'tts': ['Text to Speech infer command.', 'TTSExecutor'],
'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'],
}
for com, info in _commands.items():
explicit_command_register(
name='paddlespeech.{}'.format(com),
description=info[0],
cls='paddlespeech.cli.{}.{}'.format(com, info[1]))
\ No newline at end of file
...@@ -21,29 +21,19 @@ from typing import Union ...@@ -21,29 +21,19 @@ from typing import Union
import numpy as np import numpy as np
import paddle import paddle
import yaml import yaml
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from paddlespeech.utils.dynamic_import import dynamic_import
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from paddleaudio import load
from .pretrained_models import pretrained_models from paddleaudio.features import LogMelSpectrogram
__all__ = ['CLSExecutor'] __all__ = ['CLSExecutor']
@cli_register(
name='paddlespeech.cls', description='Audio classification infer command.')
class CLSExecutor(BaseExecutor): class CLSExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__(task='cls')
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.cls', add_help=True) prog='paddlespeech.cls', add_help=True)
self.parser.add_argument( self.parser.add_argument(
...@@ -53,7 +43,8 @@ class CLSExecutor(BaseExecutor): ...@@ -53,7 +43,8 @@ class CLSExecutor(BaseExecutor):
type=str, type=str,
default='panns_cnn14', default='panns_cnn14',
choices=[ choices=[
tag[:tag.index('-')] for tag in self.pretrained_models.keys() tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
], ],
help='Choose model type of cls task.') help='Choose model type of cls task.')
self.parser.add_argument( self.parser.add_argument(
...@@ -106,13 +97,16 @@ class CLSExecutor(BaseExecutor): ...@@ -106,13 +97,16 @@ class CLSExecutor(BaseExecutor):
if label_file is None or ckpt_path is None: if label_file is None or ckpt_path is None:
tag = model_type + '-' + '32k' # panns_cnn14-32k tag = model_type + '-' + '32k' # panns_cnn14-32k
self.res_path = self._get_pretrained_path(tag) self.task_resource.set_task_model(tag, version=None)
self.cfg_path = os.path.join( self.cfg_path = os.path.join(
self.res_path, self.pretrained_models[tag]['cfg_path']) self.task_resource.res_dir,
self.task_resource.res_dict['cfg_path'])
self.label_file = os.path.join( self.label_file = os.path.join(
self.res_path, self.pretrained_models[tag]['label_file']) self.task_resource.res_dir,
self.task_resource.res_dict['label_file'])
self.ckpt_path = os.path.join( self.ckpt_path = os.path.join(
self.res_path, self.pretrained_models[tag]['ckpt_path']) self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path'])
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.label_file = os.path.abspath(label_file) self.label_file = os.path.abspath(label_file)
...@@ -129,7 +123,7 @@ class CLSExecutor(BaseExecutor): ...@@ -129,7 +123,7 @@ class CLSExecutor(BaseExecutor):
self._label_list.append(line.strip()) self._label_list.append(line.strip())
# model # model
model_class = dynamic_import(model_type, self.model_alias) model_class = self.task_resource.get_model_class(model_type)
model_dict = paddle.load(self.ckpt_path) model_dict = paddle.load(self.ckpt_path)
self.model = model_class(extract_embedding=False) self.model = model_class(extract_embedding=False)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
...@@ -206,7 +200,7 @@ class CLSExecutor(BaseExecutor): ...@@ -206,7 +200,7 @@ class CLSExecutor(BaseExecutor):
if not parser_args.verbose: if not parser_args.verbose:
self.disable_task_loggers() self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input) task_source = self.get_input_source(parser_args.input)
task_results = OrderedDict() task_results = OrderedDict()
has_exceptions = False has_exceptions = False
...@@ -246,4 +240,4 @@ class CLSExecutor(BaseExecutor): ...@@ -246,4 +240,4 @@ class CLSExecutor(BaseExecutor):
self.infer() self.infer()
res = self.postprocess(topk) # Retrieve result of cls. res = self.postprocess(topk) # Retrieve result of cls.
return res return res
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"panns_cnn6-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz',
'md5': '4cf09194a95df024fd12f84712cf0f9c',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn6.pdparams',
'label_file': 'audioset_labels.txt',
},
"panns_cnn10-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz',
'md5': 'cb8427b22176cc2116367d14847f5413',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn10.pdparams',
'label_file': 'audioset_labels.txt',
},
"panns_cnn14-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz',
'md5': 'e3b9b5614a1595001161d0ab95edee97',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn14.pdparams',
'label_file': 'audioset_labels.txt',
},
}
model_alias = {
"panns_cnn6": "paddlespeech.cls.models.panns:CNN6",
"panns_cnn10": "paddlespeech.cls.models.panns:CNN10",
"panns_cnn14": "paddlespeech.cls.models.panns:CNN14",
}
...@@ -34,6 +34,11 @@ def _execute(): ...@@ -34,6 +34,11 @@ def _execute():
# The method 'execute' of a command instance returns 'True' for a success # The method 'execute' of a command instance returns 'True' for a success
# while 'False' for a failure. Here converts this result into a exit status # while 'False' for a failure. Here converts this result into a exit status
# in bash: 0 for a success and 1 for a failure. # in bash: 0 for a success and 1 for a failure.
if not callable(com['_entry']):
i = com['_entry'].rindex('.')
module, cls = com['_entry'][:i], com['_entry'][i + 1:]
exec("from {} import {}".format(module, cls))
com['_entry'] = locals()[cls]
status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1 status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1
return status return status
......
...@@ -24,9 +24,8 @@ from typing import Union ...@@ -24,9 +24,8 @@ from typing import Union
import paddle import paddle
from ..resource import CommonTaskResource
from .log import logger from .log import logger
from .utils import download_and_decompress
from .utils import MODEL_HOME
class BaseExecutor(ABC): class BaseExecutor(ABC):
...@@ -34,11 +33,10 @@ class BaseExecutor(ABC): ...@@ -34,11 +33,10 @@ class BaseExecutor(ABC):
An abstract executor of paddlespeech tasks. An abstract executor of paddlespeech tasks.
""" """
def __init__(self): def __init__(self, task: str, **kwargs):
self._inputs = OrderedDict() self._inputs = OrderedDict()
self._outputs = OrderedDict() self._outputs = OrderedDict()
self.pretrained_models = OrderedDict() self.task_resource = CommonTaskResource(task=task, **kwargs)
self.model_alias = OrderedDict()
@abstractmethod @abstractmethod
def _init_from_path(self, *args, **kwargs): def _init_from_path(self, *args, **kwargs):
...@@ -98,8 +96,8 @@ class BaseExecutor(ABC): ...@@ -98,8 +96,8 @@ class BaseExecutor(ABC):
""" """
pass pass
def get_task_source(self, input_: Union[str, os.PathLike, None] def get_input_source(self, input_: Union[str, os.PathLike, None]
) -> Dict[str, Union[str, os.PathLike]]: ) -> Dict[str, Union[str, os.PathLike]]:
""" """
Get task input source from command line input. Get task input source from command line input.
...@@ -115,15 +113,17 @@ class BaseExecutor(ABC): ...@@ -115,15 +113,17 @@ class BaseExecutor(ABC):
ret = OrderedDict() ret = OrderedDict()
if input_ is None: # Take input from stdin if input_ is None: # Take input from stdin
for i, line in enumerate(sys.stdin): if not sys.stdin.isatty(
line = line.strip() ): # Avoid getting stuck when stdin is empty.
if len(line.split(' ')) == 1: for i, line in enumerate(sys.stdin):
ret[str(i + 1)] = line line = line.strip()
elif len(line.split(' ')) == 2: if len(line.split(' ')) == 1:
id_, info = line.split(' ') ret[str(i + 1)] = line
ret[id_] = info elif len(line.split(' ')) == 2:
else: # No valid input info from one line. id_, info = line.split(' ')
continue ret[id_] = info
else: # No valid input info from one line.
continue
else: else:
ret[1] = input_ ret[1] = input_
return ret return ret
...@@ -219,23 +219,6 @@ class BaseExecutor(ABC): ...@@ -219,23 +219,6 @@ class BaseExecutor(ABC):
for l in loggers: for l in loggers:
l.disabled = True l.disabled = True
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(self.pretrained_models.keys())
assert tag in self.pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(self.pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def show_rtf(self, info: Dict[str, List[float]]): def show_rtf(self, info: Dict[str, List[float]]):
""" """
Calculate rft of current task and show results. Calculate rft of current task and show results.
......
...@@ -28,27 +28,25 @@ from yacs.config import CfgNode ...@@ -28,27 +28,25 @@ from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register
from ..utils import download_and_decompress from ..utils import download_and_decompress
from ..utils import MODEL_HOME from ..utils import MODEL_HOME
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import kaldi_bins
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ["STExecutor"] __all__ = ["STExecutor"]
kaldi_bins = {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz",
"md5":
"c0682303b3f3393dbf6ed4c4e35a53eb",
}
@cli_register(
name="paddlespeech.st", description="Speech translation infer command.")
class STExecutor(BaseExecutor): class STExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__(task='st')
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.kaldi_bins = kaldi_bins self.kaldi_bins = kaldi_bins
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
...@@ -60,7 +58,8 @@ class STExecutor(BaseExecutor): ...@@ -60,7 +58,8 @@ class STExecutor(BaseExecutor):
type=str, type=str,
default="fat_st_ted", default="fat_st_ted",
choices=[ choices=[
tag[:tag.index('-')] for tag in self.pretrained_models.keys() tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
], ],
help="Choose model type of st task.") help="Choose model type of st task.")
self.parser.add_argument( self.parser.add_argument(
...@@ -134,14 +133,16 @@ class STExecutor(BaseExecutor): ...@@ -134,14 +133,16 @@ class STExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None: if cfg_path is None or ckpt_path is None:
tag = model_type + "-" + src_lang + "-" + tgt_lang tag = model_type + "-" + src_lang + "-" + tgt_lang
res_path = self._get_pretrained_path(tag) self.task_resource.set_task_model(tag, version=None)
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(
pretrained_models[tag]["cfg_path"]) self.task_resource.res_dir,
self.ckpt_path = os.path.join(res_path, self.task_resource.res_dict['cfg_path'])
pretrained_models[tag]["ckpt_path"]) self.ckpt_path = os.path.join(
logger.info(res_path) self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path'])
logger.info(self.cfg_path) logger.info(self.cfg_path)
logger.info(self.ckpt_path) logger.info(self.ckpt_path)
res_path = self.task_resource.res_dir
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path) self.ckpt_path = os.path.abspath(ckpt_path)
...@@ -166,7 +167,7 @@ class STExecutor(BaseExecutor): ...@@ -166,7 +167,7 @@ class STExecutor(BaseExecutor):
model_conf = self.config model_conf = self.config
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
model_class = dynamic_import(model_name, self.model_alias) model_class = self.task_resource.get_model_class(model_name)
self.model = model_class.from_config(model_conf) self.model = model_class.from_config(model_conf)
self.model.eval() self.model.eval()
...@@ -304,7 +305,7 @@ class STExecutor(BaseExecutor): ...@@ -304,7 +305,7 @@ class STExecutor(BaseExecutor):
if not parser_args.verbose: if not parser_args.verbose:
self.disable_task_loggers() self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input) task_source = self.get_input_source(parser_args.input)
task_results = OrderedDict() task_results = OrderedDict()
has_exceptions = False has_exceptions = False
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
"fat_st_ted-en-zh": {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz",
"md5":
"d62063f35a16d91210a71081bd2dd557",
"cfg_path":
"model.yaml",
"ckpt_path":
"exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams",
}
}
model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"}
kaldi_bins = {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz",
"md5":
"c0682303b3f3393dbf6ed4c4e35a53eb",
}
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import List
from prettytable import PrettyTable
from ..utils import cli_register
from ..utils import stats_wrapper
__all__ = ['StatsExecutor']
model_name_format = {
'asr': 'Model-Language-Sample Rate',
'cls': 'Model-Sample Rate',
'st': 'Model-Source language-Target language',
'text': 'Model-Task-Language',
'tts': 'Model-Language',
'vector': 'Model-Sample Rate'
}
@cli_register(
name='paddlespeech.stats',
description='Get speech tasks support models list.')
class StatsExecutor():
def __init__(self):
super().__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech.stats', add_help=True)
self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector']
self.parser.add_argument(
'--task',
type=str,
default='asr',
choices=self.task_choices,
help='Choose speech task.',
required=True)
def show_support_models(self, pretrained_models: dict):
fields = model_name_format[self.task].split("-")
table = PrettyTable(fields)
for key in pretrained_models:
table.add_row(key.split("-"))
print(table)
def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
"""
parser_args = self.parser.parse_args(argv)
has_exceptions = False
try:
self(parser_args.task)
except Exception as e:
has_exceptions = True
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(
self,
task: str=None, ):
"""
Python API to call an executor.
"""
self.task = task
if self.task not in self.task_choices:
print("Please input correct speech task, choices = " + str(
self.task_choices))
elif self.task == 'asr':
try:
from ..asr.pretrained_models import pretrained_models
print(
"Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of ASR pretrained models.")
elif self.task == 'cls':
try:
from ..cls.pretrained_models import pretrained_models
print(
"Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of CLS pretrained models.")
elif self.task == 'st':
try:
from ..st.pretrained_models import pretrained_models
print(
"Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of ST pretrained models.")
elif self.task == 'text':
try:
from ..text.pretrained_models import pretrained_models
print(
"Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of TEXT pretrained models.")
elif self.task == 'tts':
try:
from ..tts.pretrained_models import pretrained_models
print(
"Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of TTS pretrained models.")
elif self.task == 'vector':
try:
from ..vector.pretrained_models import pretrained_models
print(
"Here is the list of Speaker Recognition pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print(
"Failed to get the list of Speaker Recognition pretrained models."
)
...@@ -23,24 +23,14 @@ import paddle ...@@ -23,24 +23,14 @@ import paddle
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from .pretrained_models import tokenizer_alias
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TextExecutor'] __all__ = ['TextExecutor']
@cli_register(name='paddlespeech.text', description='Text infer command.')
class TextExecutor(BaseExecutor): class TextExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__(task='text')
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.tokenizer_alias = tokenizer_alias
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.text', add_help=True) prog='paddlespeech.text', add_help=True)
self.parser.add_argument( self.parser.add_argument(
...@@ -56,7 +46,8 @@ class TextExecutor(BaseExecutor): ...@@ -56,7 +46,8 @@ class TextExecutor(BaseExecutor):
type=str, type=str,
default='ernie_linear_p7_wudao', default='ernie_linear_p7_wudao',
choices=[ choices=[
tag[:tag.index('-')] for tag in self.pretrained_models.keys() tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
], ],
help='Choose model type of text task.') help='Choose model type of text task.')
self.parser.add_argument( self.parser.add_argument(
...@@ -114,13 +105,16 @@ class TextExecutor(BaseExecutor): ...@@ -114,13 +105,16 @@ class TextExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None or vocab_file is None: if cfg_path is None or ckpt_path is None or vocab_file is None:
tag = '-'.join([model_type, task, lang]) tag = '-'.join([model_type, task, lang])
self.res_path = self._get_pretrained_path(tag) self.task_resource.set_task_model(tag, version=None)
self.cfg_path = os.path.join( self.cfg_path = os.path.join(
self.res_path, self.pretrained_models[tag]['cfg_path']) self.task_resource.res_dir,
self.task_resource.res_dict['cfg_path'])
self.ckpt_path = os.path.join( self.ckpt_path = os.path.join(
self.res_path, self.pretrained_models[tag]['ckpt_path']) self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path'])
self.vocab_file = os.path.join( self.vocab_file = os.path.join(
self.res_path, self.pretrained_models[tag]['vocab_file']) self.task_resource.res_dir,
self.task_resource.res_dict['vocab_file'])
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path) self.ckpt_path = os.path.abspath(ckpt_path)
...@@ -135,8 +129,8 @@ class TextExecutor(BaseExecutor): ...@@ -135,8 +129,8 @@ class TextExecutor(BaseExecutor):
self._punc_list.append(line.strip()) self._punc_list.append(line.strip())
# model # model
model_class = dynamic_import(model_name, self.model_alias) model_class, tokenizer_class = self.task_resource.get_model_class(
tokenizer_class = dynamic_import(model_name, self.tokenizer_alias) model_name)
self.model = model_class( self.model = model_class(
cfg_path=self.cfg_path, ckpt_path=self.ckpt_path) cfg_path=self.cfg_path, ckpt_path=self.ckpt_path)
self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0') self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0')
...@@ -226,7 +220,7 @@ class TextExecutor(BaseExecutor): ...@@ -226,7 +220,7 @@ class TextExecutor(BaseExecutor):
if not parser_args.verbose: if not parser_args.verbose:
self.disable_task_loggers() self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input) task_source = self.get_input_source(parser_args.input)
task_results = OrderedDict() task_results = OrderedDict()
has_exceptions = False has_exceptions = False
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"ernie_linear_p7_wudao-punc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz',
'md5':
'12283e2ddde1797c5d1e57036b512746',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
"ernie_linear_p3_wudao-punc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz',
'md5':
'448eb2fdf85b6a997e7e652e80c51dd2',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
}
model_alias = {
"ernie_linear_p7": "paddlespeech.text.models:ErnieLinear",
"ernie_linear_p3": "paddlespeech.text.models:ErnieLinear",
}
tokenizer_alias = {
"ernie_linear_p7": "paddlenlp.transformers:ErnieTokenizer",
"ernie_linear_p3": "paddlenlp.transformers:ErnieTokenizer",
}
...@@ -28,26 +28,17 @@ from yacs.config import CfgNode ...@@ -28,26 +28,17 @@ from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TTSExecutor'] __all__ = ['TTSExecutor']
@cli_register(
name='paddlespeech.tts', description='Text to Speech infer command.')
class TTSExecutor(BaseExecutor): class TTSExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__('tts')
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.tts', add_help=True) prog='paddlespeech.tts', add_help=True)
self.parser.add_argument( self.parser.add_argument(
...@@ -186,19 +177,23 @@ class TTSExecutor(BaseExecutor): ...@@ -186,19 +177,23 @@ class TTSExecutor(BaseExecutor):
return return
# am # am
am_tag = am + '-' + lang am_tag = am + '-' + lang
self.task_resource.set_task_model(
model_tag=am_tag,
model_type=0, # am
version=None, # default version
)
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag) self.am_res_path = self.task_resource.res_dir
self.am_res_path = am_res_path self.am_config = os.path.join(self.am_res_path,
self.am_config = os.path.join( self.task_resource.res_dict['config'])
am_res_path, self.pretrained_models[am_tag]['config']) self.am_ckpt = os.path.join(self.am_res_path,
self.am_ckpt = os.path.join(am_res_path, self.task_resource.res_dict['ckpt'])
self.pretrained_models[am_tag]['ckpt'])
self.am_stat = os.path.join( self.am_stat = os.path.join(
am_res_path, self.pretrained_models[am_tag]['speech_stats']) self.am_res_path, self.task_resource.res_dict['speech_stats'])
# must have phones_dict in acoustic # must have phones_dict in acoustic
self.phones_dict = os.path.join( self.phones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['phones_dict']) self.am_res_path, self.task_resource.res_dict['phones_dict'])
logger.info(am_res_path) logger.info(self.am_res_path)
logger.info(self.am_config) logger.info(self.am_config)
logger.info(self.am_ckpt) logger.info(self.am_ckpt)
else: else:
...@@ -210,32 +205,37 @@ class TTSExecutor(BaseExecutor): ...@@ -210,32 +205,37 @@ class TTSExecutor(BaseExecutor):
# for speedyspeech # for speedyspeech
self.tones_dict = None self.tones_dict = None
if 'tones_dict' in self.pretrained_models[am_tag]: if 'tones_dict' in self.task_resource.res_dict:
self.tones_dict = os.path.join( self.tones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['tones_dict']) self.am_res_path, self.task_resource.res_dict['tones_dict'])
if tones_dict: if tones_dict:
self.tones_dict = tones_dict self.tones_dict = tones_dict
# for multi speaker fastspeech2 # for multi speaker fastspeech2
self.speaker_dict = None self.speaker_dict = None
if 'speaker_dict' in self.pretrained_models[am_tag]: if 'speaker_dict' in self.task_resource.res_dict:
self.speaker_dict = os.path.join( self.speaker_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['speaker_dict']) self.am_res_path, self.task_resource.res_dict['speaker_dict'])
if speaker_dict: if speaker_dict:
self.speaker_dict = speaker_dict self.speaker_dict = speaker_dict
# voc # voc
voc_tag = voc + '-' + lang voc_tag = voc + '-' + lang
self.task_resource.set_task_model(
model_tag=voc_tag,
model_type=1, # vocoder
version=None, # default version
)
if voc_ckpt is None or voc_config is None or voc_stat is None: if voc_ckpt is None or voc_config is None or voc_stat is None:
voc_res_path = self._get_pretrained_path(voc_tag) self.voc_res_path = self.task_resource.voc_res_dir
self.voc_res_path = voc_res_path
self.voc_config = os.path.join( self.voc_config = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['config']) self.voc_res_path, self.task_resource.voc_res_dict['config'])
self.voc_ckpt = os.path.join( self.voc_ckpt = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['ckpt']) self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])
self.voc_stat = os.path.join( self.voc_stat = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['speech_stats']) self.voc_res_path,
logger.info(voc_res_path) self.task_resource.voc_res_dict['speech_stats'])
logger.info(self.voc_res_path)
logger.info(self.voc_config) logger.info(self.voc_config)
logger.info(self.voc_ckpt) logger.info(self.voc_ckpt)
else: else:
...@@ -285,9 +285,9 @@ class TTSExecutor(BaseExecutor): ...@@ -285,9 +285,9 @@ class TTSExecutor(BaseExecutor):
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_class = dynamic_import(am_name, self.model_alias) am_class = self.task_resource.get_model_class(am_name)
am_inference_class = dynamic_import(am_name + '_inference', am_inference_class = self.task_resource.get_model_class(am_name +
self.model_alias) '_inference')
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
am = am_class( am = am_class(
...@@ -316,9 +316,9 @@ class TTSExecutor(BaseExecutor): ...@@ -316,9 +316,9 @@ class TTSExecutor(BaseExecutor):
# vocoder # vocoder
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
voc_name = voc[:voc.rindex('_')] voc_name = voc[:voc.rindex('_')]
voc_class = dynamic_import(voc_name, self.model_alias) voc_class = self.task_resource.get_model_class(voc_name)
voc_inference_class = dynamic_import(voc_name + '_inference', voc_inference_class = self.task_resource.get_model_class(voc_name +
self.model_alias) '_inference')
if voc_name != 'wavernn': if voc_name != 'wavernn':
voc = voc_class(**self.voc_config["generator_params"]) voc = voc_class(**self.voc_config["generator_params"])
voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"]) voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"])
...@@ -446,7 +446,7 @@ class TTSExecutor(BaseExecutor): ...@@ -446,7 +446,7 @@ class TTSExecutor(BaseExecutor):
if not args.verbose: if not args.verbose:
self.disable_task_loggers() self.disable_task_loggers()
task_source = self.get_task_source(args.input) task_source = self.get_input_source(args.input)
task_results = OrderedDict() task_results = OrderedDict()
has_exceptions = False has_exceptions = False
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# speedyspeech
"speedyspeech_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip',
'md5':
'6f6fa967b408454b6662c8c00c0027cb',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_30600.pdz',
'speech_stats':
'feats_stats.npy',
'phones_dict':
'phone_id_map.txt',
'tones_dict':
'tone_id_map.txt',
},
# fastspeech2
"fastspeech2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip',
'md5':
'637d28a5e53aa60275612ba4393d5f22',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_76000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip',
'md5':
'ffed800c93deaf16ca9b3af89bfcd747',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_100000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip',
'md5':
'f4dd4a5f49a4552b77981f544ab3392e',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_96400.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
"fastspeech2_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip',
'md5':
'743e5024ca1e17a88c5c271db9779ba4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_66200.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
# tacotron2
"tacotron2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip',
'md5':
'0df4b6f0bcbe0d73c5ed6df8867ab91a',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_30600.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"tacotron2_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip',
'md5':
'6a5eddd81ae0e81d16959b97481135f3',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_60300.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
# pwgan
"pwgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip',
'md5':
'2e481633325b5bdf0a3823c714d2c117',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
"pwgan_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip',
'md5':
'53610ba9708fd3008ccaf8e99dacbaf0',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
"pwgan_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip',
'md5':
'd7598fa41ad362d62f85ffc0f07e3d84',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
"pwgan_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip',
'md5':
'b3da1defcde3e578be71eb284cb89f2c',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'ee5f0604e20091f0d495b6ec4618b90d',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
# style_melgan
"style_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'5de2d5348f396de0c966926b8c462755',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
"hifigan_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip',
'md5':
'70e9131695decbca06a65fe51ed38a72',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
"hifigan_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip',
'md5':
'3bb49bc75032ed12f79c00c8cc79a09a',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
"hifigan_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip',
'md5':
'7da8f88359bca2457e705d924cf27bd4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# wavernn
"wavernn_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip',
'md5':
'ee37b752f09bcba8f2af3b777ca38e13',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_400000.pdz',
'speech_stats':
'feats_stats.npy',
}
}
model_alias = {
# acoustic model
"speedyspeech":
"paddlespeech.t2s.models.speedyspeech:SpeedySpeech",
"speedyspeech_inference":
"paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference",
"fastspeech2":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2",
"fastspeech2_inference":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
"tacotron2":
"paddlespeech.t2s.models.tacotron2:Tacotron2",
"tacotron2_inference":
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
# voc
"pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
"pwgan_inference":
"paddlespeech.t2s.models.parallel_wavegan:PWGInference",
"mb_melgan":
"paddlespeech.t2s.models.melgan:MelGANGenerator",
"mb_melgan_inference":
"paddlespeech.t2s.models.melgan:MelGANInference",
"style_melgan":
"paddlespeech.t2s.models.melgan:StyleMelGANGenerator",
"style_melgan_inference":
"paddlespeech.t2s.models.melgan:StyleMelGANInference",
"hifigan":
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
"hifigan_inference":
"paddlespeech.t2s.models.hifigan:HiFiGANInference",
"wavernn":
"paddlespeech.t2s.models.wavernn:WaveRNN",
"wavernn_inference":
"paddlespeech.t2s.models.wavernn:WaveRNNInference",
}
...@@ -41,6 +41,7 @@ requests.adapters.DEFAULT_RETRIES = 3 ...@@ -41,6 +41,7 @@ requests.adapters.DEFAULT_RETRIES = 3
__all__ = [ __all__ = [
'timer_register', 'timer_register',
'cli_register', 'cli_register',
'explicit_command_register',
'get_command', 'get_command',
'download_and_decompress', 'download_and_decompress',
'load_state_dict_from_url', 'load_state_dict_from_url',
...@@ -70,6 +71,16 @@ def cli_register(name: str, description: str='') -> Any: ...@@ -70,6 +71,16 @@ def cli_register(name: str, description: str='') -> Any:
return _warpper return _warpper
def explicit_command_register(name: str, description: str='', cls: str=''):
items = name.split('.')
com = commands
for item in items:
com = com[item]
com['_entry'] = cls
if description:
com['_description'] = description
def get_command(name: str) -> Any: def get_command(name: str) -> Any:
items = name.split('.') items = name.split('.')
com = commands com = commands
......
...@@ -22,30 +22,20 @@ from typing import Union ...@@ -22,30 +22,20 @@ from typing import Union
import paddle import paddle
import soundfile import soundfile
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from paddleaudio.backends import load as load_audio
from .pretrained_models import pretrained_models from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.utils.dynamic_import import dynamic_import
from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.modules.sid_model import SpeakerIdetification
@cli_register(
name="paddlespeech.vector",
description="Speech to vector embedding infer command.")
class VectorExecutor(BaseExecutor): class VectorExecutor(BaseExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__('vector')
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog="paddlespeech.vector", add_help=True) prog="paddlespeech.vector", add_help=True)
...@@ -53,7 +43,10 @@ class VectorExecutor(BaseExecutor): ...@@ -53,7 +43,10 @@ class VectorExecutor(BaseExecutor):
"--model", "--model",
type=str, type=str,
default="ecapatdnn_voxceleb12", default="ecapatdnn_voxceleb12",
choices=["ecapatdnn_voxceleb12"], choices=[
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
],
help="Choose model type of vector task.") help="Choose model type of vector task.")
self.parser.add_argument( self.parser.add_argument(
"--task", "--task",
...@@ -123,7 +116,7 @@ class VectorExecutor(BaseExecutor): ...@@ -123,7 +116,7 @@ class VectorExecutor(BaseExecutor):
self.disable_task_loggers() self.disable_task_loggers()
# stage 2: read the input data and store them as a list # stage 2: read the input data and store them as a list
task_source = self.get_task_source(parser_args.input) task_source = self.get_input_source(parser_args.input)
logger.info(f"task source: {task_source}") logger.info(f"task source: {task_source}")
# stage 3: process the audio one by one # stage 3: process the audio one by one
...@@ -300,17 +293,18 @@ class VectorExecutor(BaseExecutor): ...@@ -300,17 +293,18 @@ class VectorExecutor(BaseExecutor):
# get the mode from pretrained list # get the mode from pretrained list
sample_rate_str = "16k" if sample_rate == 16000 else "8k" sample_rate_str = "16k" if sample_rate == 16000 else "8k"
tag = model_type + "-" + sample_rate_str tag = model_type + "-" + sample_rate_str
self.task_resource.set_task_model(tag, version=None)
logger.info(f"load the pretrained model: {tag}") logger.info(f"load the pretrained model: {tag}")
# get the model from the pretrained list # get the model from the pretrained list
# we download the pretrained model and store it in the res_path # we download the pretrained model and store it in the res_path
res_path = self._get_pretrained_path(tag) self.res_path = self.task_resource.res_dir
self.res_path = res_path
self.cfg_path = os.path.join( self.cfg_path = os.path.join(
res_path, self.pretrained_models[tag]['cfg_path']) self.task_resource.res_dir,
self.task_resource.res_dict['cfg_path'])
self.ckpt_path = os.path.join( self.ckpt_path = os.path.join(
res_path, self.task_resource.res_dir,
self.pretrained_models[tag]['ckpt_path'] + '.pdparams') self.task_resource.res_dict['ckpt_path'] + '.pdparams')
else: else:
# get the model from disk # get the model from disk
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
...@@ -329,8 +323,8 @@ class VectorExecutor(BaseExecutor): ...@@ -329,8 +323,8 @@ class VectorExecutor(BaseExecutor):
# stage 3: get the model name to instance the model network with dynamic_import # stage 3: get the model name to instance the model network with dynamic_import
logger.info("start to dynamic import the model class") logger.info("start to dynamic import the model class")
model_name = model_type[:model_type.rindex('_')] model_name = model_type[:model_type.rindex('_')]
model_class = self.task_resource.get_model_class(model_name)
logger.info(f"model name {model_name}") logger.info(f"model name {model_name}")
model_class = dynamic_import(model_name, self.model_alias)
model_conf = self.config.model model_conf = self.config.model
backbone = model_class(**model_conf) backbone = model_class(**model_conf)
model = SpeakerIdetification( model = SpeakerIdetification(
...@@ -476,4 +470,4 @@ class VectorExecutor(BaseExecutor): ...@@ -476,4 +470,4 @@ class VectorExecutor(BaseExecutor):
else: else:
logger.info("The audio file format is right") logger.info("The audio file format is right")
return True return True
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]".
# e.g. "ecapatdnn_voxceleb12-16k".
# Command line and python api use "{model_name}[-{dataset}]" as --model, usage:
# "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav"
"ecapatdnn_voxceleb12-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1.tar.gz',
'md5':
'67c7ff8885d5246bd16e0f5ac1cba99f',
'cfg_path':
'conf/model.yaml', # the yaml config path
'ckpt_path':
'model/model', # the format is ${dir}/{model_name},
# so the first 'model' is dir, the second 'model' is the name
# this means we have a model stored as model/model.pdparams
},
}
model_alias = {
"ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn",
}
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
# 2022 Shaoqing Yu(954793264@qq.com)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
# Menglong Xu
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
# 2022 Shaoqing Yu(954793264@qq.com)
# 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2021 Binbin Zhang
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,4 +11,4 @@ ...@@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .infer import StatsExecutor from .resource import CommonTaskResource
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
'model_alias',
]
# Records of model name to import class
model_alias = {
# ---------------------------------
# -------------- ASR --------------
# ---------------------------------
"deepspeech2offline": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"],
"deepspeech2online":
["paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline"],
"conformer": ["paddlespeech.s2t.models.u2:U2Model"],
"conformer_online": ["paddlespeech.s2t.models.u2:U2Model"],
"transformer": ["paddlespeech.s2t.models.u2:U2Model"],
"wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"],
# ---------------------------------
# -------------- CLS --------------
# ---------------------------------
"panns_cnn6": ["paddlespeech.cls.models.panns:CNN6"],
"panns_cnn10": ["paddlespeech.cls.models.panns:CNN10"],
"panns_cnn14": ["paddlespeech.cls.models.panns:CNN14"],
# ---------------------------------
# -------------- ST ---------------
# ---------------------------------
"fat_st": ["paddlespeech.s2t.models.u2_st:U2STModel"],
# ---------------------------------
# -------------- TEXT -------------
# ---------------------------------
"ernie_linear_p7": [
"paddlespeech.text.models:ErnieLinear",
"paddlenlp.transformers:ErnieTokenizer"
],
"ernie_linear_p3": [
"paddlespeech.text.models:ErnieLinear",
"paddlenlp.transformers:ErnieTokenizer"
],
# ---------------------------------
# -------------- TTS --------------
# ---------------------------------
# acoustic model
"speedyspeech": ["paddlespeech.t2s.models.speedyspeech:SpeedySpeech"],
"speedyspeech_inference":
["paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference"],
"fastspeech2": ["paddlespeech.t2s.models.fastspeech2:FastSpeech2"],
"fastspeech2_inference":
["paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference"],
"tacotron2": ["paddlespeech.t2s.models.tacotron2:Tacotron2"],
"tacotron2_inference":
["paddlespeech.t2s.models.tacotron2:Tacotron2Inference"],
# voc
"pwgan": ["paddlespeech.t2s.models.parallel_wavegan:PWGGenerator"],
"pwgan_inference":
["paddlespeech.t2s.models.parallel_wavegan:PWGInference"],
"mb_melgan": ["paddlespeech.t2s.models.melgan:MelGANGenerator"],
"mb_melgan_inference": ["paddlespeech.t2s.models.melgan:MelGANInference"],
"style_melgan": ["paddlespeech.t2s.models.melgan:StyleMelGANGenerator"],
"style_melgan_inference":
["paddlespeech.t2s.models.melgan:StyleMelGANInference"],
"hifigan": ["paddlespeech.t2s.models.hifigan:HiFiGANGenerator"],
"hifigan_inference": ["paddlespeech.t2s.models.hifigan:HiFiGANInference"],
"wavernn": ["paddlespeech.t2s.models.wavernn:WaveRNN"],
"wavernn_inference": ["paddlespeech.t2s.models.wavernn:WaveRNNInference"],
# ---------------------------------
# ------------ Vector -------------
# ---------------------------------
"ecapatdnn": ["paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn"],
}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
'asr_dynamic_pretrained_models',
'asr_static_pretrained_models',
'cls_dynamic_pretrained_models',
'cls_static_pretrained_models',
'st_dynamic_pretrained_models',
'st_kaldi_bins',
'text_dynamic_pretrained_models',
'tts_dynamic_pretrained_models',
'tts_static_pretrained_models',
'tts_onnx_pretrained_models',
'vector_dynamic_pretrained_models',
]
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
# ---------------------------------
# -------------- ASR --------------
# ---------------------------------
asr_dynamic_pretrained_models = {
"conformer_wenetspeech-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz',
'md5':
'76cb19ed857e6623856b7cd7ebbfeda4',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/conformer/checkpoints/wenetspeech',
},
},
"conformer_online_wenetspeech-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz',
'md5':
'b8c02632b04da34aca88459835be54a6',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_10',
'model':
'exp/chunk_conformer/checkpoints/avg_10.pdparams',
'params':
'exp/chunk_conformer/checkpoints/avg_10.pdparams',
'lm_url':
'',
'lm_md5':
'',
},
},
"conformer_online_multicn-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz',
'md5':
'7989b3248c898070904cf042fd656003',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/multi_cn',
},
'2.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz',
'md5':
'0ac93d390552336f2a906aec9e33c5fa',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/multi_cn',
'model':
'exp/chunk_conformer/checkpoints/multi_cn.pdparams',
'params':
'exp/chunk_conformer/checkpoints/multi_cn.pdparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3',
},
},
"conformer_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz',
'md5':
'3f073eccfa7bb14e0c6867d65fc0dc3a',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/conformer/checkpoints/avg_30',
},
},
"conformer_online_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz',
'md5':
'b374cfb93537761270b6224fb0bfc26a',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_30',
},
},
"transformer_librispeech-en-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'2c667da24922aad391eacafe37bc1660',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/transformer/checkpoints/avg_10',
},
},
"deepspeech2online_wenetspeech-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz',
'md5':
'e393d4d274af0f6967db24fc146e8074',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2offline_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2online_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz',
'md5':
'98b87b171b7240b7cae6e07d8d0bc9be',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2offline_librispeech-en-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'f5666c81ad015c8de03aac2bc92e5762',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm',
'lm_md5':
'099a601759d467cd0a8523ff939819c5'
},
},
}
asr_static_pretrained_models = {
"deepspeech2offline_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'model':
'exp/deepspeech2/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
}
},
}
# ---------------------------------
# -------------- CLS --------------
# ---------------------------------
cls_dynamic_pretrained_models = {
"panns_cnn6-32k": {
'1.0': {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz',
'md5': '4cf09194a95df024fd12f84712cf0f9c',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn6.pdparams',
'label_file': 'audioset_labels.txt',
},
},
"panns_cnn10-32k": {
'1.0': {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz',
'md5': 'cb8427b22176cc2116367d14847f5413',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn10.pdparams',
'label_file': 'audioset_labels.txt',
},
},
"panns_cnn14-32k": {
'1.0': {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz',
'md5': 'e3b9b5614a1595001161d0ab95edee97',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn14.pdparams',
'label_file': 'audioset_labels.txt',
},
},
}
cls_static_pretrained_models = {
"panns_cnn6-32k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz',
'md5':
'da087c31046d23281d8ec5188c1967da',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
},
"panns_cnn10-32k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz',
'md5':
'5460cc6eafbfaf0f261cc75b90284ae1',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
},
"panns_cnn14-32k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz',
'md5':
'ccc80b194821274da79466862b2ab00f',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
},
}
# ---------------------------------
# -------------- ST ---------------
# ---------------------------------
st_dynamic_pretrained_models = {
"fat_st_ted-en-zh": {
'1.0': {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz",
"md5":
"d62063f35a16d91210a71081bd2dd557",
"cfg_path":
"model.yaml",
"ckpt_path":
"exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams",
},
},
}
st_kaldi_bins = {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz",
"md5":
"c0682303b3f3393dbf6ed4c4e35a53eb",
}
# ---------------------------------
# -------------- TEXT -------------
# ---------------------------------
text_dynamic_pretrained_models = {
"ernie_linear_p7_wudao-punc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz',
'md5':
'12283e2ddde1797c5d1e57036b512746',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
},
"ernie_linear_p3_wudao-punc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz',
'md5':
'448eb2fdf85b6a997e7e652e80c51dd2',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
},
}
# ---------------------------------
# -------------- TTS --------------
# ---------------------------------
tts_dynamic_pretrained_models = {
# speedyspeech
"speedyspeech_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip',
'md5':
'6f6fa967b408454b6662c8c00c0027cb',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_30600.pdz',
'speech_stats':
'feats_stats.npy',
'phones_dict':
'phone_id_map.txt',
'tones_dict':
'tone_id_map.txt',
},
},
# fastspeech2
"fastspeech2_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip',
'md5':
'637d28a5e53aa60275612ba4393d5f22',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_76000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
},
"fastspeech2_ljspeech-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip',
'md5':
'ffed800c93deaf16ca9b3af89bfcd747',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_100000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
},
"fastspeech2_aishell3-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip',
'md5':
'f4dd4a5f49a4552b77981f544ab3392e',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_96400.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
},
"fastspeech2_vctk-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip',
'md5':
'743e5024ca1e17a88c5c271db9779ba4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_66200.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
},
# tacotron2
"tacotron2_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip',
'md5':
'0df4b6f0bcbe0d73c5ed6df8867ab91a',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_30600.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
},
"tacotron2_ljspeech-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip',
'md5':
'6a5eddd81ae0e81d16959b97481135f3',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_60300.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
},
# pwgan
"pwgan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip',
'md5':
'2e481633325b5bdf0a3823c714d2c117',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
},
"pwgan_ljspeech-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip',
'md5':
'53610ba9708fd3008ccaf8e99dacbaf0',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
},
"pwgan_aishell3-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip',
'md5':
'd7598fa41ad362d62f85ffc0f07e3d84',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
"pwgan_vctk-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip',
'md5':
'b3da1defcde3e578be71eb284cb89f2c',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'ee5f0604e20091f0d495b6ec4618b90d',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
# style_melgan
"style_melgan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'5de2d5348f396de0c966926b8c462755',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
# hifigan
"hifigan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
"hifigan_ljspeech-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip',
'md5':
'70e9131695decbca06a65fe51ed38a72',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
"hifigan_aishell3-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip',
'md5':
'3bb49bc75032ed12f79c00c8cc79a09a',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
"hifigan_vctk-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip',
'md5':
'7da8f88359bca2457e705d924cf27bd4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
# wavernn
"wavernn_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip',
'md5':
'ee37b752f09bcba8f2af3b777ca38e13',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_400000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
"fastspeech2_cnndecoder_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip',
'md5':
'6eb28e22ace73e0ebe7845f86478f89f',
'config':
'cnndecoder.yaml',
'ckpt':
'snapshot_iter_153000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
},
}
tts_static_pretrained_models = {
# speedyspeech
"speedyspeech_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip',
'md5':
'f10cbdedf47dc7a9668d2264494e1823',
'model':
'speedyspeech_csmsc.pdmodel',
'params':
'speedyspeech_csmsc.pdiparams',
'phones_dict':
'phone_id_map.txt',
'tones_dict':
'tone_id_map.txt',
'sample_rate':
24000,
},
},
# fastspeech2
"fastspeech2_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip',
'md5':
'9788cd9745e14c7a5d12d32670b2a5a7',
'model':
'fastspeech2_csmsc.pdmodel',
'params':
'fastspeech2_csmsc.pdiparams',
'phones_dict':
'phone_id_map.txt',
'sample_rate':
24000,
},
},
# pwgan
"pwgan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip',
'md5':
'e3504aed9c5a290be12d1347836d2742',
'model':
'pwgan_csmsc.pdmodel',
'params':
'pwgan_csmsc.pdiparams',
'sample_rate':
24000,
},
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip',
'md5':
'ac6eee94ba483421d750433f4c3b8d36',
'model':
'mb_melgan_csmsc.pdmodel',
'params':
'mb_melgan_csmsc.pdiparams',
'sample_rate':
24000,
},
},
# hifigan
"hifigan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip',
'md5':
'7edd8c436b3a5546b3a7cb8cff9d5a0c',
'model':
'hifigan_csmsc.pdmodel',
'params':
'hifigan_csmsc.pdiparams',
'sample_rate':
24000,
},
},
}
tts_onnx_pretrained_models = {
# fastspeech2
"fastspeech2_csmsc_onnx-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip',
'md5':
'fd3ad38d83273ad51f0ea4f4abf3ab4e',
'ckpt': ['fastspeech2_csmsc.onnx'],
'phones_dict':
'phone_id_map.txt',
'sample_rate':
24000,
},
},
"fastspeech2_cnndecoder_csmsc_onnx-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip',
'md5':
'5f70e1a6bcd29d72d54e7931aa86f266',
'ckpt': [
'fastspeech2_csmsc_am_encoder_infer.onnx',
'fastspeech2_csmsc_am_decoder.onnx',
'fastspeech2_csmsc_am_postnet.onnx',
],
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'sample_rate':
24000,
},
},
# mb_melgan
"mb_melgan_csmsc_onnx-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip',
'md5':
'5b83ec746e8414bc29032d954ffd07ec',
'ckpt':
'mb_melgan_csmsc.onnx',
'sample_rate':
24000,
},
},
# hifigan
"hifigan_csmsc_onnx-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip',
'md5':
'1a7dc0385875889e46952e50c0994a6b',
'ckpt':
'hifigan_csmsc.onnx',
'sample_rate':
24000,
},
},
}
# ---------------------------------
# ------------ Vector -------------
# ---------------------------------
vector_dynamic_pretrained_models = {
"ecapatdnn_voxceleb12-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz',
'md5':
'cc33023c54ab346cd318408f43fcaf95',
'cfg_path':
'conf/model.yaml', # the yaml config path
'ckpt_path':
'model/model', # the format is ${dir}/{model_name},
# so the first 'model' is dir, the second 'model' is the name
# this means we have a model stored as model/model.pdparams
},
},
}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict
from typing import Dict
from typing import List
from typing import Optional
from ..cli.utils import download_and_decompress
from ..cli.utils import MODEL_HOME
from ..utils.dynamic_import import dynamic_import
from .model_alias import model_alias
task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector']
model_format_supported = ['dynamic', 'static', 'onnx']
inference_mode_supported = ['online', 'offline']
class CommonTaskResource:
def __init__(self, task: str, model_format: str='dynamic', **kwargs):
assert task in task_supported, 'Arg "task" must be one of {}.'.format(
task_supported)
assert model_format in model_format_supported, 'Arg "model_format" must be one of {}.'.format(
model_format_supported)
self.task = task
self.model_format = model_format
self.pretrained_models = self._get_pretrained_models()
if 'inference_mode' in kwargs:
assert kwargs[
'inference_mode'] in inference_mode_supported, 'Arg "inference_mode" must be one of {}.'.format(
inference_mode_supported)
self._inference_mode_filter(kwargs['inference_mode'])
# Initialize after model and version had been set.
self.model_tag = None
self.version = None
self.res_dict = None
self.res_dir = None
if self.task == 'tts':
# For vocoder
self.voc_model_tag = None
self.voc_version = None
self.voc_res_dict = None
self.voc_res_dir = None
def set_task_model(self,
model_tag: str,
model_type: int=0,
version: Optional[str]=None):
"""Set model tag and version of current task.
Args:
model_tag (str): Model tag.
model_type (int): 0 for acoustic model otherwise vocoder in tts task.
version (Optional[str], optional): Version of pretrained model. Defaults to None.
"""
assert model_tag in self.pretrained_models, \
"Can't find \"{}\" in resource. Model name must be one of {}".format(model_tag, list(self.pretrained_models.keys()))
if version is None:
version = self._get_default_version(model_tag)
assert version in self.pretrained_models[model_tag], \
"Can't find version \"{}\" in \"{}\". Model name must be one of {}".format(
version, model_tag, list(self.pretrained_models[model_tag].keys()))
if model_type == 0:
self.model_tag = model_tag
self.version = version
self.res_dict = self.pretrained_models[model_tag][version]
self.format_path(self.res_dict)
self.res_dir = self._fetch(self.res_dict,
self._get_model_dir(model_type))
else:
assert self.task == 'tts', 'Vocoder will only be used in tts task.'
self.voc_model_tag = model_tag
self.voc_version = version
self.voc_res_dict = self.pretrained_models[model_tag][version]
self.format_path(self.voc_res_dict)
self.voc_res_dir = self._fetch(self.voc_res_dict,
self._get_model_dir(model_type))
@staticmethod
def format_path(res_dict: Dict[str, str]):
for k, v in res_dict.items():
if '/' in v:
if v.startswith('https://') or v.startswith('http://'):
continue
else:
res_dict[k] = os.path.join(*(v.split('/')))
@staticmethod
def get_model_class(model_name) -> List[object]:
"""Dynamic import model class.
Args:
model_name (str): Model name.
Returns:
List[object]: Return a list of model class.
"""
assert model_name in model_alias, 'No model classes found for "{}"'.format(
model_name)
ret = []
for import_path in model_alias[model_name]:
ret.append(dynamic_import(import_path))
if len(ret) == 1:
return ret[0]
else:
return ret
def get_versions(self, model_tag: str) -> List[str]:
"""List all available versions.
Args:
model_tag (str): Model tag.
Returns:
List[str]: Version list of model.
"""
return list(self.pretrained_models[model_tag].keys())
def _get_default_version(self, model_tag: str) -> str:
"""Get default version of model.
Args:
model_tag (str): Model tag.
Returns:
str: Default version.
"""
return self.get_versions(model_tag)[-1] # get latest version
def _get_model_dir(self, model_type: int=0) -> os.PathLike:
"""Get resource directory.
Args:
model_type (int): 0 for acoustic model otherwise vocoder in tts task.
Returns:
os.PathLike: Directory of model resource.
"""
if model_type == 0:
model_tag = self.model_tag
version = self.version
else:
model_tag = self.voc_model_tag
version = self.voc_version
return os.path.join(MODEL_HOME, model_tag, version)
def _get_pretrained_models(self) -> Dict[str, str]:
"""Get all available models for current task.
Returns:
Dict[str, str]: A dictionary with model tag and resources info.
"""
try:
import_models = '{}_{}_pretrained_models'.format(self.task,
self.model_format)
exec('from .pretrained_models import {}'.format(import_models))
models = OrderedDict(locals()[import_models])
except ImportError:
models = OrderedDict({}) # no models.
finally:
return models
def _inference_mode_filter(self, inference_mode: Optional[str]):
"""Filter models dict based on inference_mode.
Args:
inference_mode (Optional[str]): 'online', 'offline' or None.
"""
if inference_mode is None:
return
if self.task == 'asr':
online_flags = [
'online' in model_tag
for model_tag in self.pretrained_models.keys()
]
for online_flag, model_tag in zip(
online_flags, list(self.pretrained_models.keys())):
if inference_mode == 'online' and online_flag:
continue
elif inference_mode == 'offline' and not online_flag:
continue
else:
del self.pretrained_models[model_tag]
elif self.task == 'tts':
# Hardcode for tts online models.
tts_online_models = [
'fastspeech2_csmsc-zh', 'fastspeech2_cnndecoder_csmsc-zh',
'mb_melgan_csmsc-zh', 'hifigan_csmsc-zh'
]
for model_tag in list(self.pretrained_models.keys()):
if inference_mode == 'online' and model_tag in tts_online_models:
continue
elif inference_mode == 'offline':
continue
else:
del self.pretrained_models[model_tag]
else:
raise NotImplementedError('Only supports asr and tts task.')
@staticmethod
def _fetch(res_dict: Dict[str, str],
target_dir: os.PathLike) -> os.PathLike:
"""Fetch archive from url.
Args:
res_dict (Dict[str, str]): Info dict of a resource.
target_dir (os.PathLike): Directory to save archives.
Returns:
os.PathLike: Directory of model resource.
"""
return download_and_decompress(res_dict, target_dir)
...@@ -189,25 +189,6 @@ if not hasattr(paddle.Tensor, 'contiguous'): ...@@ -189,25 +189,6 @@ if not hasattr(paddle.Tensor, 'contiguous'):
paddle.static.Variable.contiguous = contiguous paddle.static.Variable.contiguous = contiguous
def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
nargs = len(args)
assert (nargs <= 1)
s = paddle.shape(xs)
if nargs == 1:
return s[args[0]]
else:
return s
#`to_static` do not process `size` property, maybe some `paddle` api dependent on it.
logger.debug(
"override size of paddle.Tensor "
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
)
paddle.Tensor.size = size
paddle.static.Variable.size = size
def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
return xs.reshape(args) return xs.reshape(args)
...@@ -219,7 +200,7 @@ if not hasattr(paddle.Tensor, 'view'): ...@@ -219,7 +200,7 @@ if not hasattr(paddle.Tensor, 'view'):
def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
return xs.reshape(ys.size()) return xs.reshape(paddle.shape(ys))
if not hasattr(paddle.Tensor, 'view_as'): if not hasattr(paddle.Tensor, 'view_as'):
......
...@@ -194,7 +194,7 @@ class BeamSearch(paddle.nn.Layer): ...@@ -194,7 +194,7 @@ class BeamSearch(paddle.nn.Layer):
Args: Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score hyp (Hypothesis): Hypothesis with prefix tokens to score
ids (paddle.Tensor): 1D tensor of new partial tokens to score, ids (paddle.Tensor): 1D tensor of new partial tokens to score,
len(ids) < n_vocab len(ids) < n_vocab
x (paddle.Tensor): Corresponding input feature, (T, D) x (paddle.Tensor): Corresponding input feature, (T, D)
...@@ -224,14 +224,14 @@ class BeamSearch(paddle.nn.Layer): ...@@ -224,14 +224,14 @@ class BeamSearch(paddle.nn.Layer):
ids (paddle.Tensor): The partial token ids(Global) to compute topk. ids (paddle.Tensor): The partial token ids(Global) to compute topk.
Returns: Returns:
Tuple[paddle.Tensor, paddle.Tensor]: Tuple[paddle.Tensor, paddle.Tensor]:
The topk full token ids and partial token ids. The topk full token ids and partial token ids.
Their shapes are `(self.beam_size,)`. Their shapes are `(self.beam_size,)`.
i.e. (global ids, global relative local ids). i.e. (global ids, global relative local ids).
""" """
# no pre beam performed, `ids` equal to `weighted_scores` # no pre beam performed, `ids` equal to `weighted_scores`
if weighted_scores.size(0) == ids.size(0): if paddle.shape(weighted_scores)[0] == paddle.shape(ids)[0]:
top_ids = weighted_scores.topk( top_ids = weighted_scores.topk(
self.beam_size)[1] # index in n_vocab self.beam_size)[1] # index in n_vocab
return top_ids, top_ids return top_ids, top_ids
...@@ -370,13 +370,13 @@ class BeamSearch(paddle.nn.Layer): ...@@ -370,13 +370,13 @@ class BeamSearch(paddle.nn.Layer):
""" """
# set length bounds # set length bounds
if maxlenratio == 0: if maxlenratio == 0:
maxlen = x.shape[0] maxlen = paddle.shape(x)[0]
elif maxlenratio < 0: elif maxlenratio < 0:
maxlen = -1 * int(maxlenratio) maxlen = -1 * int(maxlenratio)
else: else:
maxlen = max(1, int(maxlenratio * x.size(0))) maxlen = max(1, int(maxlenratio * paddle.shape(x)[0]))
minlen = int(minlenratio * x.size(0)) minlen = int(minlenratio * paddle.shape(x)[0])
logger.info("decoder input length: " + str(x.shape[0])) logger.info("decoder input length: " + str(paddle.shape(x)[0]))
logger.info("max output length: " + str(maxlen)) logger.info("max output length: " + str(maxlen))
logger.info("min output length: " + str(minlen)) logger.info("min output length: " + str(minlen))
......
...@@ -69,7 +69,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface): ...@@ -69,7 +69,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
return sc[i], st[i] return sc[i], st[i]
else: # for CTCPrefixScorePD (need new_id > 0) else: # for CTCPrefixScorePD (need new_id > 0)
r, log_psi, f_min, f_max, scoring_idmap = state r, log_psi, f_min, f_max, scoring_idmap = state
s = log_psi[i, new_id].expand(log_psi.size(1)) s = log_psi[i, new_id].expand(paddle.shape(log_psi)[1])
if scoring_idmap is not None: if scoring_idmap is not None:
return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max
else: else:
...@@ -107,7 +107,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface): ...@@ -107,7 +107,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
""" """
logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1 logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1
xlen = paddle.to_tensor([logp.size(1)]) xlen = paddle.to_tensor([paddle.shape(logp)[1]])
self.impl = CTCPrefixScorePD(logp, xlen, 0, self.eos) self.impl = CTCPrefixScorePD(logp, xlen, 0, self.eos)
return None return None
......
...@@ -33,9 +33,9 @@ class CTCPrefixScorePD(): ...@@ -33,9 +33,9 @@ class CTCPrefixScorePD():
self.logzero = -10000000000.0 self.logzero = -10000000000.0
self.blank = blank self.blank = blank
self.eos = eos self.eos = eos
self.batch = x.size(0) self.batch = paddle.shape(x)[0]
self.input_length = x.size(1) self.input_length = paddle.shape(x)[1]
self.odim = x.size(2) self.odim = paddle.shape(x)[2]
self.dtype = x.dtype self.dtype = x.dtype
# Pad the rest of posteriors in the batch # Pad the rest of posteriors in the batch
...@@ -76,8 +76,7 @@ class CTCPrefixScorePD(): ...@@ -76,8 +76,7 @@ class CTCPrefixScorePD():
last_ids = [yi[-1] for yi in y] # last output label ids last_ids = [yi[-1] for yi in y] # last output label ids
n_bh = len(last_ids) # batch * hyps n_bh = len(last_ids) # batch * hyps
n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
self.scoring_num = scoring_ids.size( self.scoring_num = paddle.shape(scoring_ids)[-1] if scoring_ids is not None else 0
-1) if scoring_ids is not None else 0
# prepare state info # prepare state info
if state is None: if state is None:
r_prev = paddle.full( r_prev = paddle.full(
...@@ -153,7 +152,7 @@ class CTCPrefixScorePD(): ...@@ -153,7 +152,7 @@ class CTCPrefixScorePD():
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h)) # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
for t in range(start, end): for t in range(start, end):
rp = r[t - 1] # (2 x BW x O') rp = r[t - 1] # (2 x BW x O')
rr = paddle.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view( rr = paddle.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
2, 2, n_bh, snum) # (2,2,BW,O') 2, 2, n_bh, snum) # (2,2,BW,O')
r[t] = paddle.logsumexp(rr, 1) + x_[:, t] r[t] = paddle.logsumexp(rr, 1) + x_[:, t]
...@@ -227,7 +226,7 @@ class CTCPrefixScorePD(): ...@@ -227,7 +226,7 @@ class CTCPrefixScorePD():
if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O) if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
# Pad the rest of posteriors in the batch # Pad the rest of posteriors in the batch
# TODO(takaaki-hori): need a better way without for-loops # TODO(takaaki-hori): need a better way without for-loops
xlens = [x.size(1)] xlens = [paddle.shape(x)[1]]
for i, l in enumerate(xlens): for i, l in enumerate(xlens):
if l < self.input_length: if l < self.input_length:
x[i, l:, :] = self.logzero x[i, l:, :] = self.logzero
...@@ -237,7 +236,7 @@ class CTCPrefixScorePD(): ...@@ -237,7 +236,7 @@ class CTCPrefixScorePD():
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
self.x = paddle.stack([xn, xb]) # (2, T, B, O) self.x = paddle.stack([xn, xb]) # (2, T, B, O)
self.x[:, :tmp_x.shape[1], :, :] = tmp_x self.x[:, :tmp_x.shape[1], :, :] = tmp_x
self.input_length = x.size(1) self.input_length = paddle.shape(x)[1]
self.end_frames = paddle.to_tensor(xlens) - 1 self.end_frames = paddle.to_tensor(xlens) - 1
def extend_state(self, state): def extend_state(self, state):
...@@ -318,16 +317,16 @@ class CTCPrefixScore(): ...@@ -318,16 +317,16 @@ class CTCPrefixScore():
r[0, 0] = xs[0] r[0, 0] = xs[0]
r[0, 1] = self.logzero r[0, 1] = self.logzero
else: else:
# Although the code does not exactly follow Algorithm 2, # Although the code does not exactly follow Algorithm 2,
# we don't have to change it because we can assume # we don't have to change it because we can assume
# r_t(h)=0 for t < |h| in CTC forward computation # r_t(h)=0 for t < |h| in CTC forward computation
# (Note: we assume here that index t starts with 0). # (Note: we assume here that index t starts with 0).
# The purpose of this difference is to reduce the number of for-loops. # The purpose of this difference is to reduce the number of for-loops.
# https://github.com/espnet/espnet/pull/3655 # https://github.com/espnet/espnet/pull/3655
# where we start to accumulate r_t(h) from t=|h| # where we start to accumulate r_t(h) from t=|h|
# and iterate r_t(h) = (r_{t-1}(h) + ...) to T-1, # and iterate r_t(h) = (r_{t-1}(h) + ...) to T-1,
# avoiding accumulating zeros for t=1~|h|-1. # avoiding accumulating zeros for t=1~|h|-1.
# Thus, we need to set r_{|h|-1}(h) = 0, # Thus, we need to set r_{|h|-1}(h) = 0,
# i.e., r[output_length-1] = logzero, for initialization. # i.e., r[output_length-1] = logzero, for initialization.
# This is just for reducing the computation. # This is just for reducing the computation.
r[output_length - 1] = self.logzero r[output_length - 1] = self.logzero
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -14,13 +14,15 @@ ...@@ -14,13 +14,15 @@
from .deepspeech2 import DeepSpeech2InferModel from .deepspeech2 import DeepSpeech2InferModel
from .deepspeech2 import DeepSpeech2Model from .deepspeech2 import DeepSpeech2Model
from paddlespeech.s2t.utils import dynamic_pip_install from paddlespeech.s2t.utils import dynamic_pip_install
import sys
try: try:
import paddlespeech_ctcdecoders import paddlespeech_ctcdecoders
except ImportError: except ImportError:
try: try:
package_name = 'paddlespeech_ctcdecoders' package_name = 'paddlespeech_ctcdecoders'
dynamic_pip_install.install(package_name) if sys.platform != "win32":
dynamic_pip_install.install(package_name)
except Exception: except Exception:
raise RuntimeError( raise RuntimeError(
"Can not install package paddlespeech_ctcdecoders on your system. \ "Can not install package paddlespeech_ctcdecoders on your system. \
......
...@@ -14,13 +14,15 @@ ...@@ -14,13 +14,15 @@
from .deepspeech2 import DeepSpeech2InferModelOnline from .deepspeech2 import DeepSpeech2InferModelOnline
from .deepspeech2 import DeepSpeech2ModelOnline from .deepspeech2 import DeepSpeech2ModelOnline
from paddlespeech.s2t.utils import dynamic_pip_install from paddlespeech.s2t.utils import dynamic_pip_install
import sys
try: try:
import paddlespeech_ctcdecoders import paddlespeech_ctcdecoders
except ImportError: except ImportError:
try: try:
package_name = 'paddlespeech_ctcdecoders' package_name = 'paddlespeech_ctcdecoders'
dynamic_pip_install.install(package_name) if sys.platform != "win32":
dynamic_pip_install.install(package_name)
except Exception: except Exception:
raise RuntimeError( raise RuntimeError(
"Can not install package paddlespeech_ctcdecoders on your system. \ "Can not install package paddlespeech_ctcdecoders on your system. \
......
...@@ -90,7 +90,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): ...@@ -90,7 +90,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
def _target_mask(self, ys_in_pad): def _target_mask(self, ys_in_pad):
ys_mask = ys_in_pad != 0 ys_mask = ys_in_pad != 0
m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0) m = subsequent_mask(paddle.shape(ys_mask)[-1])).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m return ys_mask.unsqueeze(-2) & m
def forward(self, x: paddle.Tensor, t: paddle.Tensor def forward(self, x: paddle.Tensor, t: paddle.Tensor
...@@ -112,7 +112,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): ...@@ -112,7 +112,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
in perplexity: p(t)^{-n} = exp(-log p(t) / n) in perplexity: p(t)^{-n} = exp(-log p(t) / n)
""" """
batch_size = x.size(0) batch_size = paddle.shape(x)[0]
xm = x != 0 xm = x != 0
xlen = xm.sum(axis=1) xlen = xm.sum(axis=1)
if self.embed_drop is not None: if self.embed_drop is not None:
...@@ -122,7 +122,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): ...@@ -122,7 +122,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
h, _ = self.encoder(emb, xlen) h, _ = self.encoder(emb, xlen)
y = self.decoder(h) y = self.decoder(h)
loss = F.cross_entropy( loss = F.cross_entropy(
y.view(-1, y.shape[-1]), t.view(-1), reduction="none") y.view(-1, paddle.shape(y)[-1]), t.view(-1), reduction="none")
mask = xm.to(loss.dtype) mask = xm.to(loss.dtype)
logp = loss * mask.view(-1) logp = loss * mask.view(-1)
nll = logp.view(batch_size, -1).sum(-1) nll = logp.view(batch_size, -1).sum(-1)
......
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -775,7 +776,7 @@ class U2DecodeModel(U2BaseModel): ...@@ -775,7 +776,7 @@ class U2DecodeModel(U2BaseModel):
""" """
self.eval() self.eval()
x = paddle.to_tensor(x).unsqueeze(0) x = paddle.to_tensor(x).unsqueeze(0)
ilen = x.size(1) ilen = paddle.shape(x)[1]
enc_output, _ = self._forward_encoder(x, ilen) enc_output, _ = self._forward_encoder(x, ilen)
return enc_output.squeeze(0) return enc_output.squeeze(0)
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
from contextlib import nullcontext from contextlib import nullcontext
import paddle import paddle
......
...@@ -22,6 +22,7 @@ from paddlespeech.s2t.modules.align import Linear ...@@ -22,6 +22,7 @@ from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.modules.loss import CTCLoss from paddlespeech.s2t.modules.loss import CTCLoss
from paddlespeech.s2t.utils import ctc_utils from paddlespeech.s2t.utils import ctc_utils
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
import sys
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -34,7 +35,8 @@ except ImportError: ...@@ -34,7 +35,8 @@ except ImportError:
try: try:
from paddlespeech.s2t.utils import dynamic_pip_install from paddlespeech.s2t.utils import dynamic_pip_install
package_name = 'paddlespeech_ctcdecoders' package_name = 'paddlespeech_ctcdecoders'
dynamic_pip_install.install(package_name) if sys.platform != "win32":
dynamic_pip_install.install(package_name)
from paddlespeech.s2t.decoders.ctcdecoder import ctc_beam_search_decoding_batch # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder import ctc_beam_search_decoding_batch # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder import ctc_greedy_decoding # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder import ctc_greedy_decoding # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401
......
...@@ -242,7 +242,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ...@@ -242,7 +242,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
] ]
# batch decoding # batch decoding
ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) # (B,L,L) ys_mask = subsequent_mask(paddle.shape(ys)[-1]).unsqueeze(0) # (B,L,L)
xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T)
logp, states = self.forward_one_step( logp, states = self.forward_one_step(
xs, xs_mask, ys, ys_mask, cache=batch_state) xs, xs_mask, ys, ys_mask, cache=batch_state)
......
...@@ -115,7 +115,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface): ...@@ -115,7 +115,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
assert offset + x.shape[ assert offset + x.shape[
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len) offset, x.shape[1], self.max_len)
#TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor #TODO(Hui Zhang): using T = paddle.shape(x)[1], __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + T] pos_emb = self.pe[:, offset:offset + T]
x = x * self.xscale + pos_emb x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
...@@ -165,6 +165,6 @@ class RelPositionalEncoding(PositionalEncoding): ...@@ -165,6 +165,6 @@ class RelPositionalEncoding(PositionalEncoding):
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len) offset, x.shape[1], self.max_len)
x = x * self.xscale x = x * self.xscale
#TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor #TODO(Hui Zhang): using paddle.shape(x)[1], __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + x.shape[1]] pos_emb = self.pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
...@@ -218,7 +218,7 @@ class BaseEncoder(nn.Layer): ...@@ -218,7 +218,7 @@ class BaseEncoder(nn.Layer):
assert xs.shape[0] == 1 # batch size must be one assert xs.shape[0] == 1 # batch size must be one
# tmp_masks is just for interface compatibility # tmp_masks is just for interface compatibility
# TODO(Hui Zhang): stride_slice not support bool tensor # TODO(Hui Zhang): stride_slice not support bool tensor
# tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool) # tmp_masks = paddle.ones([1, paddle.shape(xs)[1]], dtype=paddle.bool)
tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32) tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32)
tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T] tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T]
......
...@@ -154,7 +154,8 @@ class SpeedPerturbationSox(): ...@@ -154,7 +154,8 @@ class SpeedPerturbationSox():
package = "sox" package = "sox"
dynamic_pip_install.install(package) dynamic_pip_install.install(package)
package = "soxbindings" package = "soxbindings"
dynamic_pip_install.install(package) if sys.platform != "win32":
dynamic_pip_install.install(package)
import soxbindings as sox import soxbindings as sox
except Exception: except Exception:
raise RuntimeError( raise RuntimeError(
......
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
...@@ -58,7 +58,7 @@ def pad_sequence(sequences: List[paddle.Tensor], ...@@ -58,7 +58,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
>>> a = paddle.ones(25, 300) >>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300) >>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300) >>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).size() >>> pad_sequence([a, b, c]).shape
paddle.Tensor([25, 3, 300]) paddle.Tensor([25, 3, 300])
Note: Note:
...@@ -79,10 +79,10 @@ def pad_sequence(sequences: List[paddle.Tensor], ...@@ -79,10 +79,10 @@ def pad_sequence(sequences: List[paddle.Tensor],
# assuming trailing dimensions and type of all the Tensors # assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0] # in sequences are same and fetching those from sequences[0]
max_size = sequences[0].size() max_size = paddle.shape(sequences[0])
# (TODO Hui Zhang): slice not supprot `end==start` # (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:] # trailing_dims = max_size[1:]
trailing_dims = max_size[1:] if max_size.ndim >= 2 else () trailing_dims = tuple(max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()
max_len = max([s.shape[0] for s in sequences]) max_len = max([s.shape[0] for s in sequences])
if batch_first: if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims out_dims = (len(sequences), max_len) + trailing_dims
...@@ -99,7 +99,7 @@ def pad_sequence(sequences: List[paddle.Tensor], ...@@ -99,7 +99,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
if batch_first: if batch_first:
# TODO (Hui Zhang): set_value op not supprot `end==start` # TODO (Hui Zhang): set_value op not supprot `end==start`
# TODO (Hui Zhang): set_value op not support int16 # TODO (Hui Zhang): set_value op not support int16
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...] # TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
# out_tensor[i, :length, ...] = tensor # out_tensor[i, :length, ...] = tensor
if length != 0: if length != 0:
out_tensor[i, :length] = tensor out_tensor[i, :length] = tensor
...@@ -145,7 +145,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, ...@@ -145,7 +145,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
[ 4, 5, 6, 11, -1, -1], [ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]]) [ 7, 8, 9, 11, -1, -1]])
""" """
# TODO(Hui Zhang): using comment code, # TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor( #_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor( #_eos = paddle.to_tensor(
......
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
...@@ -25,6 +25,7 @@ from ..executor import BaseExecutor ...@@ -25,6 +25,7 @@ from ..executor import BaseExecutor
from ..util import cli_server_register from ..util import cli_server_register
from ..util import stats_wrapper from ..util import stats_wrapper
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource
from paddlespeech.server.engine.engine_pool import init_engine_pool from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.engine.engine_warmup import warm_up from paddlespeech.server.engine.engine_warmup import warm_up
from paddlespeech.server.restful.api import setup_router as setup_http_router from paddlespeech.server.restful.api import setup_router as setup_http_router
...@@ -158,101 +159,30 @@ class ServerStatsExecutor(): ...@@ -158,101 +159,30 @@ class ServerStatsExecutor():
"Please input correct speech task, choices = ['asr', 'tts']") "Please input correct speech task, choices = ['asr', 'tts']")
return False return False
elif self.task.lower() == 'asr': try:
try: # Dynamic models
from paddlespeech.cli.asr.infer import pretrained_models dynamic_pretrained_models = CommonTaskResource(
logger.info( task=self.task, model_format='dynamic').pretrained_models
"Here is the table of ASR pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
# show ASR static pretrained model
from paddlespeech.server.engine.asr.paddleinference.asr_engine import pretrained_models
logger.info(
"Here is the table of ASR static pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error(
"Failed to get the table of ASR pretrained models supported in the service."
)
return False
elif self.task.lower() == 'tts':
try:
from paddlespeech.cli.tts.infer import pretrained_models
logger.info(
"Here is the table of TTS pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
# show TTS static pretrained model
from paddlespeech.server.engine.tts.paddleinference.tts_engine import pretrained_models
logger.info(
"Here is the table of TTS static pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error(
"Failed to get the table of TTS pretrained models supported in the service."
)
return False
elif self.task.lower() == 'cls': if len(dynamic_pretrained_models) > 0:
try:
from paddlespeech.cli.cls.infer import pretrained_models
logger.info( logger.info(
"Here is the table of CLS pretrained models supported in the service." "Here is the table of {} pretrained models supported in the service.".
) format(self.task.upper()))
self.show_support_models(pretrained_models) self.show_support_models(dynamic_pretrained_models)
# show CLS static pretrained model # Static models
from paddlespeech.server.engine.cls.paddleinference.cls_engine import pretrained_models static_pretrained_models = CommonTaskResource(
task=self.task, model_format='static').pretrained_models
if len(static_pretrained_models) > 0:
logger.info( logger.info(
"Here is the table of CLS static pretrained models supported in the service." "Here is the table of {} static pretrained models supported in the service.".
) format(self.task.upper()))
self.show_support_models(pretrained_models) self.show_support_models(pretrained_models)
return True return True
except BaseException:
logger.error(
"Failed to get the table of CLS pretrained models supported in the service."
)
return False
elif self.task.lower() == 'text':
try:
from paddlespeech.cli.text.infer import pretrained_models
logger.info(
"Here is the table of Text pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
return True except BaseException:
except BaseException:
logger.error(
"Failed to get the table of Text pretrained models supported in the service."
)
return False
elif self.task.lower() == 'vector':
try:
from paddlespeech.cli.vector.infer import pretrained_models
logger.info(
"Here is the table of Vector pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error(
"Failed to get the table of Vector pretrained models supported in the service."
)
return False
else:
logger.error( logger.error(
f"Failed to get the table of {self.task} pretrained models supported in the service." "Failed to get the table of {} pretrained models supported in the service.".
) format(self.task.upper()))
return False return False
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import os import os
import sys import sys
from typing import Optional from typing import Optional
...@@ -21,15 +20,14 @@ import paddle ...@@ -21,15 +20,14 @@ import paddle
from numpy import float32 from numpy import float32
from yacs.config import CfgNode from yacs.config import CfgNode
from .pretrained_models import pretrained_models
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.resource import CommonTaskResource
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.tensor_utils import pad_sequence
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
...@@ -53,7 +51,7 @@ class PaddleASRConnectionHanddler: ...@@ -53,7 +51,7 @@ class PaddleASRConnectionHanddler:
logger.info( logger.info(
"create an paddle asr connection handler to process the websocket connection" "create an paddle asr connection handler to process the websocket connection"
) )
self.config = asr_engine.config self.config = asr_engine.config # server config
self.model_config = asr_engine.executor.config self.model_config = asr_engine.executor.config
self.asr_engine = asr_engine self.asr_engine = asr_engine
...@@ -249,10 +247,15 @@ class PaddleASRConnectionHanddler: ...@@ -249,10 +247,15 @@ class PaddleASRConnectionHanddler:
def reset(self): def reset(self):
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
# for deepspeech2 # for deepspeech2
self.chunk_state_h_box = copy.deepcopy( # init state
self.asr_engine.executor.chunk_state_h_box) self.chunk_state_h_box = np.zeros(
self.chunk_state_c_box = copy.deepcopy( (self.model_config.num_rnn_layers, 1,
self.asr_engine.executor.chunk_state_c_box) self.model_config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.decoder.reset_decoder(batch_size=1) self.decoder.reset_decoder(batch_size=1)
self.device = None self.device = None
...@@ -696,7 +699,8 @@ class PaddleASRConnectionHanddler: ...@@ -696,7 +699,8 @@ class PaddleASRConnectionHanddler:
class ASRServerExecutor(ASRExecutor): class ASRServerExecutor(ASRExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.pretrained_models = pretrained_models self.task_resource = CommonTaskResource(
task='asr', model_format='dynamic', inference_mode='online')
def _init_from_path(self, def _init_from_path(self,
model_type: str=None, model_type: str=None,
...@@ -720,20 +724,19 @@ class ASRServerExecutor(ASRExecutor): ...@@ -720,20 +724,19 @@ class ASRServerExecutor(ASRExecutor):
self.sample_rate = sample_rate self.sample_rate = sample_rate
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag)
if cfg_path is None or am_model is None or am_params is None: if cfg_path is None or am_model is None or am_params is None:
logger.info(f"Load the pretrained model, tag = {tag}") logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = self.task_resource.res_dir
self.res_path = res_path
self.cfg_path = os.path.join( self.cfg_path = os.path.join(
res_path, self.pretrained_models[tag]['cfg_path']) self.res_path, self.task_resource.res_dict['cfg_path'])
self.am_model = os.path.join(res_path, self.am_model = os.path.join(self.res_path,
self.pretrained_models[tag]['model']) self.task_resource.res_dict['model'])
self.am_params = os.path.join(res_path, self.am_params = os.path.join(self.res_path,
self.pretrained_models[tag]['params']) self.task_resource.res_dict['params'])
logger.info(res_path) logger.info(self.res_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model) self.am_model = os.path.abspath(am_model)
...@@ -760,8 +763,8 @@ class ASRServerExecutor(ASRExecutor): ...@@ -760,8 +763,8 @@ class ASRServerExecutor(ASRExecutor):
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, vocab=self.vocab) unit_type=self.config.unit_type, vocab=self.vocab)
lm_url = self.pretrained_models[tag]['lm_url'] lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.pretrained_models[tag]['lm_md5'] lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}") logger.info(f"Start to load language model {lm_url}")
self.download_lm( self.download_lm(
lm_url, lm_url,
...@@ -803,41 +806,11 @@ class ASRServerExecutor(ASRExecutor): ...@@ -803,41 +806,11 @@ class ASRServerExecutor(ASRExecutor):
model_file=self.am_model, model_file=self.am_model,
params_file=self.am_params, params_file=self.am_params,
predictor_conf=self.am_predictor_conf) predictor_conf=self.am_predictor_conf)
# decoder
logger.info("ASR engine start to create the ctc decoder instance")
self.decoder = CTCDecoder(
odim=self.config.output_dim, # <blank> is in vocab
enc_n_units=self.config.rnn_layer_size * 2,
blank_id=self.config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
# init decoder
logger.info("ASR engine start to init the ctc decoder")
cfg = self.config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# init state box
self.chunk_state_h_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
logger.info(f"model name: {model_name}") logger.info(f"model name: {model_name}")
model_class = dynamic_import(model_name, self.model_alias) model_class = self.task_resource.get_model_class(model_name)
model_conf = self.config model_conf = self.config
model = model_class.from_config(model_conf) model = model_class.from_config(model_conf)
self.model = model self.model = model
...@@ -847,10 +820,6 @@ class ASRServerExecutor(ASRExecutor): ...@@ -847,10 +820,6 @@ class ASRServerExecutor(ASRExecutor):
model_dict = paddle.load(self.am_model) model_dict = paddle.load(self.am_model)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
logger.info("create the transformer like model success") logger.info("create the transformer like model success")
# update the ctc decoding
self.searcher = CTCPrefixBeamSearch(self.config.decode)
self.transformer_decode_reset()
else: else:
raise ValueError(f"Not support: {model_type}") raise ValueError(f"Not support: {model_type}")
...@@ -881,8 +850,8 @@ class ASREngine(BaseEngine): ...@@ -881,8 +850,8 @@ class ASREngine(BaseEngine):
self.executor = ASRServerExecutor() self.executor = ASRServerExecutor()
try: try:
default_dev = paddle.get_device() self.device = self.config.get("device", paddle.get_device())
paddle.set_device(self.config.get("device", default_dev)) paddle.set_device(self.device)
except BaseException as e: except BaseException as e:
logger.error( logger.error(
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file" f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
"deepspeech2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz',
'md5':
'98b87b171b7240b7cae6e07d8d0bc9be',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"conformer_online_multicn-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz',
'md5':
'0ac93d390552336f2a906aec9e33c5fa',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/multi_cn',
'model':
'exp/chunk_conformer/checkpoints/multi_cn.pdparams',
'params':
'exp/chunk_conformer/checkpoints/multi_cn.pdparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"conformer_online_wenetspeech-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz',
'md5':
'b8c02632b04da34aca88459835be54a6',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_10',
'model':
'exp/chunk_conformer/checkpoints/avg_10.pdparams',
'params':
'exp/chunk_conformer/checkpoints/avg_10.pdparams',
'lm_url':
'',
'lm_md5':
'',
},
}
...@@ -19,10 +19,10 @@ from typing import Optional ...@@ -19,10 +19,10 @@ from typing import Optional
import paddle import paddle
from yacs.config import CfgNode from yacs.config import CfgNode
from .pretrained_models import pretrained_models
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.resource import CommonTaskResource
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
...@@ -36,7 +36,8 @@ __all__ = ['ASREngine', 'PaddleASRConnectionHandler'] ...@@ -36,7 +36,8 @@ __all__ = ['ASREngine', 'PaddleASRConnectionHandler']
class ASRServerExecutor(ASRExecutor): class ASRServerExecutor(ASRExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.pretrained_models = pretrained_models self.task_resource = CommonTaskResource(
task='asr', model_format='static')
def _init_from_path(self, def _init_from_path(self,
model_type: str='wenetspeech', model_type: str='wenetspeech',
...@@ -53,17 +54,17 @@ class ASRServerExecutor(ASRExecutor): ...@@ -53,17 +54,17 @@ class ASRServerExecutor(ASRExecutor):
self.max_len = 50 self.max_len = 50
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag)
if cfg_path is None or am_model is None or am_params is None: if cfg_path is None or am_model is None or am_params is None:
res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = self.task_resource.res_dir
self.res_path = res_path
self.cfg_path = os.path.join( self.cfg_path = os.path.join(
res_path, self.pretrained_models[tag]['cfg_path']) self.res_path, self.task_resource.res_dict['cfg_path'])
self.am_model = os.path.join(res_path, self.am_model = os.path.join(self.res_path,
self.pretrained_models[tag]['model']) self.task_resource.res_dict['model'])
self.am_params = os.path.join(res_path, self.am_params = os.path.join(self.res_path,
self.pretrained_models[tag]['params']) self.task_resource.res_dict['params'])
logger.info(res_path) logger.info(self.res_path)
logger.info(self.cfg_path) logger.info(self.cfg_path)
logger.info(self.am_model) logger.info(self.am_model)
logger.info(self.am_params) logger.info(self.am_params)
...@@ -89,8 +90,8 @@ class ASRServerExecutor(ASRExecutor): ...@@ -89,8 +90,8 @@ class ASRServerExecutor(ASRExecutor):
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, vocab=self.vocab) unit_type=self.config.unit_type, vocab=self.vocab)
lm_url = self.pretrained_models[tag]['lm_url'] lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.pretrained_models[tag]['lm_md5'] lm_md5 = self.task_resource.res_dict['lm_md5']
self.download_lm( self.download_lm(
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
"deepspeech2offline_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'model':
'exp/deepspeech2/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
}
...@@ -21,9 +21,9 @@ import numpy as np ...@@ -21,9 +21,9 @@ import numpy as np
import paddle import paddle
import yaml import yaml
from .pretrained_models import pretrained_models
from paddlespeech.cli.cls.infer import CLSExecutor from paddlespeech.cli.cls.infer import CLSExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model from paddlespeech.server.utils.paddle_predictor import run_model
...@@ -34,11 +34,12 @@ __all__ = ['CLSEngine', 'PaddleCLSConnectionHandler'] ...@@ -34,11 +34,12 @@ __all__ = ['CLSEngine', 'PaddleCLSConnectionHandler']
class CLSServerExecutor(CLSExecutor): class CLSServerExecutor(CLSExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.pretrained_models = pretrained_models self.task_resource = CommonTaskResource(
task='cls', model_format='static')
def _init_from_path( def _init_from_path(
self, self,
model_type: str='panns_cnn14', model_type: str='panns_cnn14_audioset',
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
model_path: Optional[os.PathLike]=None, model_path: Optional[os.PathLike]=None,
params_path: Optional[os.PathLike]=None, params_path: Optional[os.PathLike]=None,
...@@ -50,15 +51,16 @@ class CLSServerExecutor(CLSExecutor): ...@@ -50,15 +51,16 @@ class CLSServerExecutor(CLSExecutor):
if cfg_path is None or model_path is None or params_path is None or label_file is None: if cfg_path is None or model_path is None or params_path is None or label_file is None:
tag = model_type + '-' + '32k' tag = model_type + '-' + '32k'
self.res_path = self._get_pretrained_path(tag) self.task_resource.set_task_model(model_tag=tag)
self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join( self.cfg_path = os.path.join(
self.res_path, self.pretrained_models[tag]['cfg_path']) self.res_path, self.task_resource.res_dict['cfg_path'])
self.model_path = os.path.join( self.model_path = os.path.join(
self.res_path, self.pretrained_models[tag]['model_path']) self.res_path, self.task_resource.res_dict['model_path'])
self.params_path = os.path.join( self.params_path = os.path.join(
self.res_path, self.pretrained_models[tag]['params_path']) self.res_path, self.task_resource.res_dict['params_path'])
self.label_file = os.path.join( self.label_file = os.path.join(
self.res_path, self.pretrained_models[tag]['label_file']) self.res_path, self.task_resource.res_dict['label_file'])
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.model_path = os.path.abspath(model_path) self.model_path = os.path.abspath(model_path)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
"panns_cnn6-32k": {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz',
'md5':
'da087c31046d23281d8ec5188c1967da',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
"panns_cnn10-32k": {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz',
'md5':
'5460cc6eafbfaf0f261cc75b90284ae1',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
"panns_cnn14-32k": {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz',
'md5':
'ccc80b194821274da79466862b2ab00f',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# support online model
pretrained_models = {
# fastspeech2
"fastspeech2_csmsc_onnx-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip',
'md5':
'fd3ad38d83273ad51f0ea4f4abf3ab4e',
'ckpt': ['fastspeech2_csmsc.onnx'],
'phones_dict':
'phone_id_map.txt',
'sample_rate':
24000,
},
"fastspeech2_cnndecoder_csmsc_onnx-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip',
'md5':
'5f70e1a6bcd29d72d54e7931aa86f266',
'ckpt': [
'fastspeech2_csmsc_am_encoder_infer.onnx',
'fastspeech2_csmsc_am_decoder.onnx',
'fastspeech2_csmsc_am_postnet.onnx',
],
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'sample_rate':
24000,
},
# mb_melgan
"mb_melgan_csmsc_onnx-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip',
'md5':
'5b83ec746e8414bc29032d954ffd07ec',
'ckpt':
'mb_melgan_csmsc.onnx',
'sample_rate':
24000,
},
# hifigan
"hifigan_csmsc_onnx-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip',
'md5':
'1a7dc0385875889e46952e50c0994a6b',
'ckpt':
'hifigan_csmsc.onnx',
'sample_rate':
24000,
},
}
...@@ -20,9 +20,9 @@ from typing import Optional ...@@ -20,9 +20,9 @@ from typing import Optional
import numpy as np import numpy as np
import paddle import paddle
from .pretrained_models import pretrained_models
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.resource import CommonTaskResource
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.onnx_infer import get_sess from paddlespeech.server.utils.onnx_infer import get_sess
...@@ -37,7 +37,7 @@ __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] ...@@ -37,7 +37,7 @@ __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor): class TTSServerExecutor(TTSExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.pretrained_models = pretrained_models self.task_resource = CommonTaskResource(task='tts', model_format='onnx')
def _init_from_path( def _init_from_path(
self, self,
...@@ -66,16 +66,21 @@ class TTSServerExecutor(TTSExecutor): ...@@ -66,16 +66,21 @@ class TTSServerExecutor(TTSExecutor):
return return
# am # am
am_tag = am + '-' + lang am_tag = am + '-' + lang
self.task_resource.set_task_model(
model_tag=am_tag,
model_type=0, # am
version=None, # default version
)
self.am_res_path = self.task_resource.res_dir
if am == "fastspeech2_csmsc_onnx": if am == "fastspeech2_csmsc_onnx":
# get model info # get model info
if am_ckpt is None or phones_dict is None: if am_ckpt is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_ckpt = os.path.join( self.am_ckpt = os.path.join(
am_res_path, self.pretrained_models[am_tag]['ckpt'][0]) self.am_res_path, self.task_resource.res_dict['ckpt'][0])
# must have phones_dict in acoustic # must have phones_dict in acoustic
self.phones_dict = os.path.join( self.phones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['phones_dict']) self.am_res_path,
self.task_resource.res_dict['phones_dict'])
else: else:
self.am_ckpt = os.path.abspath(am_ckpt[0]) self.am_ckpt = os.path.abspath(am_ckpt[0])
...@@ -88,19 +93,19 @@ class TTSServerExecutor(TTSExecutor): ...@@ -88,19 +93,19 @@ class TTSServerExecutor(TTSExecutor):
elif am == "fastspeech2_cnndecoder_csmsc_onnx": elif am == "fastspeech2_cnndecoder_csmsc_onnx":
if am_ckpt is None or am_stat is None or phones_dict is None: if am_ckpt is None or am_stat is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_encoder_infer = os.path.join( self.am_encoder_infer = os.path.join(
am_res_path, self.pretrained_models[am_tag]['ckpt'][0]) self.am_res_path, self.task_resource.res_dict['ckpt'][0])
self.am_decoder = os.path.join( self.am_decoder = os.path.join(
am_res_path, self.pretrained_models[am_tag]['ckpt'][1]) self.am_res_path, self.task_resource.res_dict['ckpt'][1])
self.am_postnet = os.path.join( self.am_postnet = os.path.join(
am_res_path, self.pretrained_models[am_tag]['ckpt'][2]) self.am_res_path, self.task_resource.res_dict['ckpt'][2])
# must have phones_dict in acoustic # must have phones_dict in acoustic
self.phones_dict = os.path.join( self.phones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['phones_dict']) self.am_res_path,
self.task_resource.res_dict['phones_dict'])
self.am_stat = os.path.join( self.am_stat = os.path.join(
am_res_path, self.pretrained_models[am_tag]['speech_stats']) self.am_res_path,
self.task_resource.res_dict['speech_stats'])
else: else:
self.am_encoder_infer = os.path.abspath(am_ckpt[0]) self.am_encoder_infer = os.path.abspath(am_ckpt[0])
...@@ -125,11 +130,15 @@ class TTSServerExecutor(TTSExecutor): ...@@ -125,11 +130,15 @@ class TTSServerExecutor(TTSExecutor):
# voc model info # voc model info
voc_tag = voc + '-' + lang voc_tag = voc + '-' + lang
self.task_resource.set_task_model(
model_tag=voc_tag,
model_type=1, # vocoder
version=None, # default version
)
if voc_ckpt is None: if voc_ckpt is None:
voc_res_path = self._get_pretrained_path(voc_tag) self.voc_res_path = self.task_resource.voc_res_dir
self.voc_res_path = voc_res_path
self.voc_ckpt = os.path.join( self.voc_ckpt = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['ckpt']) self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])
else: else:
self.voc_ckpt = os.path.abspath(voc_ckpt) self.voc_ckpt = os.path.abspath(voc_ckpt)
self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt)) self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt))
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# support online model
pretrained_models = {
# fastspeech2
"fastspeech2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip',
'md5':
'637d28a5e53aa60275612ba4393d5f22',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_76000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_cnndecoder_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip',
'md5':
'6eb28e22ace73e0ebe7845f86478f89f',
'config':
'cnndecoder.yaml',
'ckpt':
'snapshot_iter_153000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'ee5f0604e20091f0d495b6ec4618b90d',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
}
...@@ -22,9 +22,9 @@ import paddle ...@@ -22,9 +22,9 @@ import paddle
import yaml import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from .pretrained_models import pretrained_models
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.resource import CommonTaskResource
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.util import denorm from paddlespeech.server.utils.util import denorm
...@@ -32,7 +32,6 @@ from paddlespeech.server.utils.util import get_chunks ...@@ -32,7 +32,6 @@ from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
...@@ -40,7 +39,9 @@ __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] ...@@ -40,7 +39,9 @@ __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor): class TTSServerExecutor(TTSExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.pretrained_models = pretrained_models self.task_resource = CommonTaskResource(
task='tts', model_format='static', inference_mode='online')
def get_model_info(self, def get_model_info(self,
field: str, field: str,
...@@ -61,7 +62,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -61,7 +62,7 @@ class TTSServerExecutor(TTSExecutor):
[Tensor]: standard deviation [Tensor]: standard deviation
""" """
model_class = dynamic_import(model_name, self.model_alias) model_class = self.task_resource.get_model_class(model_name)
if field == "am": if field == "am":
odim = self.am_config.n_mels odim = self.am_config.n_mels
...@@ -106,20 +107,24 @@ class TTSServerExecutor(TTSExecutor): ...@@ -106,20 +107,24 @@ class TTSServerExecutor(TTSExecutor):
return return
# am model info # am model info
am_tag = am + '-' + lang am_tag = am + '-' + lang
self.task_resource.set_task_model(
model_tag=am_tag,
model_type=0, # am
version=None, # default version
)
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag) self.am_res_path = self.task_resource.res_dir
self.am_res_path = am_res_path self.am_config = os.path.join(self.am_res_path,
self.am_config = os.path.join( self.task_resource.res_dict['config'])
am_res_path, self.pretrained_models[am_tag]['config']) self.am_ckpt = os.path.join(self.am_res_path,
self.am_ckpt = os.path.join(am_res_path, self.task_resource.res_dict['ckpt'])
self.pretrained_models[am_tag]['ckpt'])
self.am_stat = os.path.join( self.am_stat = os.path.join(
am_res_path, self.pretrained_models[am_tag]['speech_stats']) self.am_res_path, self.task_resource.res_dict['speech_stats'])
# must have phones_dict in acoustic # must have phones_dict in acoustic
self.phones_dict = os.path.join( self.phones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['phones_dict']) self.am_res_path, self.task_resource.res_dict['phones_dict'])
print("self.phones_dict:", self.phones_dict) print("self.phones_dict:", self.phones_dict)
logger.info(am_res_path) logger.info(self.am_res_path)
logger.info(self.am_config) logger.info(self.am_config)
logger.info(self.am_ckpt) logger.info(self.am_ckpt)
else: else:
...@@ -135,16 +140,21 @@ class TTSServerExecutor(TTSExecutor): ...@@ -135,16 +140,21 @@ class TTSServerExecutor(TTSExecutor):
# voc model info # voc model info
voc_tag = voc + '-' + lang voc_tag = voc + '-' + lang
self.task_resource.set_task_model(
model_tag=voc_tag,
model_type=1, # vocoder
version=None, # default version
)
if voc_ckpt is None or voc_config is None or voc_stat is None: if voc_ckpt is None or voc_config is None or voc_stat is None:
voc_res_path = self._get_pretrained_path(voc_tag) self.voc_res_path = self.task_resource.voc_res_dir
self.voc_res_path = voc_res_path
self.voc_config = os.path.join( self.voc_config = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['config']) self.voc_res_path, self.task_resource.voc_res_dict['config'])
self.voc_ckpt = os.path.join( self.voc_ckpt = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['ckpt']) self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])
self.voc_stat = os.path.join( self.voc_stat = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['speech_stats']) self.voc_res_path,
logger.info(voc_res_path) self.task_resource.voc_res_dict['speech_stats'])
logger.info(self.voc_res_path)
logger.info(self.voc_config) logger.info(self.voc_config)
logger.info(self.voc_ckpt) logger.info(self.voc_ckpt)
else: else:
...@@ -184,8 +194,8 @@ class TTSServerExecutor(TTSExecutor): ...@@ -184,8 +194,8 @@ class TTSServerExecutor(TTSExecutor):
am, am_mu, am_std = self.get_model_info("am", self.am_name, am, am_mu, am_std = self.get_model_info("am", self.am_name,
self.am_ckpt, self.am_stat) self.am_ckpt, self.am_stat)
am_normalizer = ZScore(am_mu, am_std) am_normalizer = ZScore(am_mu, am_std)
am_inference_class = dynamic_import(self.am_name + '_inference', am_inference_class = self.task_resource.get_model_class(
self.model_alias) self.am_name + '_inference')
self.am_inference = am_inference_class(am_normalizer, am) self.am_inference = am_inference_class(am_normalizer, am)
self.am_inference.eval() self.am_inference.eval()
print("acoustic model done!") print("acoustic model done!")
...@@ -195,8 +205,8 @@ class TTSServerExecutor(TTSExecutor): ...@@ -195,8 +205,8 @@ class TTSServerExecutor(TTSExecutor):
voc, voc_mu, voc_std = self.get_model_info("voc", self.voc_name, voc, voc_mu, voc_std = self.get_model_info("voc", self.voc_name,
self.voc_ckpt, self.voc_stat) self.voc_ckpt, self.voc_stat)
voc_normalizer = ZScore(voc_mu, voc_std) voc_normalizer = ZScore(voc_mu, voc_std)
voc_inference_class = dynamic_import(self.voc_name + '_inference', voc_inference_class = self.task_resource.get_model_class(self.voc_name +
self.model_alias) '_inference')
self.voc_inference = voc_inference_class(voc_normalizer, voc) self.voc_inference = voc_inference_class(voc_normalizer, voc)
self.voc_inference.eval() self.voc_inference.eval()
print("voc done!") print("voc done!")
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Static model applied on paddle inference
pretrained_models = {
# speedyspeech
"speedyspeech_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip',
'md5':
'f10cbdedf47dc7a9668d2264494e1823',
'model':
'speedyspeech_csmsc.pdmodel',
'params':
'speedyspeech_csmsc.pdiparams',
'phones_dict':
'phone_id_map.txt',
'tones_dict':
'tone_id_map.txt',
'sample_rate':
24000,
},
# fastspeech2
"fastspeech2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip',
'md5':
'9788cd9745e14c7a5d12d32670b2a5a7',
'model':
'fastspeech2_csmsc.pdmodel',
'params':
'fastspeech2_csmsc.pdiparams',
'phones_dict':
'phone_id_map.txt',
'sample_rate':
24000,
},
# pwgan
"pwgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip',
'md5':
'e3504aed9c5a290be12d1347836d2742',
'model':
'pwgan_csmsc.pdmodel',
'params':
'pwgan_csmsc.pdiparams',
'sample_rate':
24000,
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip',
'md5':
'ac6eee94ba483421d750433f4c3b8d36',
'model':
'mb_melgan_csmsc.pdmodel',
'params':
'mb_melgan_csmsc.pdiparams',
'sample_rate':
24000,
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip',
'md5':
'7edd8c436b3a5546b3a7cb8cff9d5a0c',
'model':
'hifigan_csmsc.pdmodel',
'params':
'hifigan_csmsc.pdiparams',
'sample_rate':
24000,
},
}
...@@ -24,9 +24,9 @@ import paddle ...@@ -24,9 +24,9 @@ import paddle
import soundfile as sf import soundfile as sf
from scipy.io import wavfile from scipy.io import wavfile
from .pretrained_models import pretrained_models
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.resource import CommonTaskResource
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import change_speed from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.utils.errors import ErrorCode from paddlespeech.server.utils.errors import ErrorCode
...@@ -42,7 +42,8 @@ __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] ...@@ -42,7 +42,8 @@ __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler']
class TTSServerExecutor(TTSExecutor): class TTSServerExecutor(TTSExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.pretrained_models = pretrained_models self.task_resource = CommonTaskResource(
task='tts', model_format='static')
def _init_from_path( def _init_from_path(
self, self,
...@@ -68,19 +69,23 @@ class TTSServerExecutor(TTSExecutor): ...@@ -68,19 +69,23 @@ class TTSServerExecutor(TTSExecutor):
return return
# am # am
am_tag = am + '-' + lang am_tag = am + '-' + lang
self.task_resource.set_task_model(
model_tag=am_tag,
model_type=0, # am
version=None, # default version
)
if am_model is None or am_params is None or phones_dict is None: if am_model is None or am_params is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag) self.am_res_path = self.task_resource.res_dir
self.am_res_path = am_res_path self.am_model = os.path.join(self.am_res_path,
self.am_model = os.path.join( self.task_resource.res_dict['model'])
am_res_path, self.pretrained_models[am_tag]['model']) self.am_params = os.path.join(self.am_res_path,
self.am_params = os.path.join( self.task_resource.res_dict['params'])
am_res_path, self.pretrained_models[am_tag]['params'])
# must have phones_dict in acoustic # must have phones_dict in acoustic
self.phones_dict = os.path.join( self.phones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['phones_dict']) self.am_res_path, self.task_resource.res_dict['phones_dict'])
self.am_sample_rate = self.pretrained_models[am_tag]['sample_rate'] self.am_sample_rate = self.task_resource.res_dict['sample_rate']
logger.info(am_res_path) logger.info(self.am_res_path)
logger.info(self.am_model) logger.info(self.am_model)
logger.info(self.am_params) logger.info(self.am_params)
else: else:
...@@ -93,32 +98,36 @@ class TTSServerExecutor(TTSExecutor): ...@@ -93,32 +98,36 @@ class TTSServerExecutor(TTSExecutor):
# for speedyspeech # for speedyspeech
self.tones_dict = None self.tones_dict = None
if 'tones_dict' in self.pretrained_models[am_tag]: if 'tones_dict' in self.task_resource.res_dict:
self.tones_dict = os.path.join( self.tones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['tones_dict']) self.am_res_path, self.task_resource.res_dict['tones_dict'])
if tones_dict: if tones_dict:
self.tones_dict = tones_dict self.tones_dict = tones_dict
# for multi speaker fastspeech2 # for multi speaker fastspeech2
self.speaker_dict = None self.speaker_dict = None
if 'speaker_dict' in self.pretrained_models[am_tag]: if 'speaker_dict' in self.task_resource.res_dict:
self.speaker_dict = os.path.join( self.speaker_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['speaker_dict']) self.am_res_path, self.task_resource.res_dict['speaker_dict'])
if speaker_dict: if speaker_dict:
self.speaker_dict = speaker_dict self.speaker_dict = speaker_dict
# voc # voc
voc_tag = voc + '-' + lang voc_tag = voc + '-' + lang
self.task_resource.set_task_model(
model_tag=voc_tag,
model_type=1, # vocoder
version=None, # default version
)
if voc_model is None or voc_params is None: if voc_model is None or voc_params is None:
voc_res_path = self._get_pretrained_path(voc_tag) self.voc_res_path = self.task_resource.voc_res_dir
self.voc_res_path = voc_res_path
self.voc_model = os.path.join( self.voc_model = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['model']) self.voc_res_path, self.task_resource.voc_res_dict['model'])
self.voc_params = os.path.join( self.voc_params = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['params']) self.voc_res_path, self.task_resource.voc_res_dict['params'])
self.voc_sample_rate = self.pretrained_models[voc_tag][ self.voc_sample_rate = self.task_resource.voc_res_dict[
'sample_rate'] 'sample_rate']
logger.info(voc_res_path) logger.info(self.voc_res_path)
logger.info(self.voc_model) logger.info(self.voc_model)
logger.info(self.voc_params) logger.info(self.voc_params)
else: else:
......
...@@ -243,8 +243,7 @@ def main(): ...@@ -243,8 +243,7 @@ def main():
# parse args and config and redirect to train_sp # parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a HiFiGAN model.") parser = argparse.ArgumentParser(description="Train a HiFiGAN model.")
parser.add_argument( parser.add_argument("--config", type=str, help="HiFiGAN config file.")
"--config", type=str, help="config file to overwrite default config.")
parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
......
...@@ -233,7 +233,7 @@ def main(): ...@@ -233,7 +233,7 @@ def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Train a Multi-Band MelGAN model.") description="Train a Multi-Band MelGAN model.")
parser.add_argument( parser.add_argument(
"--config", type=str, help="config file to overwrite default config.") "--config", type=str, help="Multi-Band MelGAN config file.")
parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
......
...@@ -208,7 +208,7 @@ def main(): ...@@ -208,7 +208,7 @@ def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Train a ParallelWaveGAN model.") description="Train a ParallelWaveGAN model.")
parser.add_argument( parser.add_argument(
"--config", type=str, help="config file to overwrite default config.") "--config", type=str, help="ParallelWaveGAN config file.")
parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
......
...@@ -224,8 +224,7 @@ def main(): ...@@ -224,8 +224,7 @@ def main():
# parse args and config and redirect to train_sp # parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a Style MelGAN model.") parser = argparse.ArgumentParser(description="Train a Style MelGAN model.")
parser.add_argument( parser.add_argument("--config", type=str, help="Style MelGAN config file.")
"--config", type=str, help="config file to overwrite default config.")
parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
......
...@@ -160,7 +160,7 @@ def main(): ...@@ -160,7 +160,7 @@ def main():
parser = argparse.ArgumentParser(description="Train a TransformerTTS " parser = argparse.ArgumentParser(description="Train a TransformerTTS "
"model with LJSpeech TTS dataset.") "model with LJSpeech TTS dataset.")
parser.add_argument( parser.add_argument(
"--config", type=str, help="config file to overwrite default config.") "--config", type=str, help="TransformerTTS config file.")
parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
......
...@@ -226,9 +226,8 @@ def train_sp(args, config): ...@@ -226,9 +226,8 @@ def train_sp(args, config):
def main(): def main():
# parse args and config and redirect to train_sp # parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a HiFiGAN model.") parser = argparse.ArgumentParser(description="Train a VITS model.")
parser.add_argument( parser.add_argument("--config", type=str, help="VITS config file")
"--config", type=str, help="config file to overwrite default config.")
parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
......
...@@ -180,8 +180,7 @@ def main(): ...@@ -180,8 +180,7 @@ def main():
# parse args and config and redirect to train_sp # parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a WaveRNN model.") parser = argparse.ArgumentParser(description="Train a WaveRNN model.")
parser.add_argument( parser.add_argument("--config", type=str, help="WaveRNN config file.")
"--config", type=str, help="config file to overwrite default config.")
parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
......
...@@ -44,13 +44,13 @@ More details please see `README.md` under `examples`. ...@@ -44,13 +44,13 @@ More details please see `README.md` under `examples`.
> If using docker please check `--privileged` is set when `docker run`. > If using docker please check `--privileged` is set when `docker run`.
* Fatal error at startup: `a function redirection which is mandatory for this platform-tool combination cannot be set up` * Fatal error at startup: `a function redirection which is mandatory for this platform-tool combination cannot be set up`
``` ```bash
apt-get install libc6-dbg apt-get install libc6-dbg
``` ```
* Install * Install
``` ```bash
pushd tools pushd tools
./setup_valgrind.sh ./setup_valgrind.sh
popd popd
...@@ -59,4 +59,4 @@ popd ...@@ -59,4 +59,4 @@ popd
## TODO ## TODO
### Deepspeech2 with linear feature ### Deepspeech2 with linear feature
* DecibelNormalizer: there is a little bit difference between offline and online db norm. The computation of online db norm read feature chunk by chunk, which causes the feature size is different with offline db norm. In normalizer.cc:73, the samples.size() is different, which causes the difference of result. * DecibelNormalizer: there is a small difference between the offline and online db norm. The computation of online db norm reads features chunk by chunk, which causes the feature size to be different different with offline db norm. In `normalizer.cc:73`, the `samples.size()` is different, which causes the different result.
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(ds2_ol)
add_subdirectory(dev)
\ No newline at end of file
...@@ -22,14 +22,7 @@ netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host ...@@ -22,14 +22,7 @@ netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host
## For Developer ## For Developer
> Warning: Only for developer, make sure you know what's it. > Reminder: Only for developer, make sure you know what's it.
* dev - for speechx developer, using for test. * codelab - for speechx developer, using for test.
## Build WFST
> Warning: Using below example when you know what's it.
* text_lm - process text for build lm
* ngram - using to build NGram ARPA lm.
* wfst - build wfst for TLG.
# Codelab
## introduction
> The below is for developing and offline testing. Do not run it only if you know what it is.
* nnet
* feat
* decoder
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../ SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } [ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/feat SPEECHX_BIN=$SPEECHX_ROOT/build/speechx/decoder:$SPEECHX_ROOT/build/speechx/frontend/audio
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
...@@ -54,7 +54,7 @@ cmvn=$exp_dir/cmvn.ark ...@@ -54,7 +54,7 @@ cmvn=$exp_dir/cmvn.ark
export GLOG_logtostderr=1 export GLOG_logtostderr=1
# dump json cmvn to kaldi # dump json cmvn to kaldi
cmvn-json2kaldi \ cmvn_json2kaldi_main \
--json_file $ckpt_dir/data/mean_std.json \ --json_file $ckpt_dir/data/mean_std.json \
--cmvn_write_path $cmvn \ --cmvn_write_path $cmvn \
--binary=false --binary=false
...@@ -62,17 +62,17 @@ echo "convert json cmvn to kaldi ark." ...@@ -62,17 +62,17 @@ echo "convert json cmvn to kaldi ark."
# generate linear feature as streaming # generate linear feature as streaming
linear-spectrogram-wo-db-norm-ol \ compute_linear_spectrogram_main \
--wav_rspecifier=scp:$data/wav.scp \ --wav_rspecifier=scp:$data/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \ --feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_file=$cmvn --cmvn_file=$cmvn
echo "compute linear spectrogram feature." echo "compute linear spectrogram feature."
# run ctc beam search decoder as streaming # run ctc beam search decoder as streaming
ctc-prefix-beam-search-decoder-ol \ ctc_prefix_beam_search_decoder_main \
--result_wspecifier=ark,t:$exp_dir/result.txt \ --result_wspecifier=ark,t:$exp_dir/result.txt \
--feature_rspecifier=ark:$feat_wspecifier \ --feature_rspecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
--dict_file=$vocb_dir/vocab.txt \ --dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm --lm_path=$lm
\ No newline at end of file
...@@ -2,6 +2,6 @@ ...@@ -2,6 +2,6 @@
ASR audio feature test bins. We using theses bins to test linaer/fbank/mfcc asr feature as streaming manner. ASR audio feature test bins. We using theses bins to test linaer/fbank/mfcc asr feature as streaming manner.
* linear_spectrogram_without_db_norm_main.cc * compute_linear_spectrogram_main.cc
compute linear spectrogram w/o db norm in streaming manner. compute linear spectrogram without db norm in streaming manner.
...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin ...@@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat SPEECHX_BIN=$SPEECHX_ROOT/build/speechx/decoder:$SPEECHX_ROOT/build/speechx/frontend/audio
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
...@@ -41,14 +41,14 @@ mkdir -p $exp_dir ...@@ -41,14 +41,14 @@ mkdir -p $exp_dir
# 3. run feat # 3. run feat
export GLOG_logtostderr=1 export GLOG_logtostderr=1
cmvn-json2kaldi \ cmvn_json2kaldi_main \
--json_file $model_dir/data/mean_std.json \ --json_file $model_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \ --cmvn_write_path $exp_dir/cmvn.ark \
--binary=false --binary=false
echo "convert json cmvn to kaldi ark." echo "convert json cmvn to kaldi ark."
linear-spectrogram-wo-db-norm-ol \ compute_linear_spectrogram_main \
--wav_rspecifier=scp:$data_dir/wav.scp \ --wav_rspecifier=scp:$data_dir/wav.scp \
--feature_wspecifier=ark,t:$exp_dir/feats.ark \ --feature_wspecifier=ark,t:$exp_dir/feats.ark \
--cmvn_file=$exp_dir/cmvn.ark --cmvn_file=$exp_dir/cmvn.ark
......
...@@ -17,7 +17,7 @@ feat_wspecifier=./feats.ark ...@@ -17,7 +17,7 @@ feat_wspecifier=./feats.ark
cmvn=./cmvn.ark cmvn=./cmvn.ark
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \ valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
linear_spectrogram_main \ compute_linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \ --wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \ --feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_write_path=$cmvn --cmvn_write_path=$cmvn
......
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../ SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } [ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
SPEECHX_BIN=$SPEECHX_EXAMPLES/dev/glog
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_BUILD/codelab/nnet
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
...@@ -20,19 +20,10 @@ if [ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; ...@@ -20,19 +20,10 @@ if [ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ];
popd popd
fi fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p data
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
ckpt_dir=./data/model ckpt_dir=./data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
ds2-model-ol-test \ ds2_model_test_main \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams --param_path=$model_dir/avg_1.jit.pdiparams
...@@ -12,9 +12,10 @@ if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then ...@@ -12,9 +12,10 @@ if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
exit 1 exit 1
fi fi
model_dir=../paddle_asr_model ckpt_dir=./data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \ valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
pp-model-test \ ds2_model_test_main \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams --param_path=$model_dir/avg_1.jit.pdparams
\ No newline at end of file
...@@ -7,7 +7,7 @@ export GLOG_logtostderr=1 ...@@ -7,7 +7,7 @@ export GLOG_logtostderr=1
. ./path.sh || exit 1; . ./path.sh || exit 1;
# ds2 means deepspeech2 (acoutic model type) # ds2 means deepspeech2 (acoutic model type)
dir=$PWD/ds2_graph_with_slot dir=$PWD/exp/ds2_graph_with_slot
data=$PWD/data data=$PWD/data
stage=0 stage=0
stop_stage=10 stop_stage=10
...@@ -80,9 +80,9 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then ...@@ -80,9 +80,9 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--word_symbol_table=$graph/words.txt \ --word_symbol_table=$graph/words.txt \
--graph_path=$graph/TLG.fst --max_active=7500 \ --graph_path=$graph/TLG.fst --max_active=7500 \
--acoustic_scale=12 \ --acoustic_scale=12 \
--result_wspecifier=ark,t:./result_run.txt --result_wspecifier=ark,t:./exp/result_run.txt
# the data/wav.trans is the label. # the data/wav.trans is the label.
utils/compute-wer.py --char=1 --v=1 data/wav.trans result_run.txt > wer_run utils/compute-wer.py --char=1 --v=1 data/wav.trans exp/result_run.txt > exp/wer_run
tail -n 7 wer_run tail -n 7 exp/wer_run
fi fi
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(glog_test ${CMAKE_CURRENT_SOURCE_DIR}/glog_test.cc)
target_link_libraries(glog_test glog)
add_executable(glog_logtostderr_test ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_test.cc)
target_link_libraries(glog_logtostderr_test glog)
\ No newline at end of file
#!/bin/bash
set +x
set -e
. ./path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. run
glog_test
echo "------"
export FLAGS_logtostderr=1
glog_test
echo "------"
glog_logtostderr_test
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)
add_subdirectory(websocket)
...@@ -42,3 +42,40 @@ Overall -> 10.93 % N=104765 C=93410 S=9780 D=1575 I=95 ...@@ -42,3 +42,40 @@ Overall -> 10.93 % N=104765 C=93410 S=9780 D=1575 I=95
Mandarin -> 10.93 % N=104762 C=93410 S=9779 D=1573 I=95 Mandarin -> 10.93 % N=104762 C=93410 S=9779 D=1573 I=95
Other -> 100.00 % N=3 C=0 S=1 D=2 I=0 Other -> 100.00 % N=3 C=0 S=1 D=2 I=0
``` ```
## fbank
```
bash run_fbank.sh
```
### CTC Prefix Beam Search w/o LM
```
Overall -> 10.44 % N=104765 C=94194 S=10174 D=397 I=369
Mandarin -> 10.44 % N=104762 C=94194 S=10171 D=397 I=369
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0
```
### CTC Prefix Beam Search w/ LM
LM: zh_giga.no_cna_cmn.prune01244.klm
```
Overall -> 5.82 % N=104765 C=99386 S=4944 D=435 I=720
Mandarin -> 5.82 % N=104762 C=99386 S=4941 D=435 I=720
English -> 0.00 % N=0 C=0 S=0 D=0 I=0
```
### CTC WFST
LM: [aishell train](https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph2.zip)
```
Overall -> 9.58 % N=104765 C=94817 S=4326 D=5622 I=84
Mandarin -> 9.57 % N=104762 C=94817 S=4325 D=5620 I=84
Other -> 100.00 % N=3 C=0 S=1 D=2 I=0
```
## build TLG graph
```
bash run_build_tlg.sh
```
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../.. MAIN_ROOT=`realpath $PWD/../../../../`
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } [ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat:$SPEECHX_EXAMPLES/ds2_ol/websocket # openfst bin & kaldi bin
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN KALDI_DIR=$SPEECHX_ROOT/build/speechx/kaldi/
OPENFST_DIR=$SPEECHX_ROOT/fc_patch/openfst-build/src
# srilm
export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs
export SRILM=${MAIN_ROOT}/tools/srilm
SPEECHX_BIN=$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/websocket
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN:${SRILM}/bin:${SRILM}/bin/i686-m64:$KALDI_DIR/lmbin:$KALDI_DIR/fstbin:$OPENFST_DIR/bin
...@@ -69,12 +69,12 @@ export GLOG_logtostderr=1 ...@@ -69,12 +69,12 @@ export GLOG_logtostderr=1
cmvn=$data/cmvn.ark cmvn=$data/cmvn.ark
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# 3. gen linear feat # 3. gen linear feat
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
linear-spectrogram-wo-db-norm-ol \ compute_linear_spectrogram_main \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \ --feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \
--cmvn_file=$cmvn \ --cmvn_file=$cmvn \
...@@ -85,7 +85,7 @@ fi ...@@ -85,7 +85,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# recognizer # recognizer
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \
ctc-prefix-beam-search-decoder-ol \ ctc_prefix_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
...@@ -102,7 +102,7 @@ fi ...@@ -102,7 +102,7 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# decode with lm # decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \
ctc-prefix-beam-search-decoder-ol \ ctc_prefix_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
...@@ -132,7 +132,7 @@ fi ...@@ -132,7 +132,7 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder # TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
wfst-decoder-ol \ tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
...@@ -151,7 +151,7 @@ fi ...@@ -151,7 +151,7 @@ fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# TLG decoder # TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \
recognizer_test_main \ recognizer_main \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--cmvn_file=$cmvn \ --cmvn_file=$cmvn \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
......
...@@ -3,11 +3,15 @@ set -eo pipefail ...@@ -3,11 +3,15 @@ set -eo pipefail
. path.sh . path.sh
# attention, please replace the vocab is only for this script.
# different acustic model has different vocab
ckpt_dir=data/fbank_model
unit=$ckpt_dir/data/lang_char/vocab.txt # vocab file, line: char/spm_pice
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
stage=-1 stage=-1
stop_stage=100 stop_stage=100
corpus=aishell corpus=aishell
unit=data/vocab.txt # vocab file, line: char/spm_pice
lexicon=data/lexicon.txt # line: word ph0 ... phn, aishell/resource_aishell/lexicon.txt lexicon=data/lexicon.txt # line: word ph0 ... phn, aishell/resource_aishell/lexicon.txt
text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
...@@ -23,6 +27,14 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then ...@@ -23,6 +27,14 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
tar xvzf speech.ngram.zh.tar.gz tar xvzf speech.ngram.zh.tar.gz
popd popd
fi fi
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
mkdir -p $ckpt_dir
pushd $ckpt_dir
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
tar xzfv WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
popd
fi
fi fi
if [ ! -f $unit ]; then if [ ! -f $unit ]; then
...@@ -38,12 +50,12 @@ fi ...@@ -38,12 +50,12 @@ fi
mkdir -p data/local/dict mkdir -p data/local/dict
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# 7.1 Prepare dict # Prepare dict
# line: char/spm_pices # line: char/spm_pices
cp $unit data/local/dict/units.txt cp $unit data/local/dict/units.txt
if [ ! -f $lexicon ];then if [ ! -f $lexicon ];then
local/text_to_lexicon.py --has_key true --text $text --lexicon $lexicon utils/text_to_lexicon.py --has_key true --text $text --lexicon $lexicon
echo "Generate $lexicon from $text" echo "Generate $lexicon from $text"
fi fi
...@@ -59,10 +71,71 @@ lm=data/local/lm ...@@ -59,10 +71,71 @@ lm=data/local/lm
mkdir -p $lm mkdir -p $lm
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# 7.2 Train lm # Train lm
cp $text $lm/text cp $text $lm/text
local/aishell_train_lms.sh local/aishell_train_lms.sh
echo "build LM done."
fi
# build TLG
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# build T & L
utils/fst/compile_lexicon_token_fst.sh \
data/local/dict data/local/tmp data/local/lang
# build G & TLG
utils/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1;
fi
aishell_wav_scp=aishell_test.scp
nj=40
cmvn=$data/cmvn_fbank.ark
wfst=$data/lang_test
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
fi
wer=aishell_wer
label_file=aishell_result
export GLOG_logtostderr=1
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/check_tlg.log \
recognizer_main \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--cmvn_file=$cmvn \
--model_path=$model_dir/avg_5.jit.pdmodel \
--streaming_chunk=30 \
--use_fbank=true \
--param_path=$model_dir/avg_5.jit.pdiparams \
--word_symbol_table=$wfst/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--model_cache_shapes="5-1-2048,5-1-2048" \
--graph_path=$wfst/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_check_tlg
cat $data/split${nj}/*/result_check_tlg > $exp/${label_file}_check_tlg
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_check_tlg > $exp/${wer}.check_tlg
echo "recognizer test have finished!!!"
echo "please checkout in ${exp}/${wer}.check_tlg"
fi fi
echo "build LM done."
exit 0 exit 0
...@@ -69,7 +69,7 @@ export GLOG_logtostderr=1 ...@@ -69,7 +69,7 @@ export GLOG_logtostderr=1
cmvn=$data/cmvn_fbank.ark cmvn=$data/cmvn_fbank.ark
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# 3. gen linear feat # 3. gen linear feat
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn --binary=false cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn --binary=false
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
...@@ -84,7 +84,7 @@ fi ...@@ -84,7 +84,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# recognizer # recognizer
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \
ctc-prefix-beam-search-decoder-ol \ ctc_prefix_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \ --model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \ --param_path=$model_dir/avg_5.jit.pdiparams \
...@@ -100,12 +100,12 @@ fi ...@@ -100,12 +100,12 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# decode with lm # decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \
ctc-prefix-beam-search-decoder-ol \ ctc_prefix_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \ --model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \ --param_path=$model_dir/avg_5.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--model_cache_shapes="5-1-2048,5-1-2048" \ --model_cache_shapes="5-1-2048,5-1-2048" \
--dict_file=$vocb_dir/vocab.txt \ --dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm \ --lm_path=$lm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/fbank_result_lm --result_wspecifier=ark,t:$data/split${nj}/JOB/fbank_result_lm
...@@ -129,13 +129,13 @@ fi ...@@ -129,13 +129,13 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder # TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \
wfst-decoder-ol \ tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \ --model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \ --param_path=$model_dir/avg_5.jit.pdiparams \
--word_symbol_table=$wfst/words.txt \ --word_symbol_table=$wfst/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--model_cache_shapes="5-1-2048,5-1-2048" \ --model_cache_shapes="5-1-2048,5-1-2048" \
--graph_path=$wfst/TLG.fst --max_active=7500 \ --graph_path=$wfst/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \ --acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
...@@ -148,13 +148,12 @@ fi ...@@ -148,13 +148,12 @@ fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/fbank_recognizer.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/fbank_recognizer.log \
recognizer_test_main \ recognizer_main \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--cmvn_file=$cmvn \ --cmvn_file=$cmvn \
--model_path=$model_dir/avg_5.jit.pdmodel \ --model_path=$model_dir/avg_5.jit.pdmodel \
--streaming_chunk=30 \ --streaming_chunk=30 \
--use_fbank=true \ --use_fbank=true \
--to_float32=false \
--param_path=$model_dir/avg_5.jit.pdiparams \ --param_path=$model_dir/avg_5.jit.pdiparams \
--word_symbol_table=$wfst/words.txt \ --word_symbol_table=$wfst/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
......
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name ctc-prefix-beam-search-decoder-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
set(bin_name wfst-decoder-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
set(bin_name nnet-logprob-decoder-test)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(recognizer_test_main ${CMAKE_CURRENT_SOURCE_DIR}/recognizer_test_main.cc)
target_include_directories(recognizer_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(recognizer_test_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name linear-spectrogram-wo-db-norm-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog)
set(bin_name compute_fbank_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog)
set(bin_name cmvn-json2kaldi)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog)
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/nnet
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../.. SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } [ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/websocket:$SPEECHX_EXAMPLES/ds2_ol/feat SPEECHX_BIN=$SPEECHX_BUILD/protocol/websocket
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
...@@ -45,7 +45,7 @@ export GLOG_logtostderr=1 ...@@ -45,7 +45,7 @@ export GLOG_logtostderr=1
# 3. gen cmvn # 3. gen cmvn
cmvn=$data/cmvn.ark cmvn=$data/cmvn.ark
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
wfst=$data/wfst/ wfst=$data/wfst/
......
# ngram train for mandarin
Quick run:
```
bash run.sh --stage -1
```
## input
input files:
```
data/
├── lexicon.txt
├── text
└── vocab.txt
```
```
==> data/text <==
BAC009S0002W0122 而 对 楼市 成交 抑制 作用 最 大 的 限 购
BAC009S0002W0123 也 成为 地方 政府 的 眼中 钉
BAC009S0002W0124 自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后
BAC009S0002W0125 各地 政府 便 纷纷 跟进
BAC009S0002W0126 仅 一 个 多 月 的 时间 里
BAC009S0002W0127 除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外
BAC009S0002W0128 四十六 个 限 购 城市 当中
BAC009S0002W0129 四十一 个 已 正式 取消 或 变相 放松 了 限 购
BAC009S0002W0130 财政 金融 政策 紧随 其后 而来
BAC009S0002W0131 显示 出 了 极 强 的 威力
==> data/lexicon.txt <==
SIL sil
<SPOKEN_NOISE> sil
啊 aa a1
啊 aa a2
啊 aa a4
啊 aa a5
啊啊啊 aa a2 aa a2 aa a2
啊啊啊 aa a5 aa a5 aa a5
坐地 z uo4 d i4
坐实 z uo4 sh ix2
坐视 z uo4 sh ix4
坐稳 z uo4 uu un3
坐拥 z uo4 ii iong1
坐诊 z uo4 zh en3
坐庄 z uo4 zh uang1
坐姿 z uo4 z iy1
==> data/vocab.txt <==
<blank>
<unk>
A
B
C
D
E
<eos>
```
## output
```
data/
├── local
│ ├── dict
│ │ ├── lexicon.txt
│ │ └── units.txt
│ └── lm
│ ├── heldout
│ ├── lm.arpa
│ ├── text
│ ├── text.no_oov
│ ├── train
│ ├── unigram.counts
│ ├── word.counts
│ └── wordlist
```
```
/workspace/srilm/bin/i686-m64/ngram-count
Namespace(bpemodel=None, in_lexicon='data/lexicon.txt', out_lexicon='data/local/dict/lexicon.txt', unit_file='data/vocab.txt')
Ignoring words 矽, which contains oov unit
Ignoring words 傩, which contains oov unit
Ignoring words 堀, which contains oov unit
Ignoring words 莼, which contains oov unit
Ignoring words 菰, which contains oov unit
Ignoring words 摭, which contains oov unit
Ignoring words 帙, which contains oov unit
Ignoring words 迨, which contains oov unit
Ignoring words 孥, which contains oov unit
Ignoring words 瑗, which contains oov unit
...
...
...
file data/local/lm/heldout: 10000 sentences, 89496 words, 0 OOVs
0 zeroprobs, logprob= -270337.9 ppl= 521.2819 ppl1= 1048.745
build LM done.
```
#!/usr/bin/env bash
set -eo pipefail
data=$1
scp=$2
split_name=$3
numsplit=$4
# save in $data/split{n}
# $scp to split
#
if [[ ! $numsplit -gt 0 ]]; then
echo "Invalid num-split argument";
exit 1;
fi
directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done)
scp_splits=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_name}; done)
# if this mkdir fails due to argument-list being too long, iterate.
if ! mkdir -p $directories >&/dev/null; then
for n in `seq $numsplit`; do
mkdir -p $data/split${numsplit}/$n
done
fi
echo "utils/split_scp.pl $scp $scp_splits"
utils/split_scp.pl $scp $scp_splits
# This contains the locations of binarys build required for running the examples.
MAIN_ROOT=`realpath $PWD/../../../../`
SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx`
export LC_AL=C
# srilm
export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs
export SRILM=${MAIN_ROOT}/tools/srilm
export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64
../../../../utils/
\ No newline at end of file
# Built TLG wfst
## Input
```
data/local/
├── dict
│ ├── lexicon.txt
│ └── units.txt
└── lm
├── heldout
├── lm.arpa
├── text
├── text.no_oov
├── train
├── unigram.counts
├── word.counts
└── wordlist
```
```
==> data/local/dict/lexicon.txt <==
啊 啊
啊啊啊 啊 啊 啊
阿 阿
阿尔 阿 尔
阿根廷 阿 根 廷
阿九 阿 九
阿克 阿 克
阿拉伯数字 阿 拉 伯 数 字
阿拉法特 阿 拉 法 特
阿拉木图 阿 拉 木 图
==> data/local/dict/units.txt <==
<blank>
<unk>
A
B
C
D
E
F
G
H
==> data/local/lm/heldout <==
而 对 楼市 成交 抑制 作用 最 大 的 限 购
也 成为 地方 政府 的 眼中 钉
自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后
各地 政府 便 纷纷 跟进
仅 一 个 多 月 的 时间 里
除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外
四十六 个 限 购 城市 当中
四十一 个 已 正式 取消 或 变相 放松 了 限 购
财政 金融 政策 紧随 其后 而来
显示 出 了 极 强 的 威力
==> data/local/lm/lm.arpa <==
\data\
ngram 1=129356
ngram 2=504661
ngram 3=123455
\1-grams:
-1.531278 </s>
-3.828829 <SPOKEN_NOISE> -0.1600094
-6.157292 <UNK>
==> data/local/lm/text <==
BAC009S0002W0122 而 对 楼市 成交 抑制 作用 最 大 的 限 购
BAC009S0002W0123 也 成为 地方 政府 的 眼中 钉
BAC009S0002W0124 自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后
BAC009S0002W0125 各地 政府 便 纷纷 跟进
BAC009S0002W0126 仅 一 个 多 月 的 时间 里
BAC009S0002W0127 除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外
BAC009S0002W0128 四十六 个 限 购 城市 当中
BAC009S0002W0129 四十一 个 已 正式 取消 或 变相 放松 了 限 购
BAC009S0002W0130 财政 金融 政策 紧随 其后 而来
BAC009S0002W0131 显示 出 了 极 强 的 威力
==> data/local/lm/text.no_oov <==
<SPOKEN_NOISE> 而 对 楼市 成交 抑制 作用 最 大 的 限 购
<SPOKEN_NOISE> 也 成为 地方 政府 的 眼中 钉
<SPOKEN_NOISE> 自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后
<SPOKEN_NOISE> 各地 政府 便 纷纷 跟进
<SPOKEN_NOISE> 仅 一 个 多 月 的 时间 里
<SPOKEN_NOISE> 除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外
<SPOKEN_NOISE> 四十六 个 限 购 城市 当中
<SPOKEN_NOISE> 四十一 个 已 正式 取消 或 变相 放松 了 限 购
<SPOKEN_NOISE> 财政 ���融 政策 紧随 其后 而来
<SPOKEN_NOISE> 显示 出 了 极 强 的 威力
==> data/local/lm/train <==
汉莎 不 得 不 通过 这样 的 方式 寻求 新 的 发展 点
并 计划 朝云 计算 方面 发展
汉莎 的 基础 设施 部门 拥有 一千四百 名 员工
媒体 就 曾 披露 这笔 交易
虽然 双方 已经 正式 签署 了 外包 协议
但是 这笔 交易 还 需要 得到 反 垄断 部门 的 批准
陈 黎明 一九八九 年 获得 美国 康乃尔 大学 硕士 学位
并 于 二零零三 年 顺利 完成 美国 哈佛 商学 院 高级 管理 课程
曾 在 多家 国际 公司 任职
拥有 业务 开发 商务 及 企业 治理
==> data/local/lm/unigram.counts <==
57487 的
13099 在
11862 一
11397 了
10998 不
9913 是
7952 有
6250 和
6152 个
5422 将
==> data/local/lm/word.counts <==
57486 的
13098 在
11861 一
11396 了
10997 不
9912 是
7951 有
6249 和
6151 个
5421 将
==> data/local/lm/wordlist <==
```
## Output
```
fstaddselfloops 'echo 4234 |' 'echo 123660 |'
Lexicon and Token FSTs compiling succeeded
arpa2fst --read-symbol-table=data/lang_test/words.txt --keep-symbols=true -
LOG (arpa2fst[5.5.0~1-5a37]:Read():arpa-file-parser.cc:94) Reading \data\ section.
LOG (arpa2fst[5.5.0~1-5a37]:Read():arpa-file-parser.cc:149) Reading \1-grams: section.
LOG (arpa2fst[5.5.0~1-5a37]:Read():arpa-file-parser.cc:149) Reading \2-grams: section.
LOG (arpa2fst[5.5.0~1-5a37]:Read():arpa-file-parser.cc:149) Reading \3-grams: section.
Checking how stochastic G is (the first of these numbers should be small):
fstisstochastic data/lang_test/G.fst
0 -1.14386
fsttablecompose data/lang_test/L.fst data/lang_test/G.fst
fstminimizeencoded
fstdeterminizestar --use-log=true
fsttablecompose data/lang_test/T.fst data/lang_test/LG.fst
Composing decoding graph TLG.fst succeeded
Aishell build TLG done.
```
```
data/
├── lang_test
│ ├── G.fst
│ ├── L.fst
│ ├── LG.fst
│ ├── T.fst
│ ├── TLG.fst
│ ├── tokens.txt
│ ├── units.txt
│ └── words.txt
└── local
├── lang
│ ├── L.fst
│ ├── T.fst
│ ├── tokens.txt
│ ├── units.txt
│ └── words.txt
└── tmp
├── disambig.list
├── lexiconp_disambig.txt
├── lexiconp.txt
└── units.list
```
# This contains the locations of binarys build required for running the examples.
MAIN_ROOT=`realpath $PWD/../../../`
SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx`
export LC_AL=C
# srilm
export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs
export SRILM=${MAIN_ROOT}/tools/srilm
export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64
# Kaldi
export KALDI_ROOT=${MAIN_ROOT}/tools/kaldi
[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present, can not using Kaldi!"
[ -f $KALDI_ROOT/tools/config/common_path.sh ] && . $KALDI_ROOT/tools/config/common_path.sh
#!/bin/bash
set -eo pipefail
. path.sh
stage=-1
stop_stage=100
. utils/parse_options.sh
if ! which fstprint ; then
pushd $MAIN_ROOT/tools
make kaldi.done
popd
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# build T & L
# utils/fst/compile_lexicon_token_fst.sh <dict-src-dir> <tmp-dir> <lang-dir>
utils/fst/compile_lexicon_token_fst.sh \
data/local/dict data/local/tmp data/local/lang
# build G & LG & TLG
# utils/fst/make_tlg.sh <lm_dir> <src_lang> <tgt_lang>
utils/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1;
fi
echo "build TLG done."
exit 0
../../../utils/
\ No newline at end of file
...@@ -34,6 +34,12 @@ add_subdirectory(decoder) ...@@ -34,6 +34,12 @@ add_subdirectory(decoder)
include_directories( include_directories(
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/websocket ${CMAKE_CURRENT_SOURCE_DIR}/protocol
) )
add_subdirectory(websocket) add_subdirectory(protocol)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/codelab
)
add_subdirectory(codelab)
cmake_minimum_required(VERSION 3.14 FATAL_ERROR) cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(glog) add_subdirectory(glog)
add_subdirectory(nnet)
## For Developer
> Reminder: Only for developer.
* codelab - for speechx developer, using for test.
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(glog_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_main.cc)
target_link_libraries(glog_main glog)
add_executable(glog_logtostderr_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_main.cc)
target_link_libraries(glog_logtostderr_main glog)
...@@ -23,3 +23,16 @@ You can also modify flag values in your program by modifying global variables `F ...@@ -23,3 +23,16 @@ You can also modify flag values in your program by modifying global variables `F
FLAGS_log_dir = "/some/log/directory"; FLAGS_log_dir = "/some/log/directory";
LOG(INFO) << "the same file"; LOG(INFO) << "the same file";
``` ```
* this is the test script:
```
# run
glog_test
echo "------"
export FLAGS_logtostderr=1
glog_test
echo "------"
glog_logtostderr_test
```
cmake_minimum_required(VERSION 3.14 FATAL_ERROR) cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name ds2-model-ol-test) set(bin_name ds2_model_test_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS}) target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS})
\ No newline at end of file
...@@ -10,3 +10,16 @@ add_library(decoder STATIC ...@@ -10,3 +10,16 @@ add_library(decoder STATIC
recognizer.cc recognizer.cc
) )
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder)
set(BINS
ctc_prefix_beam_search_decoder_main
nnet_logprob_decoder_main
recognizer_main
tlg_decoder_main
)
foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
endforeach()
...@@ -47,6 +47,26 @@ void TLGDecoder::Reset() { ...@@ -47,6 +47,26 @@ void TLGDecoder::Reset() {
return; return;
} }
std::string TLGDecoder::GetPartialResult() {
if (frame_decoded_size_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.")
return std::string("");
}
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;
std::vector<int> words_id;
decoder_->GetBestPath(&lat, false);
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
std::string words;
for (int32 idx = 0; idx < words_id.size(); ++idx) {
std::string word = word_symbol_table_->Find(words_id[idx]);
words += word;
}
return words;
}
std::string TLGDecoder::GetFinalBestPath() { std::string TLGDecoder::GetFinalBestPath() {
if (frame_decoded_size_ == 0) { if (frame_decoded_size_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
......
...@@ -38,6 +38,7 @@ class TLGDecoder { ...@@ -38,6 +38,7 @@ class TLGDecoder {
std::string GetBestPath(); std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath(); std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath(); std::string GetFinalBestPath();
std::string GetPartialResult();
int NumFrameDecoded(); int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); std::vector<std::string>& nbest_words);
......
...@@ -81,8 +81,8 @@ FeaturePipelineOptions InitFeaturePipelineOptions() { ...@@ -81,8 +81,8 @@ FeaturePipelineOptions InitFeaturePipelineOptions() {
frame_opts.preemph_coeff = 0.0; frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts; opts.linear_spectrogram_opts.frame_opts = frame_opts;
} }
opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length; opts.assembler_opts.frame_chunk_size = FLAGS_receptive_field_length;
opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate; opts.assembler_opts.frame_chunk_stride = FLAGS_downsampling_rate;
return opts; return opts;
} }
...@@ -115,4 +115,4 @@ RecognizerResource InitRecognizerResoure() { ...@@ -115,4 +115,4 @@ RecognizerResource InitRecognizerResoure() {
resource.tlg_opts = InitDecoderOptions(); resource.tlg_opts = InitDecoderOptions();
return resource; return resource;
} }
} }
\ No newline at end of file
...@@ -44,6 +44,10 @@ std::string Recognizer::GetFinalResult() { ...@@ -44,6 +44,10 @@ std::string Recognizer::GetFinalResult() {
return decoder_->GetFinalBestPath(); return decoder_->GetFinalBestPath();
} }
std::string Recognizer::GetPartialResult() {
return decoder_->GetPartialResult();
}
void Recognizer::SetFinished() { void Recognizer::SetFinished() {
feature_pipeline_->SetFinished(); feature_pipeline_->SetFinished();
input_finished_ = true; input_finished_ = true;
......
...@@ -43,6 +43,7 @@ class Recognizer { ...@@ -43,6 +43,7 @@ class Recognizer {
void Accept(const kaldi::Vector<kaldi::BaseFloat>& waves); void Accept(const kaldi::Vector<kaldi::BaseFloat>& waves);
void Decode(); void Decode();
std::string GetFinalResult(); std::string GetFinalResult();
std::string GetPartialResult();
void SetFinished(); void SetFinished();
bool IsFinished(); bool IsFinished();
void Reset(); void Reset();
......
...@@ -8,6 +8,24 @@ add_library(frontend STATIC ...@@ -8,6 +8,24 @@ add_library(frontend STATIC
feature_cache.cc feature_cache.cc
feature_pipeline.cc feature_pipeline.cc
fbank.cc fbank.cc
assembler.cc
) )
target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common kaldi-fbank) target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common kaldi-fbank)
set(bin_name cmvn_json2kaldi_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog)
set(BINS
compute_linear_spectrogram_main
compute_fbank_main
)
foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC frontend utils kaldi-util gflags glog)
endforeach()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/audio/assembler.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::unique_ptr;
Assembler::Assembler(AssemblerOptions opts,
unique_ptr<FrontendInterface> base_extractor) {
frame_chunk_stride_ = opts.frame_chunk_stride;
frame_chunk_size_ = opts.frame_chunk_size;
base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim();
}
void Assembler::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
// read inputs
base_extractor_->Accept(inputs);
}
// pop feature chunk
bool Assembler::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
feats->Resize(dim_ * frame_chunk_size_);
bool result = Compute(feats);
return result;
}
// read all data from base_feature_extractor_ into cache_
bool Assembler::Compute(Vector<BaseFloat>* feats) {
// compute and feed
bool result = false;
while (feature_cache_.size() < frame_chunk_size_) {
Vector<BaseFloat> feature;
result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) return false;
feature_cache_.push(feature);
}
int32 counter = 0;
int32 cache_size = frame_chunk_size_ - frame_chunk_stride_;
int32 elem_dim = base_extractor_->Dim();
while (counter < frame_chunk_size_) {
Vector<BaseFloat>& val = feature_cache_.front();
int32 start = counter * elem_dim;
feats->Range(start, elem_dim).CopyFromVec(val);
if (frame_chunk_size_ - counter <= cache_size ) {
feature_cache_.push(val);
}
feature_cache_.pop();
counter++;
}
return result;
}
} // namespace ppspeech
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/audio/frontend_itf.h"
namespace ppspeech {
struct AssemblerOptions {
int32 frame_chunk_size;
int32 frame_chunk_stride;
AssemblerOptions()
: frame_chunk_size(1),
frame_chunk_stride(1) {}
};
class Assembler : public FrontendInterface {
public:
explicit Assembler(
AssemblerOptions opts,
std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
// feats size = num_frames * feat_dim
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// feat dim
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() {
base_extractor_->SetFinished();
}
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
}
private:
bool Compute(kaldi::Vector<kaldi::BaseFloat>* feats);
int32 dim_;
int32 frame_chunk_size_; // window
int32 frame_chunk_stride_; // stride
std::queue<kaldi::Vector<kaldi::BaseFloat>> feature_cache_;
std::unique_ptr<FrontendInterface> base_extractor_;
DISALLOW_COPY_AND_ASSIGN(Assembler);
};
} // namespace ppspeech
...@@ -64,10 +64,6 @@ int main(int argc, char* argv[]) { ...@@ -64,10 +64,6 @@ int main(int argc, char* argv[]) {
ppspeech::FeatureCacheOptions feat_cache_opts; ppspeech::FeatureCacheOptions feat_cache_opts;
// the feature cache output feature chunk by chunk. // the feature cache output feature chunk by chunk.
// frame_chunk_size : num frame of a chunk.
// frame_chunk_stride: chunk sliding window stride.
feat_cache_opts.frame_chunk_stride = 1;
feat_cache_opts.frame_chunk_size = 1;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "fbank: " << true; LOG(INFO) << "fbank: " << true;
LOG(INFO) << "feat dim: " << feature_cache.Dim(); LOG(INFO) << "feat dim: " << feature_cache.Dim();
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "kaldi/feat/wave-reader.h" #include "kaldi/feat/wave-reader.h"
...@@ -68,10 +66,6 @@ int main(int argc, char* argv[]) { ...@@ -68,10 +66,6 @@ int main(int argc, char* argv[]) {
ppspeech::FeatureCacheOptions feat_cache_opts; ppspeech::FeatureCacheOptions feat_cache_opts;
// the feature cache output feature chunk by chunk. // the feature cache output feature chunk by chunk.
// frame_chunk_size : num frame of a chunk.
// frame_chunk_stride: chunk sliding window stride.
feat_cache_opts.frame_chunk_stride = 1;
feat_cache_opts.frame_chunk_size = 1;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim(); LOG(INFO) << "feat dim: " << feature_cache.Dim();
......
...@@ -26,8 +26,6 @@ using std::unique_ptr; ...@@ -26,8 +26,6 @@ using std::unique_ptr;
FeatureCache::FeatureCache(FeatureCacheOptions opts, FeatureCache::FeatureCache(FeatureCacheOptions opts,
unique_ptr<FrontendInterface> base_extractor) { unique_ptr<FrontendInterface> base_extractor) {
max_size_ = opts.max_size; max_size_ = opts.max_size;
frame_chunk_stride_ = opts.frame_chunk_stride;
frame_chunk_size_ = opts.frame_chunk_size;
timeout_ = opts.timeout; // ms timeout_ = opts.timeout; // ms
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim(); dim_ = base_extractor_->Dim();
...@@ -74,24 +72,11 @@ bool FeatureCache::Compute() { ...@@ -74,24 +72,11 @@ bool FeatureCache::Compute() {
bool result = base_extractor_->Read(&feature); bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) return false; if (result == false || feature.Dim() == 0) return false;
// join with remained int32 num_chunk = feature.Dim() / dim_ ;
int32 joint_len = feature.Dim() + remained_feature_.Dim();
Vector<BaseFloat> joint_feature(joint_len);
joint_feature.Range(0, remained_feature_.Dim())
.CopyFromVec(remained_feature_);
joint_feature.Range(remained_feature_.Dim(), feature.Dim())
.CopyFromVec(feature);
// one by one, or stride with window
// controlled by frame_chunk_stride_ and frame_chunk_size_
int32 num_chunk =
((joint_len / dim_) - frame_chunk_size_) / frame_chunk_stride_ + 1;
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * frame_chunk_stride_ * dim_; int32 start = chunk_idx * dim_;
Vector<BaseFloat> feature_chunk(dim_);
Vector<BaseFloat> feature_chunk(frame_chunk_size_ * dim_); SubVector<BaseFloat> tmp(feature.Data() + start, dim_);
SubVector<BaseFloat> tmp(joint_feature.Data() + start,
frame_chunk_size_ * dim_);
feature_chunk.CopyFromVec(tmp); feature_chunk.CopyFromVec(tmp);
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
...@@ -104,13 +89,6 @@ bool FeatureCache::Compute() { ...@@ -104,13 +89,6 @@ bool FeatureCache::Compute() {
cache_.push(feature_chunk); cache_.push(feature_chunk);
ready_read_condition_.notify_one(); ready_read_condition_.notify_one();
} }
// cache remained feats
int32 remained_feature_len =
joint_len - num_chunk * frame_chunk_stride_ * dim_;
remained_feature_.Resize(remained_feature_len);
remained_feature_.CopyFromVec(joint_feature.Range(
frame_chunk_stride_ * num_chunk * dim_, remained_feature_len));
return result; return result;
} }
......
...@@ -21,13 +21,9 @@ namespace ppspeech { ...@@ -21,13 +21,9 @@ namespace ppspeech {
struct FeatureCacheOptions { struct FeatureCacheOptions {
int32 max_size; int32 max_size;
int32 frame_chunk_size;
int32 frame_chunk_stride;
int32 timeout; // ms int32 timeout; // ms
FeatureCacheOptions() FeatureCacheOptions()
: max_size(kint16max), : max_size(kint16max),
frame_chunk_size(1),
frame_chunk_stride(1),
timeout(1) {} timeout(1) {}
}; };
...@@ -80,7 +76,7 @@ class FeatureCache : public FrontendInterface { ...@@ -80,7 +76,7 @@ class FeatureCache : public FrontendInterface {
std::condition_variable ready_feed_condition_; std::condition_variable ready_feed_condition_;
std::condition_variable ready_read_condition_; std::condition_variable ready_read_condition_;
// DISALLOW_COPY_AND_ASSGIN(FeatureCache); DISALLOW_COPY_AND_ASSIGN(FeatureCache);
}; };
} // namespace ppspeech } // namespace ppspeech
...@@ -35,8 +35,11 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) { ...@@ -35,8 +35,11 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
unique_ptr<FrontendInterface> cmvn( unique_ptr<FrontendInterface> cmvn(
new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature)));
base_extractor_.reset( unique_ptr<FrontendInterface> cache(
new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn))); new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn)));
base_extractor_.reset(
new ppspeech::Assembler(opts.assembler_opts, std::move(cache)));
} }
} // ppspeech } // ppspeech
\ No newline at end of file
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "frontend/audio/frontend_itf.h" #include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h" #include "frontend/audio/normalizer.h"
#include "frontend/audio/assembler.h"
namespace ppspeech { namespace ppspeech {
...@@ -33,13 +34,16 @@ struct FeaturePipelineOptions { ...@@ -33,13 +34,16 @@ struct FeaturePipelineOptions {
LinearSpectrogramOptions linear_spectrogram_opts; LinearSpectrogramOptions linear_spectrogram_opts;
FbankOptions fbank_opts; FbankOptions fbank_opts;
FeatureCacheOptions feature_cache_opts; FeatureCacheOptions feature_cache_opts;
AssemblerOptions assembler_opts;
FeaturePipelineOptions() FeaturePipelineOptions()
: cmvn_file(""), : cmvn_file(""),
to_float32(false), // true, only for linear feature to_float32(false), // true, only for linear feature
use_fbank(true), use_fbank(true),
linear_spectrogram_opts(), linear_spectrogram_opts(),
fbank_opts(), fbank_opts(),
feature_cache_opts() {} feature_cache_opts(),
assembler_opts() {}
}; };
class FeaturePipeline : public FrontendInterface { class FeaturePipeline : public FrontendInterface {
...@@ -59,4 +63,4 @@ class FeaturePipeline : public FrontendInterface { ...@@ -59,4 +63,4 @@ class FeaturePipeline : public FrontendInterface {
private: private:
std::unique_ptr<FrontendInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
}; };
} }
\ No newline at end of file
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(websocket)
cmake_minimum_required(VERSION 3.14 FATAL_ERROR) project(websocket)
add_library(websocket STATIC
websocket_server.cc
websocket_client.cc
)
target_link_libraries(websocket PUBLIC frontend decoder nnet)
add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc) add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc)
target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_server_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS}) target_link_libraries(websocket_server_main PUBLIC fst websocket ${DEPS})
add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc) add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc)
target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_client_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS}) target_link_libraries(websocket_client_main PUBLIC fst websocket ${DEPS})
\ No newline at end of file
...@@ -67,6 +67,9 @@ void WebSocketClient::ReadLoopFunc() { ...@@ -67,6 +67,9 @@ void WebSocketClient::ReadLoopFunc() {
if (obj["type"] == "final_result") { if (obj["type"] == "final_result") {
result_ = obj["result"].as_string().c_str(); result_ = obj["result"].as_string().c_str();
} }
if (obj["type"] == "partial_result") {
partial_result_ = obj["result"].as_string().c_str();
}
if (obj["type"] == "speech_end") { if (obj["type"] == "speech_end") {
done_ = true; done_ = true;
break; break;
......
...@@ -40,12 +40,14 @@ class WebSocketClient { ...@@ -40,12 +40,14 @@ class WebSocketClient {
void SendEndSignal(); void SendEndSignal();
void SendDataEnd(); void SendDataEnd();
bool Done() const { return done_; } bool Done() const { return done_; }
std::string GetResult() { return result_; } std::string GetResult() const { return result_; }
std::string GetPartialResult() const { return partial_result_;}
private: private:
void Connect(); void Connect();
std::string host_; std::string host_;
std::string result_; std::string result_;
std::string partial_result_;
int port_; int port_;
bool done_ = false; bool done_ = false;
asio::io_context ioc_; asio::io_context ioc_;
......
...@@ -59,7 +59,6 @@ int main(int argc, char* argv[]) { ...@@ -59,7 +59,6 @@ int main(int argc, char* argv[]) {
client.SendBinaryData(wav_chunk.data(), client.SendBinaryData(wav_chunk.data(),
wav_chunk.size() * sizeof(int16)); wav_chunk.size() * sizeof(int16));
sample_offset += cur_chunk_size; sample_offset += cur_chunk_size;
LOG(INFO) << "Send " << cur_chunk_size << " samples"; LOG(INFO) << "Send " << cur_chunk_size << " samples";
std::this_thread::sleep_for( std::this_thread::sleep_for(
......
...@@ -75,9 +75,10 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) { ...@@ -75,9 +75,10 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) {
CHECK(recognizer_ != nullptr); CHECK(recognizer_ != nullptr);
recognizer_->Accept(pcm_data); recognizer_->Accept(pcm_data);
// TODO: return lpartial result std::string partial_result = recognizer_->GetPartialResult();
json::value rv = { json::value rv = {
{"status", "ok"}, {"type", "partial_result"}, {"result", "TODO"}}; {"status", "ok"}, {"type", "partial_result"}, {"result", partial_result}};
ws_.text(true); ws_.text(true);
ws_.write(asio::buffer(json::serialize(rv))); ws_.write(asio::buffer(json::serialize(rv)));
} }
......
...@@ -44,7 +44,6 @@ class ConnectionHandler { ...@@ -44,7 +44,6 @@ class ConnectionHandler {
void OnFinish(); void OnFinish();
void OnSpeechData(const beast::flat_buffer& buffer); void OnSpeechData(const beast::flat_buffer& buffer);
void OnError(const std::string& message); void OnError(const std::string& message);
void OnPartialResult(const std::string& result);
void OnFinalResult(const std::string& result); void OnFinalResult(const std::string& result);
void DecodeThreadFunc(); void DecodeThreadFunc();
std::string SerializeResult(bool finish); std::string SerializeResult(bool finish);
......
project(websocket)
add_library(websocket STATIC
websocket_server.cc
websocket_client.cc
)
target_link_libraries(websocket PUBLIC frontend decoder nnet)
...@@ -25,7 +25,7 @@ paddlespeech asr --model deepspeech2offline_librispeech --lang en --input ./en.w ...@@ -25,7 +25,7 @@ paddlespeech asr --model deepspeech2offline_librispeech --lang en --input ./en.w
# long audio restriction # long audio restriction
{ {
wget -c https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/test_long_audio_01.wav wget -c https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/test_long_audio_01.wav
paddlespeech asr --input test_long_audio_01.wav paddlespeech asr --model deepspeech2online_wenetspeech --input test_long_audio_01.wav -y
if [ $? -ne 255 ]; then if [ $? -ne 255 ]; then
echo -e "\e[1;31mTime restriction not passed\e[0m" echo -e "\e[1;31mTime restriction not passed\e[0m"
exit 1 exit 1
...@@ -54,7 +54,7 @@ paddlespeech tts --am tacotron2_ljspeech --voc pwgan_ljspeech --lang en --input ...@@ -54,7 +54,7 @@ paddlespeech tts --am tacotron2_ljspeech --voc pwgan_ljspeech --lang en --input
# Speech Translation (only support linux) # Speech Translation (only support linux)
paddlespeech st --input ./en.wav paddlespeech st --input ./en.wav
# Speaker Verification # Speaker Verification
wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
paddlespeech vector --task spk --input 85236145389.wav paddlespeech vector --task spk --input 85236145389.wav
...@@ -65,7 +65,7 @@ echo -e "demo1 85236145389.wav \n demo2 85236145389.wav" > vec.job ...@@ -65,7 +65,7 @@ echo -e "demo1 85236145389.wav \n demo2 85236145389.wav" > vec.job
paddlespeech vector --task spk --input vec.job paddlespeech vector --task spk --input vec.job
echo -e "demo3 85236145389.wav \n demo4 85236145389.wav" | paddlespeech vector --task spk echo -e "demo3 85236145389.wav \n demo4 85236145389.wav" | paddlespeech vector --task spk
rm 85236145389.wav rm 85236145389.wav
rm vec.job rm vec.job
# shell pipeline # shell pipeline
......
* [python_kaldi_features](https://github.com/ZitengWang/python_kaldi_features) * [python_kaldi_features](https://github.com/ZitengWang/python_kaldi_features)
commit: fc1bd6240c2008412ab64dc25045cd872f5e126c commit: fc1bd6240c2008412ab64dc25045cd872f5e126c
ref: https://zhuanlan.zhihu.com/p/55371926 ref: https://zhuanlan.zhihu.com/p/55371926
licence: MIT license: MIT
* [python-pinyin](https://github.com/mozillazg/python-pinyin.git) * [python-pinyin](https://github.com/mozillazg/python-pinyin.git)
commit: 55e524aa1b7b8eec3d15c5306043c6cdd5938b03 commit: 55e524aa1b7b8eec3d15c5306043c6cdd5938b03
licence: MIT license: MIT
* [zhon](https://github.com/tsroten/zhon) * [zhon](https://github.com/tsroten/zhon)
commit: 09bf543696277f71de502506984661a60d24494c commit: 09bf543696277f71de502506984661a60d24494c
licence: MIT license: MIT
* [pymmseg-cpp](https://github.com/pluskid/pymmseg-cpp.git) * [pymmseg-cpp](https://github.com/pluskid/pymmseg-cpp.git)
commit: b76465045717fbb4f118c4fbdd24ce93bab10a6d commit: b76465045717fbb4f118c4fbdd24ce93bab10a6d
licence: MIT license: MIT
* [chinese_text_normalization](https://github.com/speechio/chinese_text_normalization.git) * [chinese_text_normalization](https://github.com/speechio/chinese_text_normalization.git)
commit: 9e92c7bf2d6b5a7974305406d8e240045beac51c commit: 9e92c7bf2d6b5a7974305406d8e240045beac51c
licence: MIT license: MIT
* [phkit](https://github.com/KuangDD/phkit.git) * [phkit](https://github.com/KuangDD/phkit.git)
commit: b2100293c1e36da531d7f30bd52c9b955a649522 commit: b2100293c1e36da531d7f30bd52c9b955a649522
licence: None license: None
* [nnAudio](https://github.com/KinWaiCheuk/nnAudio.git) * [nnAudio](https://github.com/KinWaiCheuk/nnAudio.git)
licence: MIT license: MIT
...@@ -5,4 +5,4 @@ score.h and score.cpp is under the LGPL license. ...@@ -5,4 +5,4 @@ score.h and score.cpp is under the LGPL license.
The two files include the header files from KenLM project. The two files include the header files from KenLM project.
For the rest: For the rest:
The default licence of paddlespeech-ctcdecoders is Apache License 2.0. The default license of paddlespeech-ctcdecoders is Apache License 2.0.
# Utils # Utils
* [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils) * [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils)
* [espnet utils)(https://github.com/espnet/espnet/tree/master/utils) * [espnet utils](https://github.com/espnet/espnet/tree/master/utils)
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# CopyRight WeNet Apache-2.0 License # Copyright 2021 Mobvoi Inc. All Rights Reserved.
import codecs import codecs
import re import re
import sys import sys
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册