“397d0fa44a44c086d67627a0da885b1bf38407dd”上不存在“mobile/src/operators/math/sequence2batch.cpp”
提交 c76c4800 编写于 作者: Y Yang Zhou

Merge branch 'develop' of github.com:SmileGoat/PaddleSpeech into refactor_file_struct

([简体中文](./README_cn.md)|English) ([简体中文](./README_cn.md)|English)
<p align="center"> <p align="center">
<img src="./docs/images/PaddleSpeech_logo.png" /> <img src="./docs/images/PaddleSpeech_logo.png" />
</p> </p>
...@@ -20,20 +17,17 @@ ...@@ -20,20 +17,17 @@
<a href="https://huggingface.co/spaces"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue"></a> <a href="https://huggingface.co/spaces"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue"></a>
</p> </p>
<div align="center"> <div align="center">
<h3> <h4>
| <a href="#quick-start"> Quick Start </a> <a href="#quick-start"> Quick Start </a>
| <a href="#quick-start-server"> Quick Start Server </a> | <a href="#quick-start-server"> Quick Start Server </a>
| <a href="#quick-start-streaming-server"> Quick Start Streaming Server</a> | <a href="#quick-start-streaming-server"> Quick Start Streaming Server</a>
|
</br>
| <a href="#documents"> Documents </a> | <a href="#documents"> Documents </a>
| <a href="#model-list"> Models List </a> | <a href="#model-list"> Models List </a>
| | <a href="https://aistudio.baidu.com/aistudio/education/group/info/25130"> AIStudio Courses </a>
</h3> </h4>
</div> </div>
------------------------------------------------------------------------------------
**PaddleSpeech** is an open-source toolkit on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform for a variety of critical tasks in speech and audio, with the state-of-art and influential models. **PaddleSpeech** is an open-source toolkit on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform for a variety of critical tasks in speech and audio, with the state-of-art and influential models.
...@@ -170,23 +164,12 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision ...@@ -170,23 +164,12 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🤗 2021.12.14: [ASR](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR) and [TTS](https://huggingface.co/spaces/KPatrick/PaddleSpeechTTS) Demos on Hugging Face Spaces are available! - 🤗 2021.12.14: [ASR](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR) and [TTS](https://huggingface.co/spaces/KPatrick/PaddleSpeechTTS) Demos on Hugging Face Spaces are available!
- 👏🏻 2021.12.10: `CLI` is available for `Audio Classification`, `Automatic Speech Recognition`, `Speech Translation (English to Chinese)` and `Text-to-Speech`. - 👏🏻 2021.12.10: `CLI` is available for `Audio Classification`, `Automatic Speech Recognition`, `Speech Translation (English to Chinese)` and `Text-to-Speech`.
### 🔥 Hot Activities
<!---
2021.12.14: We would like to have an online courses to introduce basics and research of speech, as well as code practice with `paddlespeech`. Please pay attention to our [Calendar](https://www.paddlepaddle.org.cn/live).
--->
- 2021.12.21~12.24
4 Days Live Courses: Depth interpretation of PaddleSpeech!
**Courses videos and related materials: https://aistudio.baidu.com/aistudio/education/group/info/25130**
### Community ### Community
- Scan the QR code below with your Wechat (reply【语音】after your friend's application is approved), you can access to official technical exchange group. Look forward to your participation. - Scan the QR code below with your Wechat, you can access to official technical exchange group and get the bonus ( more than 20GB learning materials, such as papers, codes and videos ) and the live link of the lessons. Look forward to your participation.
<div align="center"> <div align="center">
<img src="https://raw.githubusercontent.com/yt605155624/lanceTest/main/images/wechat_4.jpg" width = "300" /> <img src="https://user-images.githubusercontent.com/23690325/169763015-cbd8e28d-602c-4723-810d-dbc6da49441e.jpg" width = "200" />
</div> </div>
## Installation ## Installation
......
...@@ -18,40 +18,19 @@ ...@@ -18,40 +18,19 @@
<a href="https://huggingface.co/spaces"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue"></a> <a href="https://huggingface.co/spaces"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue"></a>
</p> </p>
<div align="center"> <div align="center">
<h3> <h4>
<a href="#quick-start"> Quick Start </a> <a href="#快速开始"> 快速开始 </a>
| <a href="#quick-start-server"> Quick Start Server </a> | <a href="#快速使用服务"> 快速使用服务 </a>
| <a href="#quick-start-streaming-server"> Quick Start Streaming Server</a> | <a href="#快速使用流式服务"> 快速使用流式服务 </a>
</br> | <a href="#教程文档"> 教程文档 </a>
<a href="#documents"> Documents </a> | <a href="#模型列表"> 模型列表 </a>
| <a href="#model-list"> Models List </a> | <a href="https://aistudio.baidu.com/aistudio/education/group/info/25130"> AIStudio 课程 </a>
</h3> </h4>
</div> </div>
------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------
<div align="center">
<h3>
<a href="#quick-start"> 快速开始 </a>
| <a href="#quick-start-server"> 快速使用服务 </a>
| <a href="#quick-start-streaming-server"> 快速使用流式服务 </a>
| <a href="#documents"> 教程文档 </a>
| <a href="#model-list"> 模型列表 </a>
</div>
<!---
from https://github.com/18F/open-source-guide/blob/18f-pages/pages/making-readmes-readable.md
1.What is this repo or project? (You can reuse the repo description you used earlier because this section doesn’t have to be long.)
2.How does it work?
3.Who will use this repo or project?
4.What is the goal of this project?
-->
**PaddleSpeech** 是基于飞桨 [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) 的语音方向的开源模型库,用于语音和音频中的各种关键任务的开发,包含大量基于深度学习前沿和有影响力的模型,一些典型的应用示例如下: **PaddleSpeech** 是基于飞桨 [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) 的语音方向的开源模型库,用于语音和音频中的各种关键任务的开发,包含大量基于深度学习前沿和有影响力的模型,一些典型的应用示例如下:
##### 语音识别 ##### 语音识别
...@@ -178,39 +157,30 @@ from https://github.com/18F/open-source-guide/blob/18f-pages/pages/making-readme ...@@ -178,39 +157,30 @@ from https://github.com/18F/open-source-guide/blob/18f-pages/pages/making-readme
### 近期更新 ### 近期更新
- 👑 2022.05.13: PaddleSpeech 发布 [PP-ASR](./docs/source/asr/PPASR_cn.md) 流式语音识别系统、[PP-TTS](./docs/source/tts/PPTTS_cn.md) 流式语音合成系统、[PP-VPR](docs/source/vpr/PPVPR_cn.md) 全链路声纹识别系统
<!---
2021.12.14: We would like to have an online courses to introduce basics and research of speech, as well as code practice with `paddlespeech`. Please pay attention to our [Calendar](https://www.paddlepaddle.org.cn/live).
--->
- 👑 2022.05.13: PaddleSpeech 发布 [PP-ASR](./docs/source/asr/PPASR_cn.md)[PP-TTS](./docs/source/tts/PPTTS_cn.md)[PP-VPR](docs/source/vpr/PPVPR_cn.md)
- 👏🏻 2022.05.06: PaddleSpeech Streaming Server 上线! 覆盖了语音识别(标点恢复、时间戳),和语音合成。 - 👏🏻 2022.05.06: PaddleSpeech Streaming Server 上线! 覆盖了语音识别(标点恢复、时间戳),和语音合成。
- 👏🏻 2022.05.06: PaddleSpeech Server 上线! 覆盖了声音分类、语音识别、语音合成、声纹识别,标点恢复。 - 👏🏻 2022.05.06: PaddleSpeech Server 上线! 覆盖了声音分类、语音识别、语音合成、声纹识别,标点恢复。
- 👏🏻 2022.03.28: PaddleSpeech CLI 覆盖声音分类、语音识别、语音翻译(英译中)、语音合成,声纹验证。 - 👏🏻 2022.03.28: PaddleSpeech CLI 覆盖声音分类、语音识别、语音翻译(英译中)、语音合成,声纹验证。
- 🤗 2021.12.14: PaddleSpeech [ASR](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR) and [TTS](https://huggingface.co/spaces/KPatrick/PaddleSpeechTTS) Demos on Hugging Face Spaces are available! - 🤗 2021.12.14: PaddleSpeech [ASR](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR) and [TTS](https://huggingface.co/spaces/KPatrick/PaddleSpeechTTS) Demos on Hugging Face Spaces are available!
### 🔥 热门活动
- 2021.12.21~12.24
4 日直播课: 深度解读 PaddleSpeech 语音技术! ### 🔥 加入技术交流群获取入群福利
**直播回放与课件资料: https://aistudio.baidu.com/aistudio/education/group/info/25130** - 3 日直播课链接: 深度解读 PP-TTS、PP-ASR、PP-VPR 三项核心语音系统关键技术
- 20G 学习大礼包:视频课程、前沿论文与学习资料
微信扫描二维码关注公众号,点击“马上报名”填写问卷加入官方交流群,获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。
### 技术交流群
微信扫描二维码(好友申请通过后回复【语音】)加入官方交流群,获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。
<div align="center"> <div align="center">
<img src="https://raw.githubusercontent.com/yt605155624/lanceTest/main/images/wechat_4.jpg" width = "300" /> <img src="https://user-images.githubusercontent.com/23690325/169763015-cbd8e28d-602c-4723-810d-dbc6da49441e.jpg" width = "200" />
</div> </div>
## 安装 ## 安装
我们强烈建议用户在 **Linux** 环境下,*3.7* 以上版本的 *python* 上安装 PaddleSpeech。 我们强烈建议用户在 **Linux** 环境下,*3.7* 以上版本的 *python* 上安装 PaddleSpeech。
目前为止,**Linux** 支持声音分类、语音识别、语音合成和语音翻译四种功能,**Mac OSX、 Windows** 下暂不支持语音翻译功能。 想了解具体安装细节,可以参考[安装文档](./docs/source/install_cn.md) 目前为止,**Linux** 支持声音分类、语音识别、语音合成和语音翻译四种功能,**Mac OSX、 Windows** 下暂不支持语音翻译功能。 想了解具体安装细节,可以参考[安装文档](./docs/source/install_cn.md)
<a name="快速开始"></a>
## 快速开始 ## 快速开始
安装完成后,开发者可以通过命令行快速开始,改变 `--input` 可以尝试用自己的音频或文本测试。 安装完成后,开发者可以通过命令行快速开始,改变 `--input` 可以尝试用自己的音频或文本测试。
...@@ -257,7 +227,7 @@ paddlespeech asr --input ./zh.wav | paddlespeech text --task punc ...@@ -257,7 +227,7 @@ paddlespeech asr --input ./zh.wav | paddlespeech text --task punc
更多命令行命令请参考 [demos](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/demos) 更多命令行命令请参考 [demos](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/demos)
> Note: 如果需要训练或者微调,请查看[语音识别](./docs/source/asr/quick_start.md), [语音合成](./docs/source/tts/quick_start.md)。 > Note: 如果需要训练或者微调,请查看[语音识别](./docs/source/asr/quick_start.md), [语音合成](./docs/source/tts/quick_start.md)。
<a name="快速使用服务"></a>
## 快速使用服务 ## 快速使用服务
安装完成后,开发者可以通过命令行快速使用服务。 安装完成后,开发者可以通过命令行快速使用服务。
...@@ -283,30 +253,30 @@ paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav ...@@ -283,30 +253,30 @@ paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
更多服务相关的命令行使用信息,请参考 [demos](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/demos/speech_server) 更多服务相关的命令行使用信息,请参考 [demos](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/demos/speech_server)
<a name="quickstartstreamingserver"></a> <a name="快速使用流式服务"></a>
## 快速使用流式服务 ## 快速使用流式服务
开发者可以尝试[流式ASR](./demos/streaming_asr_server/README.md)[流式TTS](./demos/streaming_tts_server/README.md)服务. 开发者可以尝试 [流式 ASR](./demos/streaming_asr_server/README.md)[流式 TTS](./demos/streaming_tts_server/README.md) 服务.
**启动流式ASR服务** **启动流式 ASR 服务**
``` ```
paddlespeech_server start --config_file ./demos/streaming_asr_server/conf/application.yaml paddlespeech_server start --config_file ./demos/streaming_asr_server/conf/application.yaml
``` ```
**访问流式ASR服务** **访问流式 ASR 服务**
``` ```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
``` ```
**启动流式TTS服务** **启动流式 TTS 服务**
``` ```
paddlespeech_server start --config_file ./demos/streaming_tts_server/conf/tts_online_application.yaml paddlespeech_server start --config_file ./demos/streaming_tts_server/conf/tts_online_application.yaml
``` ```
**访问流式TTS服务** **访问流式 TTS 服务**
``` ```
paddlespeech_client tts_online --server_ip 127.0.0.1 --port 8092 --protocol http --input "您好,欢迎使用百度飞桨语音合成服务。" --output output.wav paddlespeech_client tts_online --server_ip 127.0.0.1 --port 8092 --protocol http --input "您好,欢迎使用百度飞桨语音合成服务。" --output output.wav
...@@ -314,8 +284,7 @@ paddlespeech_client tts_online --server_ip 127.0.0.1 --port 8092 --protocol http ...@@ -314,8 +284,7 @@ paddlespeech_client tts_online --server_ip 127.0.0.1 --port 8092 --protocol http
更多信息参看: [流式 ASR](./demos/streaming_asr_server/README.md)[流式 TTS](./demos/streaming_tts_server/README.md) 更多信息参看: [流式 ASR](./demos/streaming_asr_server/README.md)[流式 TTS](./demos/streaming_tts_server/README.md)
<a name="modulelist"></a> <a name="模型列表"></a>
## 模型列表 ## 模型列表
PaddleSpeech 支持很多主流的模型,并提供了预训练模型,详情请见[模型列表](./docs/source/released_model.md) PaddleSpeech 支持很多主流的模型,并提供了预训练模型,详情请见[模型列表](./docs/source/released_model.md)
...@@ -587,6 +556,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 ...@@ -587,6 +556,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
</tbody> </tbody>
</table> </table>
<a name="教程文档"></a>
## 教程文档 ## 教程文档
对于 PaddleSpeech 的所关注的任务,以下指南有助于帮助开发者快速入门,了解语音相关核心思想。 对于 PaddleSpeech 的所关注的任务,以下指南有助于帮助开发者快速入门,了解语音相关核心思想。
...@@ -668,7 +638,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 ...@@ -668,7 +638,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
<a name="欢迎贡献"></a> <a name="欢迎贡献"></a>
## 参与 PaddleSpeech 的开发 ## 参与 PaddleSpeech 的开发
热烈欢迎您在[Discussions](https://github.com/PaddlePaddle/PaddleSpeech/discussions) 中提交问题,并在[Issues](https://github.com/PaddlePaddle/PaddleSpeech/issues) 中指出发现的 bug。此外,我们非常希望您参与到 PaddleSpeech 的开发中! 热烈欢迎您在 [Discussions](https://github.com/PaddlePaddle/PaddleSpeech/discussions) 中提交问题,并在 [Issues](https://github.com/PaddlePaddle/PaddleSpeech/issues) 中指出发现的 bug。此外,我们非常希望您参与到 PaddleSpeech 的开发中!
### 贡献者 ### 贡献者
<p align="center"> <p align="center">
......
...@@ -16,7 +16,12 @@ see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/doc ...@@ -16,7 +16,12 @@ see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/doc
You can choose one way from meduim and hard to install paddlespeech. You can choose one way from meduim and hard to install paddlespeech.
The dependency refers to the requirements.txt The dependency refers to the requirements.txt, and install the dependency as follows:
```
pip install -r requriement.txt
```
### 2. Prepare Input File ### 2. Prepare Input File
The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model. The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model.
......
...@@ -16,7 +16,11 @@ ...@@ -16,7 +16,11 @@
请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md) 请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)
你可以从 medium,hard 三中方式中选择一种方式安装。 你可以从 medium,hard 三中方式中选择一种方式安装。
依赖参见 requirements.txt 依赖参见 requirements.txt, 安装依赖
```
pip install -r requriement.txt
```
### 2. 准备输入 ### 2. 准备输入
这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。 这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
......
...@@ -28,6 +28,7 @@ acs_python: ...@@ -28,6 +28,7 @@ acs_python:
word_list: "./conf/words.txt" word_list: "./conf/words.txt"
sample_rate: 16000 sample_rate: 16000
device: 'cpu' # set 'gpu:id' or 'cpu' device: 'cpu' # set 'gpu:id' or 'cpu'
ping_timeout: 100 # seconds
......
websocket-client
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from paddlespeech.cli.log import logger
from paddlespeech.server.bin.paddlespeech_server import ServerExecutor
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='paddlespeech_server.start', add_help=True)
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default=None,
required=True)
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
logger.info("start to parse the args")
args = parser.parse_args()
logger.info("start to launch the streaming asr server")
streaming_asr_server = ServerExecutor()
streaming_asr_server(config_file=args.config_file, log_file=args.log_file)
...@@ -26,8 +26,7 @@ def get_audios(path): ...@@ -26,8 +26,7 @@ def get_audios(path):
""" """
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"] supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
return [ return [
item item for sublist in [[os.path.join(dir, file) for file in files]
for sublist in [[os.path.join(dir, file) for file in files]
for dir, _, files in list(os.walk(path))] for dir, _, files in list(os.walk(path))]
for item in sublist if os.path.splitext(item)[1] in supported_formats for item in sublist if os.path.splitext(item)[1] in supported_formats
] ]
......
...@@ -53,50 +53,49 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav ...@@ -53,50 +53,49 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
Output: Output:
```bash ```bash
demo [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055 demo [ -1.3251206 7.8606825 -4.620626 0.3000721 2.2648535
1.756596 5.167894 10.80636 -3.8226728 -5.6141334 -1.1931441 3.0647137 7.673595 -6.0044727 -12.02426
2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897 -1.9496069 3.1269536 1.618838 -7.6383104 -1.2299773
-9.723131 0.6619743 -6.976803 10.213478 7.494748 -12.338331 2.1373026 -5.3957124 9.717328 5.6752305
2.9105635 3.8949256 3.7999806 7.1061673 16.905321 3.7805123 3.0597172 3.429692 8.97601 13.174125
-7.1493764 8.733103 3.4230042 -4.831653 -11.403367 -0.53132284 8.9424715 4.46511 -4.4262476 -9.726503
11.232214 7.1274667 -4.2828417 2.452362 -5.130748 8.399328 7.2239175 -7.435854 2.9441683 -4.3430395
-18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683 -13.886965 -1.6346735 -10.9027405 -5.311245 3.8007221
0.7618269 1.1253023 -2.083836 4.725744 -8.782597 3.8976038 -2.1230774 -2.3521194 4.151031 -7.4048667
-3.539873 3.814236 5.1420674 2.162061 4.096431 0.13911647 2.4626107 4.9664545 0.9897574 5.4839754
-6.4162116 12.747448 1.9429878 -15.152943 6.417416 -3.3574002 10.1340065 -0.6120171 -10.403095 4.6007543
16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939 16.00935 -7.7836914 -4.1945305 -6.9368606 1.1789556
11.567354 3.69788 11.258265 7.442363 9.183411 11.490801 4.2380238 9.550931 8.375046 7.5089145
4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783 -0.65707296 -0.30051577 2.8406055 3.0828028 0.730817
7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073 6.148354 0.13766119 -13.424735 -7.7461405 -2.3227983
-3.7776625 6.917234 -9.848748 -2.0944717 -5.135116 -8.305252 2.9879124 -10.995229 0.15211068 -2.3820348
0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578 -1.7984174 8.495629 -5.8522367 -3.755498 0.6989711
-7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651 -5.2702994 -2.6188622 -1.8828466 -4.64665 14.078544
-4.473626 7.6624317 -0.55381083 9.631587 -6.4704556 -0.5495333 10.579158 -3.2160501 9.349004 -4.381078
-8.548508 4.3716145 -0.79702514 4.478997 -2.9758704 -11.675817 -2.8630207 4.5721755 2.246612 -4.574342
3.272176 2.8382776 5.134597 -9.190781 -0.5657382 1.8610188 2.3767874 5.6257877 -9.784078 0.64967257
-4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576 -1.4579505 0.4263264 -4.9211264 -2.454784 3.4869802
-0.31784213 9.493548 2.1144536 4.358092 -12.089823 -0.42654222 8.341269 1.356552 7.0966883 -13.102829
8.451689 -7.925461 4.6242585 4.4289427 18.692003 8.016734 -7.1159344 1.8699781 0.208721 14.699384
-2.6204622 -5.149185 -0.35821092 8.488551 4.981496 -1.025278 -2.6107233 -2.5082312 8.427193 6.9138527
-9.32683 -2.2544234 6.6417594 1.2119585 10.977129 -6.2912464 0.6157366 2.489688 -3.4668267 9.921763
16.555033 3.3238444 9.551863 -1.6676947 -0.79539716 11.200815 -0.1966403 7.4916005 -0.62312716 -0.25848144
-8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796 -9.947997 -0.9611041 1.1649219 -2.1907122 -1.5028487
0.66607 15.443222 4.740594 -3.4725387 11.592567 -0.51926106 15.165954 2.4649463 -0.9980445 7.4416637
-2.054497 1.7361217 -8.265324 -9.30447 5.4068313 -2.0768049 3.5896823 -7.3055434 -7.5620847 4.323335
-1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733 0.0804418 -6.56401 -2.3148053 -1.7642345 -2.4708817
-8.649895 -9.998958 -2.564841 -0.53999114 2.601808 -7.675618 -9.548878 -1.0177554 0.16986446 2.5877135
-0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085 -1.8752296 -0.36614323 -6.0493784 -2.3965611 -5.9453387
1.483641 -15.365992 -8.288208 3.8847756 -3.4876456 0.9424033 -13.155974 -7.457801 0.14658108 -3.742797
7.3629923 0.4657332 3.132599 12.438889 -1.8337058 5.8414927 -1.2872906 5.5694313 12.57059 1.0939219
4.532936 2.7264361 10.145339 -6.521951 2.897153 2.2142086 1.9181576 6.9914207 -5.888139 3.1409824
-3.3925855 5.079156 7.759716 4.677565 5.8457737 -2.003628 2.4434285 9.973139 5.03668 2.0051203
2.402413 7.7071047 3.9711342 -6.390043 6.1268735 2.8615603 5.860224 2.9176188 -1.6311141 2.0292206
-3.7760346 -11.118123 ] -4.070415 -6.831437 ]
``` ```
- Python API - Python API
```python ```python
import paddle
from paddlespeech.cli import VectorExecutor from paddlespeech.cli import VectorExecutor
vector_executor = VectorExecutor() vector_executor = VectorExecutor()
...@@ -128,88 +127,88 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav ...@@ -128,88 +127,88 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
```bash ```bash
# Vector Result: # Vector Result:
Audio embedding Result: Audio embedding Result:
[ 1.4217498 5.626253 -5.342073 1.1773866 3.308055 [ -1.3251206 7.8606825 -4.620626 0.3000721 2.2648535
1.756596 5.167894 10.80636 -3.8226728 -5.6141334 -1.1931441 3.0647137 7.673595 -6.0044727 -12.02426
2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897 -1.9496069 3.1269536 1.618838 -7.6383104 -1.2299773
-9.723131 0.6619743 -6.976803 10.213478 7.494748 -12.338331 2.1373026 -5.3957124 9.717328 5.6752305
2.9105635 3.8949256 3.7999806 7.1061673 16.905321 3.7805123 3.0597172 3.429692 8.97601 13.174125
-7.1493764 8.733103 3.4230042 -4.831653 -11.403367 -0.53132284 8.9424715 4.46511 -4.4262476 -9.726503
11.232214 7.1274667 -4.2828417 2.452362 -5.130748 8.399328 7.2239175 -7.435854 2.9441683 -4.3430395
-18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683 -13.886965 -1.6346735 -10.9027405 -5.311245 3.8007221
0.7618269 1.1253023 -2.083836 4.725744 -8.782597 3.8976038 -2.1230774 -2.3521194 4.151031 -7.4048667
-3.539873 3.814236 5.1420674 2.162061 4.096431 0.13911647 2.4626107 4.9664545 0.9897574 5.4839754
-6.4162116 12.747448 1.9429878 -15.152943 6.417416 -3.3574002 10.1340065 -0.6120171 -10.403095 4.6007543
16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939 16.00935 -7.7836914 -4.1945305 -6.9368606 1.1789556
11.567354 3.69788 11.258265 7.442363 9.183411 11.490801 4.2380238 9.550931 8.375046 7.5089145
4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783 -0.65707296 -0.30051577 2.8406055 3.0828028 0.730817
7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073 6.148354 0.13766119 -13.424735 -7.7461405 -2.3227983
-3.7776625 6.917234 -9.848748 -2.0944717 -5.135116 -8.305252 2.9879124 -10.995229 0.15211068 -2.3820348
0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578 -1.7984174 8.495629 -5.8522367 -3.755498 0.6989711
-7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651 -5.2702994 -2.6188622 -1.8828466 -4.64665 14.078544
-4.473626 7.6624317 -0.55381083 9.631587 -6.4704556 -0.5495333 10.579158 -3.2160501 9.349004 -4.381078
-8.548508 4.3716145 -0.79702514 4.478997 -2.9758704 -11.675817 -2.8630207 4.5721755 2.246612 -4.574342
3.272176 2.8382776 5.134597 -9.190781 -0.5657382 1.8610188 2.3767874 5.6257877 -9.784078 0.64967257
-4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576 -1.4579505 0.4263264 -4.9211264 -2.454784 3.4869802
-0.31784213 9.493548 2.1144536 4.358092 -12.089823 -0.42654222 8.341269 1.356552 7.0966883 -13.102829
8.451689 -7.925461 4.6242585 4.4289427 18.692003 8.016734 -7.1159344 1.8699781 0.208721 14.699384
-2.6204622 -5.149185 -0.35821092 8.488551 4.981496 -1.025278 -2.6107233 -2.5082312 8.427193 6.9138527
-9.32683 -2.2544234 6.6417594 1.2119585 10.977129 -6.2912464 0.6157366 2.489688 -3.4668267 9.921763
16.555033 3.3238444 9.551863 -1.6676947 -0.79539716 11.200815 -0.1966403 7.4916005 -0.62312716 -0.25848144
-8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796 -9.947997 -0.9611041 1.1649219 -2.1907122 -1.5028487
0.66607 15.443222 4.740594 -3.4725387 11.592567 -0.51926106 15.165954 2.4649463 -0.9980445 7.4416637
-2.054497 1.7361217 -8.265324 -9.30447 5.4068313 -2.0768049 3.5896823 -7.3055434 -7.5620847 4.323335
-1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733 0.0804418 -6.56401 -2.3148053 -1.7642345 -2.4708817
-8.649895 -9.998958 -2.564841 -0.53999114 2.601808 -7.675618 -9.548878 -1.0177554 0.16986446 2.5877135
-0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085 -1.8752296 -0.36614323 -6.0493784 -2.3965611 -5.9453387
1.483641 -15.365992 -8.288208 3.8847756 -3.4876456 0.9424033 -13.155974 -7.457801 0.14658108 -3.742797
7.3629923 0.4657332 3.132599 12.438889 -1.8337058 5.8414927 -1.2872906 5.5694313 12.57059 1.0939219
4.532936 2.7264361 10.145339 -6.521951 2.897153 2.2142086 1.9181576 6.9914207 -5.888139 3.1409824
-3.3925855 5.079156 7.759716 4.677565 5.8457737 -2.003628 2.4434285 9.973139 5.03668 2.0051203
2.402413 7.7071047 3.9711342 -6.390043 6.1268735 2.8615603 5.860224 2.9176188 -1.6311141 2.0292206
-3.7760346 -11.118123 ] -4.070415 -6.831437 ]
# get the test embedding # get the test embedding
Test embedding Result: Test embedding Result:
[ -1.902964 2.0690894 -8.034194 3.5472693 0.18089125 [ 2.5247195 5.119042 -4.335273 4.4583654 5.047907
6.9085927 1.4097427 -1.9487704 -10.021278 -0.20755845 3.5059214 1.6159848 0.49364898 -11.6899185 -3.1014526
-8.04332 4.344489 2.3200977 -14.306299 5.184692 -5.6589785 -0.42684984 2.674276 -11.937654 6.2248464
-11.55602 -3.8497238 0.6444722 1.2833948 2.6766639 -10.776924 -5.694543 1.112041 1.5709964 1.0961034
0.5878921 0.7946299 1.7207596 2.5791872 14.998469 1.3976512 2.324352 1.339981 5.279319 13.734659
-1.3385371 15.031221 -0.8006958 1.99287 -9.52007 -2.5753925 13.651442 -2.2357535 5.1575427 -3.251567
2.435466 4.003221 -4.33817 -4.898601 -5.304714 1.4023279 6.1191974 -6.0845175 -1.3646189 -2.6789894
-18.033886 10.790787 -12.784645 -5.641755 2.9761686 -15.220778 9.779349 -9.411551 -6.388947 6.8313975
-10.566622 1.4839455 6.152458 -5.7195854 2.8603241 -9.245996 0.31196198 2.5509644 -4.413065 6.1649427
6.112133 8.489869 5.5958056 1.2836679 -1.2293907 6.793837 2.6328635 8.620976 3.4832475 0.52491665
0.89927405 7.0288725 -2.854029 -0.9782962 5.8255906 2.9115407 5.8392377 0.6702376 -3.2726715 2.6694255
14.905906 -5.025907 0.7866458 -4.2444224 -16.354029 16.91701 -5.5811176 0.23362345 -4.5573606 -11.801059
10.521315 0.9604709 -3.3257897 7.144871 -13.592733 14.728292 -0.5198082 -3.999922 7.0927105 -7.0459595
-8.568869 -1.7953678 0.26313916 10.916714 -6.9374123 -5.4389 -0.46420583 -5.1085467 10.376568 -8.889225
1.857403 -6.2746415 2.8154466 -7.2338667 -2.293357 -0.37705845 -1.659806 2.6731026 -7.1909504 1.4608804
-0.05452765 5.4287076 5.0849075 -6.690375 -1.6183422 -2.163136 -0.17949677 4.0241547 0.11319201 0.601279
3.654291 0.94352573 -9.200294 -5.4749465 -3.5235846 2.039692 3.1910992 -11.649526 -8.121584 -4.8707457
1.3420814 4.240421 -2.772944 -2.8451524 16.311104 0.3851982 1.4231744 -2.3321972 0.99332285 14.121717
4.2969875 -1.762936 -12.5758915 8.595198 -0.8835239 5.899413 0.7384519 -17.760096 10.555021 4.1366534
-1.5708797 1.568961 1.1413603 3.5032008 -0.45251232 -0.3391071 -0.20792882 3.208204 0.8847948 -8.721497
-6.786333 16.89443 5.3366146 -8.789056 0.6355629 -6.432868 13.006379 4.8956 -9.155822 -1.9441519
3.2579517 -3.328322 7.5969577 0.66025066 -6.550468 5.7815638 -2.066733 10.425042 -0.8802383 -2.4314315
-9.148656 2.020372 -0.4615173 1.1965656 -3.8764873 -9.869258 0.35095334 -5.3549943 2.1076174 -8.290468
11.6562195 -6.0750933 12.182899 3.2218833 0.81969476 8.4433365 -4.689333 9.334139 -2.172678 -3.0250976
5.570001 -3.8459578 -7.205299 7.9262037 -7.6611166 8.394216 -3.2110903 -7.93868 2.3960824 -2.3213403
-5.249467 -2.2671914 7.2658715 -13.298164 4.821147 -1.4963245 -3.476059 4.132903 -10.893354 4.362673
-2.7263982 11.691089 -3.8918593 -2.838112 -1.0336838 -0.45456508 10.258634 -1.1655927 -6.7799754 0.22885278
-3.8034165 2.8536487 -5.60398 -1.1972581 1.3455094 -4.399287 2.333433 -4.84745 -4.2752337 -1.3577863
-3.4903061 2.2408795 5.5010734 -3.970756 11.99696 -1.0685898 9.505196 7.3062205 0.08708266 12.927811
-7.8858757 0.43160373 -5.5059714 4.3426995 16.322706 -9.57974 1.3936648 -1.9444873 5.776769 15.251903
11.635366 0.72157705 -9.245714 -3.91465 -4.449838 10.6118355 -1.4903594 -9.535318 -3.6553776 -1.6699586
-1.5716927 7.713747 -2.2430465 -6.198303 -13.481864 -0.5933151 7.600357 -4.8815503 -8.698617 -15.855757
2.8156567 -5.7812386 5.1456156 2.7289324 -14.505571 0.25632986 -7.2235737 0.9506656 0.7128582 -9.051738
13.270688 3.448231 -7.0659585 4.5886116 -4.466099 8.74869 -1.6426028 -6.5762258 2.506905 -6.7431564
-0.296428 -11.463529 -2.6076477 14.110243 -6.9725137 5.129912 -12.189555 -3.6435068 12.068113 -6.0059533
-1.9962958 2.7119343 19.391657 0.01961198 14.607133 -2.3535995 2.9014351 22.3082 -1.5563312 13.193291
-1.6695905 -4.391516 1.3131028 -6.670972 -5.888604 2.7583609 -7.468798 1.3407065 -4.599617 -6.2345777
12.0612335 5.9285784 3.3715196 1.492534 10.723728 10.7689295 7.137627 5.099476 0.3473359 9.647881
-0.95514804 -12.085431 ] -2.0484571 -5.8549366 ]
# get the score between enroll and test # get the score between enroll and test
Eembeddings Score: 0.4292638301849365 Eembeddings Score: 0.45332613587379456
``` ```
### 4.Pretrained Models ### 4.Pretrained Models
......
...@@ -51,45 +51,45 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav ...@@ -51,45 +51,45 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
输出: 输出:
```bash ```bash
demo [ 1.4217498 5.626253 -5.342073 1.1773866 3.308055 [ -1.3251206 7.8606825 -4.620626 0.3000721 2.2648535
1.756596 5.167894 10.80636 -3.8226728 -5.6141334 -1.1931441 3.0647137 7.673595 -6.0044727 -12.02426
2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897 -1.9496069 3.1269536 1.618838 -7.6383104 -1.2299773
-9.723131 0.6619743 -6.976803 10.213478 7.494748 -12.338331 2.1373026 -5.3957124 9.717328 5.6752305
2.9105635 3.8949256 3.7999806 7.1061673 16.905321 3.7805123 3.0597172 3.429692 8.97601 13.174125
-7.1493764 8.733103 3.4230042 -4.831653 -11.403367 -0.53132284 8.9424715 4.46511 -4.4262476 -9.726503
11.232214 7.1274667 -4.2828417 2.452362 -5.130748 8.399328 7.2239175 -7.435854 2.9441683 -4.3430395
-18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683 -13.886965 -1.6346735 -10.9027405 -5.311245 3.8007221
0.7618269 1.1253023 -2.083836 4.725744 -8.782597 3.8976038 -2.1230774 -2.3521194 4.151031 -7.4048667
-3.539873 3.814236 5.1420674 2.162061 4.096431 0.13911647 2.4626107 4.9664545 0.9897574 5.4839754
-6.4162116 12.747448 1.9429878 -15.152943 6.417416 -3.3574002 10.1340065 -0.6120171 -10.403095 4.6007543
16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939 16.00935 -7.7836914 -4.1945305 -6.9368606 1.1789556
11.567354 3.69788 11.258265 7.442363 9.183411 11.490801 4.2380238 9.550931 8.375046 7.5089145
4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783 -0.65707296 -0.30051577 2.8406055 3.0828028 0.730817
7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073 6.148354 0.13766119 -13.424735 -7.7461405 -2.3227983
-3.7776625 6.917234 -9.848748 -2.0944717 -5.135116 -8.305252 2.9879124 -10.995229 0.15211068 -2.3820348
0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578 -1.7984174 8.495629 -5.8522367 -3.755498 0.6989711
-7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651 -5.2702994 -2.6188622 -1.8828466 -4.64665 14.078544
-4.473626 7.6624317 -0.55381083 9.631587 -6.4704556 -0.5495333 10.579158 -3.2160501 9.349004 -4.381078
-8.548508 4.3716145 -0.79702514 4.478997 -2.9758704 -11.675817 -2.8630207 4.5721755 2.246612 -4.574342
3.272176 2.8382776 5.134597 -9.190781 -0.5657382 1.8610188 2.3767874 5.6257877 -9.784078 0.64967257
-4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576 -1.4579505 0.4263264 -4.9211264 -2.454784 3.4869802
-0.31784213 9.493548 2.1144536 4.358092 -12.089823 -0.42654222 8.341269 1.356552 7.0966883 -13.102829
8.451689 -7.925461 4.6242585 4.4289427 18.692003 8.016734 -7.1159344 1.8699781 0.208721 14.699384
-2.6204622 -5.149185 -0.35821092 8.488551 4.981496 -1.025278 -2.6107233 -2.5082312 8.427193 6.9138527
-9.32683 -2.2544234 6.6417594 1.2119585 10.977129 -6.2912464 0.6157366 2.489688 -3.4668267 9.921763
16.555033 3.3238444 9.551863 -1.6676947 -0.79539716 11.200815 -0.1966403 7.4916005 -0.62312716 -0.25848144
-8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796 -9.947997 -0.9611041 1.1649219 -2.1907122 -1.5028487
0.66607 15.443222 4.740594 -3.4725387 11.592567 -0.51926106 15.165954 2.4649463 -0.9980445 7.4416637
-2.054497 1.7361217 -8.265324 -9.30447 5.4068313 -2.0768049 3.5896823 -7.3055434 -7.5620847 4.323335
-1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733 0.0804418 -6.56401 -2.3148053 -1.7642345 -2.4708817
-8.649895 -9.998958 -2.564841 -0.53999114 2.601808 -7.675618 -9.548878 -1.0177554 0.16986446 2.5877135
-0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085 -1.8752296 -0.36614323 -6.0493784 -2.3965611 -5.9453387
1.483641 -15.365992 -8.288208 3.8847756 -3.4876456 0.9424033 -13.155974 -7.457801 0.14658108 -3.742797
7.3629923 0.4657332 3.132599 12.438889 -1.8337058 5.8414927 -1.2872906 5.5694313 12.57059 1.0939219
4.532936 2.7264361 10.145339 -6.521951 2.897153 2.2142086 1.9181576 6.9914207 -5.888139 3.1409824
-3.3925855 5.079156 7.759716 4.677565 5.8457737 -2.003628 2.4434285 9.973139 5.03668 2.0051203
2.402413 7.7071047 3.9711342 -6.390043 6.1268735 2.8615603 5.860224 2.9176188 -1.6311141 2.0292206
-3.7760346 -11.118123 ] -4.070415 -6.831437 ]
``` ```
- Python API - Python API
...@@ -125,88 +125,88 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav ...@@ -125,88 +125,88 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
```bash ```bash
# Vector Result: # Vector Result:
Audio embedding Result: Audio embedding Result:
[ 1.4217498 5.626253 -5.342073 1.1773866 3.308055 [ -1.3251206 7.8606825 -4.620626 0.3000721 2.2648535
1.756596 5.167894 10.80636 -3.8226728 -5.6141334 -1.1931441 3.0647137 7.673595 -6.0044727 -12.02426
2.623845 -0.8072968 1.9635103 -7.3128724 0.01103897 -1.9496069 3.1269536 1.618838 -7.6383104 -1.2299773
-9.723131 0.6619743 -6.976803 10.213478 7.494748 -12.338331 2.1373026 -5.3957124 9.717328 5.6752305
2.9105635 3.8949256 3.7999806 7.1061673 16.905321 3.7805123 3.0597172 3.429692 8.97601 13.174125
-7.1493764 8.733103 3.4230042 -4.831653 -11.403367 -0.53132284 8.9424715 4.46511 -4.4262476 -9.726503
11.232214 7.1274667 -4.2828417 2.452362 -5.130748 8.399328 7.2239175 -7.435854 2.9441683 -4.3430395
-18.177666 -2.6116815 -11.000337 -6.7314315 1.6564683 -13.886965 -1.6346735 -10.9027405 -5.311245 3.8007221
0.7618269 1.1253023 -2.083836 4.725744 -8.782597 3.8976038 -2.1230774 -2.3521194 4.151031 -7.4048667
-3.539873 3.814236 5.1420674 2.162061 4.096431 0.13911647 2.4626107 4.9664545 0.9897574 5.4839754
-6.4162116 12.747448 1.9429878 -15.152943 6.417416 -3.3574002 10.1340065 -0.6120171 -10.403095 4.6007543
16.097002 -9.716668 -1.9920526 -3.3649497 -1.871939 16.00935 -7.7836914 -4.1945305 -6.9368606 1.1789556
11.567354 3.69788 11.258265 7.442363 9.183411 11.490801 4.2380238 9.550931 8.375046 7.5089145
4.5281515 -1.2417862 4.3959084 6.6727695 5.8898783 -0.65707296 -0.30051577 2.8406055 3.0828028 0.730817
7.627124 -0.66919386 -11.889693 -9.208865 -7.4274073 6.148354 0.13766119 -13.424735 -7.7461405 -2.3227983
-3.7776625 6.917234 -9.848748 -2.0944717 -5.135116 -8.305252 2.9879124 -10.995229 0.15211068 -2.3820348
0.49563864 9.317534 -5.9141874 -1.8098574 -0.11738578 -1.7984174 8.495629 -5.8522367 -3.755498 0.6989711
-7.169265 -1.0578263 -5.7216787 -5.1173844 16.137651 -5.2702994 -2.6188622 -1.8828466 -4.64665 14.078544
-4.473626 7.6624317 -0.55381083 9.631587 -6.4704556 -0.5495333 10.579158 -3.2160501 9.349004 -4.381078
-8.548508 4.3716145 -0.79702514 4.478997 -2.9758704 -11.675817 -2.8630207 4.5721755 2.246612 -4.574342
3.272176 2.8382776 5.134597 -9.190781 -0.5657382 1.8610188 2.3767874 5.6257877 -9.784078 0.64967257
-4.8745747 2.3165567 -5.984303 -2.1798875 0.35541576 -1.4579505 0.4263264 -4.9211264 -2.454784 3.4869802
-0.31784213 9.493548 2.1144536 4.358092 -12.089823 -0.42654222 8.341269 1.356552 7.0966883 -13.102829
8.451689 -7.925461 4.6242585 4.4289427 18.692003 8.016734 -7.1159344 1.8699781 0.208721 14.699384
-2.6204622 -5.149185 -0.35821092 8.488551 4.981496 -1.025278 -2.6107233 -2.5082312 8.427193 6.9138527
-9.32683 -2.2544234 6.6417594 1.2119585 10.977129 -6.2912464 0.6157366 2.489688 -3.4668267 9.921763
16.555033 3.3238444 9.551863 -1.6676947 -0.79539716 11.200815 -0.1966403 7.4916005 -0.62312716 -0.25848144
-8.605674 -0.47356385 2.6741948 -5.359179 -2.6673796 -9.947997 -0.9611041 1.1649219 -2.1907122 -1.5028487
0.66607 15.443222 4.740594 -3.4725387 11.592567 -0.51926106 15.165954 2.4649463 -0.9980445 7.4416637
-2.054497 1.7361217 -8.265324 -9.30447 5.4068313 -2.0768049 3.5896823 -7.3055434 -7.5620847 4.323335
-1.5180256 -7.746615 -6.089606 0.07112726 -0.34904733 0.0804418 -6.56401 -2.3148053 -1.7642345 -2.4708817
-8.649895 -9.998958 -2.564841 -0.53999114 2.601808 -7.675618 -9.548878 -1.0177554 0.16986446 2.5877135
-0.31927416 -1.8815292 -2.07215 -3.4105783 -8.2998085 -1.8752296 -0.36614323 -6.0493784 -2.3965611 -5.9453387
1.483641 -15.365992 -8.288208 3.8847756 -3.4876456 0.9424033 -13.155974 -7.457801 0.14658108 -3.742797
7.3629923 0.4657332 3.132599 12.438889 -1.8337058 5.8414927 -1.2872906 5.5694313 12.57059 1.0939219
4.532936 2.7264361 10.145339 -6.521951 2.897153 2.2142086 1.9181576 6.9914207 -5.888139 3.1409824
-3.3925855 5.079156 7.759716 4.677565 5.8457737 -2.003628 2.4434285 9.973139 5.03668 2.0051203
2.402413 7.7071047 3.9711342 -6.390043 6.1268735 2.8615603 5.860224 2.9176188 -1.6311141 2.0292206
-3.7760346 -11.118123 ] -4.070415 -6.831437 ]
# get the test embedding # get the test embedding
Test embedding Result: Test embedding Result:
[ -1.902964 2.0690894 -8.034194 3.5472693 0.18089125 [ 2.5247195 5.119042 -4.335273 4.4583654 5.047907
6.9085927 1.4097427 -1.9487704 -10.021278 -0.20755845 3.5059214 1.6159848 0.49364898 -11.6899185 -3.1014526
-8.04332 4.344489 2.3200977 -14.306299 5.184692 -5.6589785 -0.42684984 2.674276 -11.937654 6.2248464
-11.55602 -3.8497238 0.6444722 1.2833948 2.6766639 -10.776924 -5.694543 1.112041 1.5709964 1.0961034
0.5878921 0.7946299 1.7207596 2.5791872 14.998469 1.3976512 2.324352 1.339981 5.279319 13.734659
-1.3385371 15.031221 -0.8006958 1.99287 -9.52007 -2.5753925 13.651442 -2.2357535 5.1575427 -3.251567
2.435466 4.003221 -4.33817 -4.898601 -5.304714 1.4023279 6.1191974 -6.0845175 -1.3646189 -2.6789894
-18.033886 10.790787 -12.784645 -5.641755 2.9761686 -15.220778 9.779349 -9.411551 -6.388947 6.8313975
-10.566622 1.4839455 6.152458 -5.7195854 2.8603241 -9.245996 0.31196198 2.5509644 -4.413065 6.1649427
6.112133 8.489869 5.5958056 1.2836679 -1.2293907 6.793837 2.6328635 8.620976 3.4832475 0.52491665
0.89927405 7.0288725 -2.854029 -0.9782962 5.8255906 2.9115407 5.8392377 0.6702376 -3.2726715 2.6694255
14.905906 -5.025907 0.7866458 -4.2444224 -16.354029 16.91701 -5.5811176 0.23362345 -4.5573606 -11.801059
10.521315 0.9604709 -3.3257897 7.144871 -13.592733 14.728292 -0.5198082 -3.999922 7.0927105 -7.0459595
-8.568869 -1.7953678 0.26313916 10.916714 -6.9374123 -5.4389 -0.46420583 -5.1085467 10.376568 -8.889225
1.857403 -6.2746415 2.8154466 -7.2338667 -2.293357 -0.37705845 -1.659806 2.6731026 -7.1909504 1.4608804
-0.05452765 5.4287076 5.0849075 -6.690375 -1.6183422 -2.163136 -0.17949677 4.0241547 0.11319201 0.601279
3.654291 0.94352573 -9.200294 -5.4749465 -3.5235846 2.039692 3.1910992 -11.649526 -8.121584 -4.8707457
1.3420814 4.240421 -2.772944 -2.8451524 16.311104 0.3851982 1.4231744 -2.3321972 0.99332285 14.121717
4.2969875 -1.762936 -12.5758915 8.595198 -0.8835239 5.899413 0.7384519 -17.760096 10.555021 4.1366534
-1.5708797 1.568961 1.1413603 3.5032008 -0.45251232 -0.3391071 -0.20792882 3.208204 0.8847948 -8.721497
-6.786333 16.89443 5.3366146 -8.789056 0.6355629 -6.432868 13.006379 4.8956 -9.155822 -1.9441519
3.2579517 -3.328322 7.5969577 0.66025066 -6.550468 5.7815638 -2.066733 10.425042 -0.8802383 -2.4314315
-9.148656 2.020372 -0.4615173 1.1965656 -3.8764873 -9.869258 0.35095334 -5.3549943 2.1076174 -8.290468
11.6562195 -6.0750933 12.182899 3.2218833 0.81969476 8.4433365 -4.689333 9.334139 -2.172678 -3.0250976
5.570001 -3.8459578 -7.205299 7.9262037 -7.6611166 8.394216 -3.2110903 -7.93868 2.3960824 -2.3213403
-5.249467 -2.2671914 7.2658715 -13.298164 4.821147 -1.4963245 -3.476059 4.132903 -10.893354 4.362673
-2.7263982 11.691089 -3.8918593 -2.838112 -1.0336838 -0.45456508 10.258634 -1.1655927 -6.7799754 0.22885278
-3.8034165 2.8536487 -5.60398 -1.1972581 1.3455094 -4.399287 2.333433 -4.84745 -4.2752337 -1.3577863
-3.4903061 2.2408795 5.5010734 -3.970756 11.99696 -1.0685898 9.505196 7.3062205 0.08708266 12.927811
-7.8858757 0.43160373 -5.5059714 4.3426995 16.322706 -9.57974 1.3936648 -1.9444873 5.776769 15.251903
11.635366 0.72157705 -9.245714 -3.91465 -4.449838 10.6118355 -1.4903594 -9.535318 -3.6553776 -1.6699586
-1.5716927 7.713747 -2.2430465 -6.198303 -13.481864 -0.5933151 7.600357 -4.8815503 -8.698617 -15.855757
2.8156567 -5.7812386 5.1456156 2.7289324 -14.505571 0.25632986 -7.2235737 0.9506656 0.7128582 -9.051738
13.270688 3.448231 -7.0659585 4.5886116 -4.466099 8.74869 -1.6426028 -6.5762258 2.506905 -6.7431564
-0.296428 -11.463529 -2.6076477 14.110243 -6.9725137 5.129912 -12.189555 -3.6435068 12.068113 -6.0059533
-1.9962958 2.7119343 19.391657 0.01961198 14.607133 -2.3535995 2.9014351 22.3082 -1.5563312 13.193291
-1.6695905 -4.391516 1.3131028 -6.670972 -5.888604 2.7583609 -7.468798 1.3407065 -4.599617 -6.2345777
12.0612335 5.9285784 3.3715196 1.492534 10.723728 10.7689295 7.137627 5.099476 0.3473359 9.647881
-0.95514804 -12.085431 ] -2.0484571 -5.8549366 ]
# get the score between enroll and test # get the score between enroll and test
Eembeddings Score: 0.4292638301849365 Eembeddings Score: 0.45332613587379456
``` ```
### 4.预训练模型 ### 4.预训练模型
......
...@@ -274,12 +274,12 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -274,12 +274,12 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
Output: Output:
```bash ```bash
[2022-05-08 00:18:44,249] [ INFO] - vector http client start [2022-05-25 12:25:36,165] [ INFO] - vector http client start
[2022-05-08 00:18:44,250] [ INFO] - the input audio: 85236145389.wav [2022-05-25 12:25:36,165] [ INFO] - the input audio: 85236145389.wav
[2022-05-08 00:18:44,250] [ INFO] - endpoint: http://127.0.0.1:8090/paddlespeech/vector [2022-05-25 12:25:36,165] [ INFO] - endpoint: http://127.0.0.1:8790/paddlespeech/vector
[2022-05-08 00:18:44,250] [ INFO] - http://127.0.0.1:8590/paddlespeech/vector [2022-05-25 12:25:36,166] [ INFO] - http://127.0.0.1:8790/paddlespeech/vector
[2022-05-08 00:18:44,406] [ INFO] - The vector: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [1.421751856803894, 5.626245498657227, -5.342077255249023, 1.1773887872695923, 3.3080549240112305, 1.7565933465957642, 5.167886257171631, 10.806358337402344, -3.8226819038391113, -5.614140033721924, 2.6238479614257812, -0.8072972893714905, 1.9635076522827148, -7.312870025634766, 0.011035939678549767, -9.723129272460938, 0.6619706153869629, -6.976806163787842, 10.213476181030273, 7.494769096374512, 2.9105682373046875, 3.8949244022369385, 3.799983501434326, 7.106168746948242, 16.90532875061035, -7.149388313293457, 8.733108520507812, 3.423006296157837, -4.831653594970703, -11.403363227844238, 11.232224464416504, 7.127461910247803, -4.282842636108398, 2.452359437942505, -5.130749702453613, -18.17766761779785, -2.6116831302642822, -11.000344276428223, -6.731433391571045, 1.6564682722091675, 0.7618281245231628, 1.125300407409668, -2.0838370323181152, 4.725743293762207, -8.782588005065918, -3.5398752689361572, 3.8142364025115967, 5.142068862915039, 2.1620609760284424, 4.09643030166626, -6.416214942932129, 12.747446060180664, 1.9429892301559448, -15.15294361114502, 6.417416095733643, 16.09701156616211, -9.716667175292969, -1.9920575618743896, -3.36494779586792, -1.8719440698623657, 11.567351341247559, 3.6978814601898193, 11.258262634277344, 7.442368507385254, 9.183408737182617, 4.528149127960205, -1.2417854070663452, 4.395912170410156, 6.6727728843688965, 5.88988733291626, 7.627128601074219, -0.6691966652870178, -11.889698028564453, -9.20886516571045, -7.42740535736084, -3.777663230895996, 6.917238712310791, -9.848755836486816, -2.0944676399230957, -5.1351165771484375, 0.4956451654434204, 9.317537307739258, -5.914181232452393, -1.809860348701477, -0.11738915741443634, -7.1692705154418945, -1.057827353477478, -5.721670627593994, -5.117385387420654, 16.13765525817871, -4.473617076873779, 7.6624321937561035, -0.55381840467453, 9.631585121154785, -6.470459461212158, -8.548508644104004, 4.371616840362549, -0.7970245480537415, 4.4789886474609375, -2.975860834121704, 3.2721822261810303, 2.838287830352783, 5.134591102600098, -9.19079875946045, -0.5657302737236023, -4.8745832443237305, 2.3165574073791504, -5.984319686889648, -2.1798853874206543, 0.3554139733314514, -0.3178512752056122, 9.493552207946777, 2.1144471168518066, 4.358094692230225, -12.089824676513672, 8.451693534851074, -7.925466537475586, 4.624246597290039, 4.428936958312988, 18.69200897216797, -2.6204581260681152, -5.14918851852417, -0.3582090139389038, 8.488558769226074, 4.98148775100708, -9.326835632324219, -2.2544219493865967, 6.641760349273682, 1.2119598388671875, 10.977124214172363, 16.555034637451172, 3.3238420486450195, 9.551861763000488, -1.6676981449127197, -0.7953944206237793, -8.605667114257812, -0.4735655188560486, 2.674196243286133, -5.359177112579346, -2.66738224029541, 0.6660683155059814, 15.44322681427002, 4.740593433380127, -3.472534418106079, 11.592567443847656, -2.0544962882995605, 1.736127495765686, -8.265326499938965, -9.30447769165039, 5.406829833984375, -1.518022894859314, -7.746612548828125, -6.089611053466797, 0.07112743705511093, -0.3490503430366516, -8.64989185333252, -9.998957633972168, -2.564845085144043, -0.5399947762489319, 2.6018123626708984, -0.3192799389362335, -1.8815255165100098, -2.0721492767333984, -3.410574436187744, -8.29980754852295, 1.483638048171997, -15.365986824035645, -8.288211822509766, 3.884779930114746, -3.4876468181610107, 7.362999439239502, 0.4657334089279175, 3.1326050758361816, 12.438895225524902, -1.8337041139602661, 4.532927989959717, 2.7264339923858643, 10.14534854888916, -6.521963596343994, 2.897155523300171, -3.392582654953003, 5.079153060913086, 7.7597246170043945, 4.677570819854736, 5.845779895782471, 2.402411460876465, 7.7071051597595215, 3.9711380004882812, -6.39003849029541, 6.12687873840332, -3.776029348373413, -11.118121147155762]}} [2022-05-25 12:25:36,324] [ INFO] - The vector: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [-1.3251205682754517, 7.860682487487793, -4.620625972747803, 0.3000721037387848, 2.2648534774780273, -1.1931440830230713, 3.064713716506958, 7.673594951629639, -6.004472732543945, -12.024259567260742, -1.9496068954467773, 3.126953601837158, 1.6188379526138306, -7.638310432434082, -1.2299772500991821, -12.33833122253418, 2.1373026371002197, -5.395712375640869, 9.717328071594238, 5.675230503082275, 3.7805123329162598, 3.0597171783447266, 3.429692029953003, 8.9760103225708, 13.174124717712402, -0.5313228368759155, 8.942471504211426, 4.465109825134277, -4.426247596740723, -9.726503372192383, 8.399328231811523, 7.223917484283447, -7.435853958129883, 2.9441683292388916, -4.343039512634277, -13.886964797973633, -1.6346734762191772, -10.902740478515625, -5.311244964599609, 3.800722122192383, 3.897603750228882, -2.123077392578125, -2.3521194458007812, 4.151031017303467, -7.404866695404053, 0.13911646604537964, 2.4626107215881348, 4.96645450592041, 0.9897574186325073, 5.483975410461426, -3.3574001789093018, 10.13400650024414, -0.6120170950889587, -10.403095245361328, 4.600754261016846, 16.009349822998047, -7.78369140625, -4.194530487060547, -6.93686056137085, 1.1789555549621582, 11.490800857543945, 4.23802375793457, 9.550930976867676, 8.375045776367188, 7.508914470672607, -0.6570729613304138, -0.3005157709121704, 2.8406054973602295, 3.0828027725219727, 0.7308170199394226, 6.1483540534973145, 0.1376611888408661, -13.424735069274902, -7.746140480041504, -2.322798252105713, -8.305252075195312, 2.98791241645813, -10.99522876739502, 0.15211068093776703, -2.3820347785949707, -1.7984174489974976, 8.49562931060791, -5.852236747741699, -3.755497932434082, 0.6989710927009583, -5.270299434661865, -2.6188621520996094, -1.8828465938568115, -4.6466498374938965, 14.078543663024902, -0.5495333075523376, 10.579157829284668, -3.216050148010254, 9.349003791809082, -4.381077766418457, -11.675816535949707, -2.863020658493042, 4.5721755027771, 2.246612071990967, -4.574341773986816, 1.8610187768936157, 2.3767874240875244, 5.625787734985352, -9.784077644348145, 0.6496725678443909, -1.457950472831726, 0.4263263940811157, -4.921126365661621, -2.4547839164733887, 3.4869801998138428, -0.4265422224998474, 8.341268539428711, 1.356552004814148, 7.096688270568848, -13.102828979492188, 8.01673412322998, -7.115934371948242, 1.8699780702590942, 0.20872099697589874, 14.699383735656738, -1.0252779722213745, -2.6107232570648193, -2.5082311630249023, 8.427192687988281, 6.913852691650391, -6.29124641418457, 0.6157366037368774, 2.489687919616699, -3.4668266773223877, 9.92176342010498, 11.200815200805664, -0.19664029777050018, 7.491600513458252, -0.6231271624565125, -0.2584814429283142, -9.947997093200684, -0.9611040949821472, 1.1649218797683716, -2.1907122135162354, -1.502848744392395, -0.5192610621452332, 15.165953636169434, 2.4649462699890137, -0.998044490814209, 7.44166374206543, -2.0768048763275146, 3.5896823406219482, -7.305543422698975, -7.562084674835205, 4.32333517074585, 0.08044180274009705, -6.564010143280029, -2.314805269241333, -1.7642345428466797, -2.470881700515747, -7.6756181716918945, -9.548877716064453, -1.017755389213562, 0.1698644608259201, 2.5877134799957275, -1.8752295970916748, -0.36614322662353516, -6.049378395080566, -2.3965611457824707, -5.945338726043701, 0.9424033164978027, -13.155974388122559, -7.45780086517334, 0.14658108353614807, -3.7427968978881836, 5.841492652893066, -1.2872905731201172, 5.569431304931641, 12.570590019226074, 1.0939218997955322, 2.2142086029052734, 1.9181575775146484, 6.991420745849609, -5.888138771057129, 3.1409823894500732, -2.0036280155181885, 2.4434285163879395, 9.973138809204102, 5.036680221557617, 2.005120277404785, 2.861560344696045, 5.860223770141602, 2.917618751525879, -1.63111412525177, 2.0292205810546875, -4.070415019989014, -6.831437110900879]}}
[2022-05-08 00:18:44,406] [ INFO] - Response time 0.156481 s. [2022-05-25 12:25:36,324] [ INFO] - Response time 0.159053 s.
``` ```
* Python API * Python API
...@@ -299,7 +299,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -299,7 +299,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
Output: Output:
``` bash ``` bash
{'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [1.421751856803894, 5.626245498657227, -5.342077255249023, 1.1773887872695923, 3.3080549240112305, 1.7565933465957642, 5.167886257171631, 10.806358337402344, -3.8226819038391113, -5.614140033721924, 2.6238479614257812, -0.8072972893714905, 1.9635076522827148, -7.312870025634766, 0.011035939678549767, -9.723129272460938, 0.6619706153869629, -6.976806163787842, 10.213476181030273, 7.494769096374512, 2.9105682373046875, 3.8949244022369385, 3.799983501434326, 7.106168746948242, 16.90532875061035, -7.149388313293457, 8.733108520507812, 3.423006296157837, -4.831653594970703, -11.403363227844238, 11.232224464416504, 7.127461910247803, -4.282842636108398, 2.452359437942505, -5.130749702453613, -18.17766761779785, -2.6116831302642822, -11.000344276428223, -6.731433391571045, 1.6564682722091675, 0.7618281245231628, 1.125300407409668, -2.0838370323181152, 4.725743293762207, -8.782588005065918, -3.5398752689361572, 3.8142364025115967, 5.142068862915039, 2.1620609760284424, 4.09643030166626, -6.416214942932129, 12.747446060180664, 1.9429892301559448, -15.15294361114502, 6.417416095733643, 16.09701156616211, -9.716667175292969, -1.9920575618743896, -3.36494779586792, -1.8719440698623657, 11.567351341247559, 3.6978814601898193, 11.258262634277344, 7.442368507385254, 9.183408737182617, 4.528149127960205, -1.2417854070663452, 4.395912170410156, 6.6727728843688965, 5.88988733291626, 7.627128601074219, -0.6691966652870178, -11.889698028564453, -9.20886516571045, -7.42740535736084, -3.777663230895996, 6.917238712310791, -9.848755836486816, -2.0944676399230957, -5.1351165771484375, 0.4956451654434204, 9.317537307739258, -5.914181232452393, -1.809860348701477, -0.11738915741443634, -7.1692705154418945, -1.057827353477478, -5.721670627593994, -5.117385387420654, 16.13765525817871, -4.473617076873779, 7.6624321937561035, -0.55381840467453, 9.631585121154785, -6.470459461212158, -8.548508644104004, 4.371616840362549, -0.7970245480537415, 4.4789886474609375, -2.975860834121704, 3.2721822261810303, 2.838287830352783, 5.134591102600098, -9.19079875946045, -0.5657302737236023, -4.8745832443237305, 2.3165574073791504, -5.984319686889648, -2.1798853874206543, 0.3554139733314514, -0.3178512752056122, 9.493552207946777, 2.1144471168518066, 4.358094692230225, -12.089824676513672, 8.451693534851074, -7.925466537475586, 4.624246597290039, 4.428936958312988, 18.69200897216797, -2.6204581260681152, -5.14918851852417, -0.3582090139389038, 8.488558769226074, 4.98148775100708, -9.326835632324219, -2.2544219493865967, 6.641760349273682, 1.2119598388671875, 10.977124214172363, 16.555034637451172, 3.3238420486450195, 9.551861763000488, -1.6676981449127197, -0.7953944206237793, -8.605667114257812, -0.4735655188560486, 2.674196243286133, -5.359177112579346, -2.66738224029541, 0.6660683155059814, 15.44322681427002, 4.740593433380127, -3.472534418106079, 11.592567443847656, -2.0544962882995605, 1.736127495765686, -8.265326499938965, -9.30447769165039, 5.406829833984375, -1.518022894859314, -7.746612548828125, -6.089611053466797, 0.07112743705511093, -0.3490503430366516, -8.64989185333252, -9.998957633972168, -2.564845085144043, -0.5399947762489319, 2.6018123626708984, -0.3192799389362335, -1.8815255165100098, -2.0721492767333984, -3.410574436187744, -8.29980754852295, 1.483638048171997, -15.365986824035645, -8.288211822509766, 3.884779930114746, -3.4876468181610107, 7.362999439239502, 0.4657334089279175, 3.1326050758361816, 12.438895225524902, -1.8337041139602661, 4.532927989959717, 2.7264339923858643, 10.14534854888916, -6.521963596343994, 2.897155523300171, -3.392582654953003, 5.079153060913086, 7.7597246170043945, 4.677570819854736, 5.845779895782471, 2.402411460876465, 7.7071051597595215, 3.9711380004882812, -6.39003849029541, 6.12687873840332, -3.776029348373413, -11.118121147155762]}} {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [-1.3251205682754517, 7.860682487487793, -4.620625972747803, 0.3000721037387848, 2.2648534774780273, -1.1931440830230713, 3.064713716506958, 7.673594951629639, -6.004472732543945, -12.024259567260742, -1.9496068954467773, 3.126953601837158, 1.6188379526138306, -7.638310432434082, -1.2299772500991821, -12.33833122253418, 2.1373026371002197, -5.395712375640869, 9.717328071594238, 5.675230503082275, 3.7805123329162598, 3.0597171783447266, 3.429692029953003, 8.9760103225708, 13.174124717712402, -0.5313228368759155, 8.942471504211426, 4.465109825134277, -4.426247596740723, -9.726503372192383, 8.399328231811523, 7.223917484283447, -7.435853958129883, 2.9441683292388916, -4.343039512634277, -13.886964797973633, -1.6346734762191772, -10.902740478515625, -5.311244964599609, 3.800722122192383, 3.897603750228882, -2.123077392578125, -2.3521194458007812, 4.151031017303467, -7.404866695404053, 0.13911646604537964, 2.4626107215881348, 4.96645450592041, 0.9897574186325073, 5.483975410461426, -3.3574001789093018, 10.13400650024414, -0.6120170950889587, -10.403095245361328, 4.600754261016846, 16.009349822998047, -7.78369140625, -4.194530487060547, -6.93686056137085, 1.1789555549621582, 11.490800857543945, 4.23802375793457, 9.550930976867676, 8.375045776367188, 7.508914470672607, -0.6570729613304138, -0.3005157709121704, 2.8406054973602295, 3.0828027725219727, 0.7308170199394226, 6.1483540534973145, 0.1376611888408661, -13.424735069274902, -7.746140480041504, -2.322798252105713, -8.305252075195312, 2.98791241645813, -10.99522876739502, 0.15211068093776703, -2.3820347785949707, -1.7984174489974976, 8.49562931060791, -5.852236747741699, -3.755497932434082, 0.6989710927009583, -5.270299434661865, -2.6188621520996094, -1.8828465938568115, -4.6466498374938965, 14.078543663024902, -0.5495333075523376, 10.579157829284668, -3.216050148010254, 9.349003791809082, -4.381077766418457, -11.675816535949707, -2.863020658493042, 4.5721755027771, 2.246612071990967, -4.574341773986816, 1.8610187768936157, 2.3767874240875244, 5.625787734985352, -9.784077644348145, 0.6496725678443909, -1.457950472831726, 0.4263263940811157, -4.921126365661621, -2.4547839164733887, 3.4869801998138428, -0.4265422224998474, 8.341268539428711, 1.356552004814148, 7.096688270568848, -13.102828979492188, 8.01673412322998, -7.115934371948242, 1.8699780702590942, 0.20872099697589874, 14.699383735656738, -1.0252779722213745, -2.6107232570648193, -2.5082311630249023, 8.427192687988281, 6.913852691650391, -6.29124641418457, 0.6157366037368774, 2.489687919616699, -3.4668266773223877, 9.92176342010498, 11.200815200805664, -0.19664029777050018, 7.491600513458252, -0.6231271624565125, -0.2584814429283142, -9.947997093200684, -0.9611040949821472, 1.1649218797683716, -2.1907122135162354, -1.502848744392395, -0.5192610621452332, 15.165953636169434, 2.4649462699890137, -0.998044490814209, 7.44166374206543, -2.0768048763275146, 3.5896823406219482, -7.305543422698975, -7.562084674835205, 4.32333517074585, 0.08044180274009705, -6.564010143280029, -2.314805269241333, -1.7642345428466797, -2.470881700515747, -7.6756181716918945, -9.548877716064453, -1.017755389213562, 0.1698644608259201, 2.5877134799957275, -1.8752295970916748, -0.36614322662353516, -6.049378395080566, -2.3965611457824707, -5.945338726043701, 0.9424033164978027, -13.155974388122559, -7.45780086517334, 0.14658108353614807, -3.7427968978881836, 5.841492652893066, -1.2872905731201172, 5.569431304931641, 12.570590019226074, 1.0939218997955322, 2.2142086029052734, 1.9181575775146484, 6.991420745849609, -5.888138771057129, 3.1409823894500732, -2.0036280155181885, 2.4434285163879395, 9.973138809204102, 5.036680221557617, 2.005120277404785, 2.861560344696045, 5.860223770141602, 2.917618751525879, -1.63111412525177, 2.0292205810546875, -4.070415019989014, -6.831437110900879]}}
``` ```
#### 7.2 Get the score between speaker audio embedding #### 7.2 Get the score between speaker audio embedding
...@@ -331,12 +331,12 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -331,12 +331,12 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
Output: Output:
``` bash ``` bash
[2022-05-09 10:28:40,556] [ INFO] - vector score http client start [2022-05-25 12:33:24,527] [ INFO] - vector score http client start
[2022-05-09 10:28:40,556] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav [2022-05-25 12:33:24,527] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav
[2022-05-09 10:28:40,556] [ INFO] - endpoint: http://127.0.0.1:8090/paddlespeech/vector/score [2022-05-25 12:33:24,528] [ INFO] - endpoint: http://127.0.0.1:8790/paddlespeech/vector/score
[2022-05-09 10:28:40,731] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.4292638897895813}} [2022-05-25 12:33:24,695] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}}
[2022-05-09 10:28:40,731] [ INFO] - The vector: None [2022-05-25 12:33:24,696] [ INFO] - The vector: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}}
[2022-05-09 10:28:40,731] [ INFO] - Response time 0.175514 s. [2022-05-25 12:33:24,696] [ INFO] - Response time 0.168271 s.
``` ```
* Python API * Python API
...@@ -358,10 +358,11 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -358,10 +358,11 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
Output: Output:
``` bash ``` bash
[2022-05-09 10:34:54,769] [ INFO] - vector score http client start [2022-05-25 12:30:14,143] [ INFO] - vector score http client start
[2022-05-09 10:34:54,771] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav [2022-05-25 12:30:14,143] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav
[2022-05-09 10:34:54,771] [ INFO] - endpoint: http://127.0.0.1:8090/paddlespeech/vector/score [2022-05-25 12:30:14,143] [ INFO] - endpoint: http://127.0.0.1:8790/paddlespeech/vector/score
[2022-05-09 10:34:55,026] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.4292638897895813}} [2022-05-25 12:30:14,363] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}}
{'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}}
``` ```
### 8. Punctuation prediction ### 8. Punctuation prediction
......
...@@ -277,12 +277,12 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -277,12 +277,12 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
输出: 输出:
``` bash ``` bash
[2022-05-08 00:18:44,249] [ INFO] - vector http client start [2022-05-25 12:25:36,165] [ INFO] - vector http client start
[2022-05-08 00:18:44,250] [ INFO] - the input audio: 85236145389.wav [2022-05-25 12:25:36,165] [ INFO] - the input audio: 85236145389.wav
[2022-05-08 00:18:44,250] [ INFO] - endpoint: http://127.0.0.1:8090/paddlespeech/vector [2022-05-25 12:25:36,165] [ INFO] - endpoint: http://127.0.0.1:8790/paddlespeech/vector
[2022-05-08 00:18:44,250] [ INFO] - http://127.0.0.1:8590/paddlespeech/vector [2022-05-25 12:25:36,166] [ INFO] - http://127.0.0.1:8790/paddlespeech/vector
[2022-05-08 00:18:44,406] [ INFO] - The vector: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [1.421751856803894, 5.626245498657227, -5.342077255249023, 1.1773887872695923, 3.3080549240112305, 1.7565933465957642, 5.167886257171631, 10.806358337402344, -3.8226819038391113, -5.614140033721924, 2.6238479614257812, -0.8072972893714905, 1.9635076522827148, -7.312870025634766, 0.011035939678549767, -9.723129272460938, 0.6619706153869629, -6.976806163787842, 10.213476181030273, 7.494769096374512, 2.9105682373046875, 3.8949244022369385, 3.799983501434326, 7.106168746948242, 16.90532875061035, -7.149388313293457, 8.733108520507812, 3.423006296157837, -4.831653594970703, -11.403363227844238, 11.232224464416504, 7.127461910247803, -4.282842636108398, 2.452359437942505, -5.130749702453613, -18.17766761779785, -2.6116831302642822, -11.000344276428223, -6.731433391571045, 1.6564682722091675, 0.7618281245231628, 1.125300407409668, -2.0838370323181152, 4.725743293762207, -8.782588005065918, -3.5398752689361572, 3.8142364025115967, 5.142068862915039, 2.1620609760284424, 4.09643030166626, -6.416214942932129, 12.747446060180664, 1.9429892301559448, -15.15294361114502, 6.417416095733643, 16.09701156616211, -9.716667175292969, -1.9920575618743896, -3.36494779586792, -1.8719440698623657, 11.567351341247559, 3.6978814601898193, 11.258262634277344, 7.442368507385254, 9.183408737182617, 4.528149127960205, -1.2417854070663452, 4.395912170410156, 6.6727728843688965, 5.88988733291626, 7.627128601074219, -0.6691966652870178, -11.889698028564453, -9.20886516571045, -7.42740535736084, -3.777663230895996, 6.917238712310791, -9.848755836486816, -2.0944676399230957, -5.1351165771484375, 0.4956451654434204, 9.317537307739258, -5.914181232452393, -1.809860348701477, -0.11738915741443634, -7.1692705154418945, -1.057827353477478, -5.721670627593994, -5.117385387420654, 16.13765525817871, -4.473617076873779, 7.6624321937561035, -0.55381840467453, 9.631585121154785, -6.470459461212158, -8.548508644104004, 4.371616840362549, -0.7970245480537415, 4.4789886474609375, -2.975860834121704, 3.2721822261810303, 2.838287830352783, 5.134591102600098, -9.19079875946045, -0.5657302737236023, -4.8745832443237305, 2.3165574073791504, -5.984319686889648, -2.1798853874206543, 0.3554139733314514, -0.3178512752056122, 9.493552207946777, 2.1144471168518066, 4.358094692230225, -12.089824676513672, 8.451693534851074, -7.925466537475586, 4.624246597290039, 4.428936958312988, 18.69200897216797, -2.6204581260681152, -5.14918851852417, -0.3582090139389038, 8.488558769226074, 4.98148775100708, -9.326835632324219, -2.2544219493865967, 6.641760349273682, 1.2119598388671875, 10.977124214172363, 16.555034637451172, 3.3238420486450195, 9.551861763000488, -1.6676981449127197, -0.7953944206237793, -8.605667114257812, -0.4735655188560486, 2.674196243286133, -5.359177112579346, -2.66738224029541, 0.6660683155059814, 15.44322681427002, 4.740593433380127, -3.472534418106079, 11.592567443847656, -2.0544962882995605, 1.736127495765686, -8.265326499938965, -9.30447769165039, 5.406829833984375, -1.518022894859314, -7.746612548828125, -6.089611053466797, 0.07112743705511093, -0.3490503430366516, -8.64989185333252, -9.998957633972168, -2.564845085144043, -0.5399947762489319, 2.6018123626708984, -0.3192799389362335, -1.8815255165100098, -2.0721492767333984, -3.410574436187744, -8.29980754852295, 1.483638048171997, -15.365986824035645, -8.288211822509766, 3.884779930114746, -3.4876468181610107, 7.362999439239502, 0.4657334089279175, 3.1326050758361816, 12.438895225524902, -1.8337041139602661, 4.532927989959717, 2.7264339923858643, 10.14534854888916, -6.521963596343994, 2.897155523300171, -3.392582654953003, 5.079153060913086, 7.7597246170043945, 4.677570819854736, 5.845779895782471, 2.402411460876465, 7.7071051597595215, 3.9711380004882812, -6.39003849029541, 6.12687873840332, -3.776029348373413, -11.118121147155762]}} [2022-05-25 12:25:36,324] [ INFO] - The vector: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [-1.3251205682754517, 7.860682487487793, -4.620625972747803, 0.3000721037387848, 2.2648534774780273, -1.1931440830230713, 3.064713716506958, 7.673594951629639, -6.004472732543945, -12.024259567260742, -1.9496068954467773, 3.126953601837158, 1.6188379526138306, -7.638310432434082, -1.2299772500991821, -12.33833122253418, 2.1373026371002197, -5.395712375640869, 9.717328071594238, 5.675230503082275, 3.7805123329162598, 3.0597171783447266, 3.429692029953003, 8.9760103225708, 13.174124717712402, -0.5313228368759155, 8.942471504211426, 4.465109825134277, -4.426247596740723, -9.726503372192383, 8.399328231811523, 7.223917484283447, -7.435853958129883, 2.9441683292388916, -4.343039512634277, -13.886964797973633, -1.6346734762191772, -10.902740478515625, -5.311244964599609, 3.800722122192383, 3.897603750228882, -2.123077392578125, -2.3521194458007812, 4.151031017303467, -7.404866695404053, 0.13911646604537964, 2.4626107215881348, 4.96645450592041, 0.9897574186325073, 5.483975410461426, -3.3574001789093018, 10.13400650024414, -0.6120170950889587, -10.403095245361328, 4.600754261016846, 16.009349822998047, -7.78369140625, -4.194530487060547, -6.93686056137085, 1.1789555549621582, 11.490800857543945, 4.23802375793457, 9.550930976867676, 8.375045776367188, 7.508914470672607, -0.6570729613304138, -0.3005157709121704, 2.8406054973602295, 3.0828027725219727, 0.7308170199394226, 6.1483540534973145, 0.1376611888408661, -13.424735069274902, -7.746140480041504, -2.322798252105713, -8.305252075195312, 2.98791241645813, -10.99522876739502, 0.15211068093776703, -2.3820347785949707, -1.7984174489974976, 8.49562931060791, -5.852236747741699, -3.755497932434082, 0.6989710927009583, -5.270299434661865, -2.6188621520996094, -1.8828465938568115, -4.6466498374938965, 14.078543663024902, -0.5495333075523376, 10.579157829284668, -3.216050148010254, 9.349003791809082, -4.381077766418457, -11.675816535949707, -2.863020658493042, 4.5721755027771, 2.246612071990967, -4.574341773986816, 1.8610187768936157, 2.3767874240875244, 5.625787734985352, -9.784077644348145, 0.6496725678443909, -1.457950472831726, 0.4263263940811157, -4.921126365661621, -2.4547839164733887, 3.4869801998138428, -0.4265422224998474, 8.341268539428711, 1.356552004814148, 7.096688270568848, -13.102828979492188, 8.01673412322998, -7.115934371948242, 1.8699780702590942, 0.20872099697589874, 14.699383735656738, -1.0252779722213745, -2.6107232570648193, -2.5082311630249023, 8.427192687988281, 6.913852691650391, -6.29124641418457, 0.6157366037368774, 2.489687919616699, -3.4668266773223877, 9.92176342010498, 11.200815200805664, -0.19664029777050018, 7.491600513458252, -0.6231271624565125, -0.2584814429283142, -9.947997093200684, -0.9611040949821472, 1.1649218797683716, -2.1907122135162354, -1.502848744392395, -0.5192610621452332, 15.165953636169434, 2.4649462699890137, -0.998044490814209, 7.44166374206543, -2.0768048763275146, 3.5896823406219482, -7.305543422698975, -7.562084674835205, 4.32333517074585, 0.08044180274009705, -6.564010143280029, -2.314805269241333, -1.7642345428466797, -2.470881700515747, -7.6756181716918945, -9.548877716064453, -1.017755389213562, 0.1698644608259201, 2.5877134799957275, -1.8752295970916748, -0.36614322662353516, -6.049378395080566, -2.3965611457824707, -5.945338726043701, 0.9424033164978027, -13.155974388122559, -7.45780086517334, 0.14658108353614807, -3.7427968978881836, 5.841492652893066, -1.2872905731201172, 5.569431304931641, 12.570590019226074, 1.0939218997955322, 2.2142086029052734, 1.9181575775146484, 6.991420745849609, -5.888138771057129, 3.1409823894500732, -2.0036280155181885, 2.4434285163879395, 9.973138809204102, 5.036680221557617, 2.005120277404785, 2.861560344696045, 5.860223770141602, 2.917618751525879, -1.63111412525177, 2.0292205810546875, -4.070415019989014, -6.831437110900879]}}
[2022-05-08 00:18:44,406] [ INFO] - Response time 0.156481 s. [2022-05-25 12:25:36,324] [ INFO] - Response time 0.159053 s.
``` ```
* Python API * Python API
...@@ -302,7 +302,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -302,7 +302,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
输出: 输出:
``` bash ``` bash
{'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [1.421751856803894, 5.626245498657227, -5.342077255249023, 1.1773887872695923, 3.3080549240112305, 1.7565933465957642, 5.167886257171631, 10.806358337402344, -3.8226819038391113, -5.614140033721924, 2.6238479614257812, -0.8072972893714905, 1.9635076522827148, -7.312870025634766, 0.011035939678549767, -9.723129272460938, 0.6619706153869629, -6.976806163787842, 10.213476181030273, 7.494769096374512, 2.9105682373046875, 3.8949244022369385, 3.799983501434326, 7.106168746948242, 16.90532875061035, -7.149388313293457, 8.733108520507812, 3.423006296157837, -4.831653594970703, -11.403363227844238, 11.232224464416504, 7.127461910247803, -4.282842636108398, 2.452359437942505, -5.130749702453613, -18.17766761779785, -2.6116831302642822, -11.000344276428223, -6.731433391571045, 1.6564682722091675, 0.7618281245231628, 1.125300407409668, -2.0838370323181152, 4.725743293762207, -8.782588005065918, -3.5398752689361572, 3.8142364025115967, 5.142068862915039, 2.1620609760284424, 4.09643030166626, -6.416214942932129, 12.747446060180664, 1.9429892301559448, -15.15294361114502, 6.417416095733643, 16.09701156616211, -9.716667175292969, -1.9920575618743896, -3.36494779586792, -1.8719440698623657, 11.567351341247559, 3.6978814601898193, 11.258262634277344, 7.442368507385254, 9.183408737182617, 4.528149127960205, -1.2417854070663452, 4.395912170410156, 6.6727728843688965, 5.88988733291626, 7.627128601074219, -0.6691966652870178, -11.889698028564453, -9.20886516571045, -7.42740535736084, -3.777663230895996, 6.917238712310791, -9.848755836486816, -2.0944676399230957, -5.1351165771484375, 0.4956451654434204, 9.317537307739258, -5.914181232452393, -1.809860348701477, -0.11738915741443634, -7.1692705154418945, -1.057827353477478, -5.721670627593994, -5.117385387420654, 16.13765525817871, -4.473617076873779, 7.6624321937561035, -0.55381840467453, 9.631585121154785, -6.470459461212158, -8.548508644104004, 4.371616840362549, -0.7970245480537415, 4.4789886474609375, -2.975860834121704, 3.2721822261810303, 2.838287830352783, 5.134591102600098, -9.19079875946045, -0.5657302737236023, -4.8745832443237305, 2.3165574073791504, -5.984319686889648, -2.1798853874206543, 0.3554139733314514, -0.3178512752056122, 9.493552207946777, 2.1144471168518066, 4.358094692230225, -12.089824676513672, 8.451693534851074, -7.925466537475586, 4.624246597290039, 4.428936958312988, 18.69200897216797, -2.6204581260681152, -5.14918851852417, -0.3582090139389038, 8.488558769226074, 4.98148775100708, -9.326835632324219, -2.2544219493865967, 6.641760349273682, 1.2119598388671875, 10.977124214172363, 16.555034637451172, 3.3238420486450195, 9.551861763000488, -1.6676981449127197, -0.7953944206237793, -8.605667114257812, -0.4735655188560486, 2.674196243286133, -5.359177112579346, -2.66738224029541, 0.6660683155059814, 15.44322681427002, 4.740593433380127, -3.472534418106079, 11.592567443847656, -2.0544962882995605, 1.736127495765686, -8.265326499938965, -9.30447769165039, 5.406829833984375, -1.518022894859314, -7.746612548828125, -6.089611053466797, 0.07112743705511093, -0.3490503430366516, -8.64989185333252, -9.998957633972168, -2.564845085144043, -0.5399947762489319, 2.6018123626708984, -0.3192799389362335, -1.8815255165100098, -2.0721492767333984, -3.410574436187744, -8.29980754852295, 1.483638048171997, -15.365986824035645, -8.288211822509766, 3.884779930114746, -3.4876468181610107, 7.362999439239502, 0.4657334089279175, 3.1326050758361816, 12.438895225524902, -1.8337041139602661, 4.532927989959717, 2.7264339923858643, 10.14534854888916, -6.521963596343994, 2.897155523300171, -3.392582654953003, 5.079153060913086, 7.7597246170043945, 4.677570819854736, 5.845779895782471, 2.402411460876465, 7.7071051597595215, 3.9711380004882812, -6.39003849029541, 6.12687873840332, -3.776029348373413, -11.118121147155762]}} {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'vec': [-1.3251205682754517, 7.860682487487793, -4.620625972747803, 0.3000721037387848, 2.2648534774780273, -1.1931440830230713, 3.064713716506958, 7.673594951629639, -6.004472732543945, -12.024259567260742, -1.9496068954467773, 3.126953601837158, 1.6188379526138306, -7.638310432434082, -1.2299772500991821, -12.33833122253418, 2.1373026371002197, -5.395712375640869, 9.717328071594238, 5.675230503082275, 3.7805123329162598, 3.0597171783447266, 3.429692029953003, 8.9760103225708, 13.174124717712402, -0.5313228368759155, 8.942471504211426, 4.465109825134277, -4.426247596740723, -9.726503372192383, 8.399328231811523, 7.223917484283447, -7.435853958129883, 2.9441683292388916, -4.343039512634277, -13.886964797973633, -1.6346734762191772, -10.902740478515625, -5.311244964599609, 3.800722122192383, 3.897603750228882, -2.123077392578125, -2.3521194458007812, 4.151031017303467, -7.404866695404053, 0.13911646604537964, 2.4626107215881348, 4.96645450592041, 0.9897574186325073, 5.483975410461426, -3.3574001789093018, 10.13400650024414, -0.6120170950889587, -10.403095245361328, 4.600754261016846, 16.009349822998047, -7.78369140625, -4.194530487060547, -6.93686056137085, 1.1789555549621582, 11.490800857543945, 4.23802375793457, 9.550930976867676, 8.375045776367188, 7.508914470672607, -0.6570729613304138, -0.3005157709121704, 2.8406054973602295, 3.0828027725219727, 0.7308170199394226, 6.1483540534973145, 0.1376611888408661, -13.424735069274902, -7.746140480041504, -2.322798252105713, -8.305252075195312, 2.98791241645813, -10.99522876739502, 0.15211068093776703, -2.3820347785949707, -1.7984174489974976, 8.49562931060791, -5.852236747741699, -3.755497932434082, 0.6989710927009583, -5.270299434661865, -2.6188621520996094, -1.8828465938568115, -4.6466498374938965, 14.078543663024902, -0.5495333075523376, 10.579157829284668, -3.216050148010254, 9.349003791809082, -4.381077766418457, -11.675816535949707, -2.863020658493042, 4.5721755027771, 2.246612071990967, -4.574341773986816, 1.8610187768936157, 2.3767874240875244, 5.625787734985352, -9.784077644348145, 0.6496725678443909, -1.457950472831726, 0.4263263940811157, -4.921126365661621, -2.4547839164733887, 3.4869801998138428, -0.4265422224998474, 8.341268539428711, 1.356552004814148, 7.096688270568848, -13.102828979492188, 8.01673412322998, -7.115934371948242, 1.8699780702590942, 0.20872099697589874, 14.699383735656738, -1.0252779722213745, -2.6107232570648193, -2.5082311630249023, 8.427192687988281, 6.913852691650391, -6.29124641418457, 0.6157366037368774, 2.489687919616699, -3.4668266773223877, 9.92176342010498, 11.200815200805664, -0.19664029777050018, 7.491600513458252, -0.6231271624565125, -0.2584814429283142, -9.947997093200684, -0.9611040949821472, 1.1649218797683716, -2.1907122135162354, -1.502848744392395, -0.5192610621452332, 15.165953636169434, 2.4649462699890137, -0.998044490814209, 7.44166374206543, -2.0768048763275146, 3.5896823406219482, -7.305543422698975, -7.562084674835205, 4.32333517074585, 0.08044180274009705, -6.564010143280029, -2.314805269241333, -1.7642345428466797, -2.470881700515747, -7.6756181716918945, -9.548877716064453, -1.017755389213562, 0.1698644608259201, 2.5877134799957275, -1.8752295970916748, -0.36614322662353516, -6.049378395080566, -2.3965611457824707, -5.945338726043701, 0.9424033164978027, -13.155974388122559, -7.45780086517334, 0.14658108353614807, -3.7427968978881836, 5.841492652893066, -1.2872905731201172, 5.569431304931641, 12.570590019226074, 1.0939218997955322, 2.2142086029052734, 1.9181575775146484, 6.991420745849609, -5.888138771057129, 3.1409823894500732, -2.0036280155181885, 2.4434285163879395, 9.973138809204102, 5.036680221557617, 2.005120277404785, 2.861560344696045, 5.860223770141602, 2.917618751525879, -1.63111412525177, 2.0292205810546875, -4.070415019989014, -6.831437110900879]}}
``` ```
#### 7.2 音频声纹打分 #### 7.2 音频声纹打分
...@@ -333,12 +333,12 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -333,12 +333,12 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
输出: 输出:
``` bash ``` bash
[2022-05-09 10:28:40,556] [ INFO] - vector score http client start [2022-05-25 12:33:24,527] [ INFO] - vector score http client start
[2022-05-09 10:28:40,556] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav [2022-05-25 12:33:24,527] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav
[2022-05-09 10:28:40,556] [ INFO] - endpoint: http://127.0.0.1:8090/paddlespeech/vector/score [2022-05-25 12:33:24,528] [ INFO] - endpoint: http://127.0.0.1:8790/paddlespeech/vector/score
[2022-05-09 10:28:40,731] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.4292638897895813}} [2022-05-25 12:33:24,695] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}}
[2022-05-09 10:28:40,731] [ INFO] - The vector: None [2022-05-25 12:33:24,696] [ INFO] - The vector: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}}
[2022-05-09 10:28:40,731] [ INFO] - Response time 0.175514 s. [2022-05-25 12:33:24,696] [ INFO] - Response time 0.168271 s.
``` ```
* Python API * Python API
...@@ -360,10 +360,11 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ...@@ -360,10 +360,11 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
输出: 输出:
``` bash ``` bash
[2022-05-09 10:34:54,769] [ INFO] - vector score http client start [2022-05-25 12:30:14,143] [ INFO] - vector score http client start
[2022-05-09 10:34:54,771] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav [2022-05-25 12:30:14,143] [ INFO] - enroll audio: 85236145389.wav, test audio: 123456789.wav
[2022-05-09 10:34:54,771] [ INFO] - endpoint: http://127.0.0.1:8590/paddlespeech/vector/score [2022-05-25 12:30:14,143] [ INFO] - endpoint: http://127.0.0.1:8790/paddlespeech/vector/score
[2022-05-09 10:34:55,026] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.4292638897895813}} [2022-05-25 12:30:14,363] [ INFO] - The vector score is: {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}}
{'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'score': 0.45332613587379456}}
``` ```
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# SERVER SETTING # # SERVER SETTING #
################################################################################# #################################################################################
host: 0.0.0.0 host: 0.0.0.0
port: 8090 port: 8091
# The task format in the engin_list is: <speech task>_<engine type> # The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online'] # task choices = ['asr_online']
......
...@@ -13,9 +13,7 @@ ...@@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
#!/usr/bin/python #!/usr/bin/python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}' # script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
import argparse import argparse
import asyncio import asyncio
import codecs import codecs
......
...@@ -92,5 +92,3 @@ server 的 demo: [streaming_asr_server](https://github.com/PaddlePaddle/Paddle ...@@ -92,5 +92,3 @@ server 的 demo: [streaming_asr_server](https://github.com/PaddlePaddle/Paddle
## 4. 快速开始 ## 4. 快速开始
关于如果使用 PP-ASR,可以看这里的 [install](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md),其中提供了 **简单****中等****困难** 三种安装方式。如果想体验 paddlespeech 的推理功能,可以用 **简单** 安装方式。 关于如果使用 PP-ASR,可以看这里的 [install](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md),其中提供了 **简单****中等****困难** 三种安装方式。如果想体验 paddlespeech 的推理功能,可以用 **简单** 安装方式。
...@@ -4,7 +4,7 @@ There are 3 ways to use `PaddleSpeech`. According to the degree of difficulty, t ...@@ -4,7 +4,7 @@ There are 3 ways to use `PaddleSpeech`. According to the degree of difficulty, t
| Way | Function | Support| | Way | Function | Support|
|:---- |:----------------------------------------------------------- |:----| |:---- |:----------------------------------------------------------- |:----|
| Easy | (1) Use command-line functions of PaddleSpeech. <br> (2) Experience PaddleSpeech on Ai Studio. | Linux, Mac(not support M1 chip),Windows | | Easy | (1) Use command-line functions of PaddleSpeech. <br> (2) Experience PaddleSpeech on Ai Studio. | Linux, Mac(not support M1 chip),Windows ( For more information about installation, see [#1195](https://github.com/PaddlePaddle/PaddleSpeech/discussions/1195)) |
| Medium | Support major functions ,such as using the` ready-made `examples and using PaddleSpeech to train your model. | Linux | | Medium | Support major functions ,such as using the` ready-made `examples and using PaddleSpeech to train your model. | Linux |
| Hard | Support full function of Paddlespeech, including using join ctc decoder with kaldi, training n-gram language model, Montreal-Forced-Aligner, and so on. And you are more able to be a developer! | Ubuntu | | Hard | Support full function of Paddlespeech, including using join ctc decoder with kaldi, training n-gram language model, Montreal-Forced-Aligner, and so on. And you are more able to be a developer! | Ubuntu |
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
`PaddleSpeech` 有三种安装方法。根据安装的难易程度,这三种方法可以分为 **简单**, **中等****困难**. `PaddleSpeech` 有三种安装方法。根据安装的难易程度,这三种方法可以分为 **简单**, **中等****困难**.
| 方式 | 功能 | 支持系统 | | 方式 | 功能 | 支持系统 |
| :--- | :----------------------------------------------------------- | :------------------ | | :--- | :----------------------------------------------------------- | :------------------ |
| 简单 | (1) 使用 PaddleSpeech 的命令行功能. <br> (2) 在 Aistudio上体验 PaddleSpeech. | Linux, Mac(不支持M1芯片),Windows | | 简单 | (1) 使用 PaddleSpeech 的命令行功能. <br> (2) 在 Aistudio上体验 PaddleSpeech. | Linux, Mac(不支持M1芯片),Windows (安装详情查看[#1195](https://github.com/PaddlePaddle/PaddleSpeech/discussions/1195)) |
| 中等 | 支持 PaddleSpeech 主要功能,比如使用已有 examples 中的模型和使用 PaddleSpeech 来训练自己的模型. | Linux | | 中等 | 支持 PaddleSpeech 主要功能,比如使用已有 examples 中的模型和使用 PaddleSpeech 来训练自己的模型. | Linux |
| 困难 | 支持 PaddleSpeech 的各项功能,包含结合kaldi使用 join ctc decoder 方式解码,训练语言模型,使用强制对齐等。并且你更能成为一名开发者! | Ubuntu | | 困难 | 支持 PaddleSpeech 的各项功能,包含结合kaldi使用 join ctc decoder 方式解码,训练语言模型,使用强制对齐等。并且你更能成为一名开发者! | Ubuntu |
## 先决条件 ## 先决条件
......
...@@ -82,7 +82,7 @@ PANN | ESC-50 |[pann-esc50](../../examples/esc50/cls0)|[esc50_cnn6.tar.gz](https ...@@ -82,7 +82,7 @@ PANN | ESC-50 |[pann-esc50](../../examples/esc50/cls0)|[esc50_cnn6.tar.gz](https
Model Type | Dataset| Example Link | Pretrained Models | Static Models Model Type | Dataset| Example Link | Pretrained Models | Static Models
:-------------:| :------------:| :-----: | :-----: | :-----: :-------------:| :------------:| :-----: | :-----: | :-----:
PANN | VoxCeleb| [voxceleb_ecapatdnn](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/voxceleb/sv0) | [ecapatdnn.tar.gz](https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz) | - ECAPA-TDNN | VoxCeleb| [voxceleb_ecapatdnn](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/voxceleb/sv0) | [ecapatdnn.tar.gz](https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1.tar.gz) | -
## Punctuation Restoration Models ## Punctuation Restoration Models
Model Type | Dataset| Example Link | Pretrained Models Model Type | Dataset| Example Link | Pretrained Models
......
...@@ -6,15 +6,8 @@ AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpu ...@@ -6,15 +6,8 @@ AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpu
We use AISHELL-3 to train a multi-speaker fastspeech2 model here. We use AISHELL-3 to train a multi-speaker fastspeech2 model here.
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download AISHELL-3. Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`.
```bash
wget https://www.openslr.org/resources/93/data_aishell3.tgz
```
Extract AISHELL-3.
```bash
mkdir data_aishell3
tar zxvf data_aishell3.tgz -C data_aishell3
```
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2. We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2.
You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo. You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo.
......
...@@ -6,15 +6,8 @@ This example contains code used to train a [Tacotron2](https://arxiv.org/abs/171 ...@@ -6,15 +6,8 @@ This example contains code used to train a [Tacotron2](https://arxiv.org/abs/171
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download AISHELL-3. Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`.
```bash
wget https://www.openslr.org/resources/93/data_aishell3.tgz
```
Extract AISHELL-3.
```bash
mkdir data_aishell3
tar zxvf data_aishell3.tgz -C data_aishell3
```
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for Tacotron2, the durations of MFA are not needed here. We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for Tacotron2, the durations of MFA are not needed here.
You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo. You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo.
......
...@@ -6,15 +6,8 @@ This example contains code used to train a [FastSpeech2](https://arxiv.org/abs/2 ...@@ -6,15 +6,8 @@ This example contains code used to train a [FastSpeech2](https://arxiv.org/abs/2
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download AISHELL-3. Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`.
```bash
wget https://www.openslr.org/resources/93/data_aishell3.tgz
```
Extract AISHELL-3.
```bash
mkdir data_aishell3
tar zxvf data_aishell3.tgz -C data_aishell3
```
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2. We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2.
You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo. You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo.
......
...@@ -4,15 +4,8 @@ This example contains code used to train a [parallel wavegan](http://arxiv.org/a ...@@ -4,15 +4,8 @@ This example contains code used to train a [parallel wavegan](http://arxiv.org/a
AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus that could be used to train multi-speaker Text-to-Speech (TTS) systems. AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus that could be used to train multi-speaker Text-to-Speech (TTS) systems.
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download AISHELL-3. Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`.
```bash
wget https://www.openslr.org/resources/93/data_aishell3.tgz
```
Extract AISHELL-3.
```bash
mkdir data_aishell3
tar zxvf data_aishell3.tgz -C data_aishell3
```
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2. We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2.
You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo. You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo.
......
...@@ -4,15 +4,7 @@ This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010. ...@@ -4,15 +4,7 @@ This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.
AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus that could be used to train multi-speaker Text-to-Speech (TTS) systems. AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus that could be used to train multi-speaker Text-to-Speech (TTS) systems.
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download AISHELL-3. Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`.
```bash
wget https://www.openslr.org/resources/93/data_aishell3.tgz
```
Extract AISHELL-3.
```bash
mkdir data_aishell3
tar zxvf data_aishell3.tgz -C data_aishell3
```
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2. We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2.
You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo. You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo.
......
...@@ -26,4 +26,7 @@ Use the following command to run diarization on AMI corpus. ...@@ -26,4 +26,7 @@ Use the following command to run diarization on AMI corpus.
./run.sh --data_folder ./amicorpus --manual_annot_folder ./ami_public_manual_1.6.2 ./run.sh --data_folder ./amicorpus --manual_annot_folder ./ami_public_manual_1.6.2
``` ```
## Results (DER) coming soon! :) ## Best performance in terms of Diarization Error Rate (DER).
| System | Mic. |Orcl. (Dev)|Orcl. (Eval)| Est. (Dev) |Est. (Eval)|
| --------|-------- | ---------|----------- | --------|-----------|
| ECAPA-TDNN + SC | HeadsetMix| 1.54 % | 3.07 %| 1.56 %| 3.28 % |
...@@ -3,7 +3,7 @@ This example contains code used to train a [Tacotron2](https://arxiv.org/abs/171 ...@@ -3,7 +3,7 @@ This example contains code used to train a [Tacotron2](https://arxiv.org/abs/171
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/source). Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for Tacotron2, the durations of MFA are not needed here. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for Tacotron2, the durations of MFA are not needed here.
......
...@@ -3,7 +3,7 @@ This example contains code used to train a [SpeedySpeech](http://arxiv.org/abs/2 ...@@ -3,7 +3,7 @@ This example contains code used to train a [SpeedySpeech](http://arxiv.org/abs/2
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/source). Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for SPEEDYSPEECH. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for SPEEDYSPEECH.
......
...@@ -4,7 +4,7 @@ This example contains code used to train a [Fastspeech2](https://arxiv.org/abs/2 ...@@ -4,7 +4,7 @@ This example contains code used to train a [Fastspeech2](https://arxiv.org/abs/2
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/source). Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2.
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
## 数据集 ## 数据集
### 下载并解压 ### 下载并解压
[官方网站](https://test.data-baker.com/data/index/source) 下载数据集 [官方网站](https://test.data-baker.com/data/index/TNtts/) 下载数据集
### 获取MFA结果并解压 ### 获取MFA结果并解压
我们使用 [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) 去获得 fastspeech2 的音素持续时间。 我们使用 [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) 去获得 fastspeech2 的音素持续时间。
......
# This configuration tested on 4 GPUs (V100) with 32GB GPU
# memory. It takes around 2 weeks to finish the training
# but 100k iters model should generate reasonable results.
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 22050 # sr
n_fft: 1024 # FFT size (samples).
n_shift: 256 # Hop size (samples). 12.5ms
win_length: null # Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
##########################################################
# TTS MODEL SETTING #
##########################################################
model:
# generator related
generator_type: vits_generator
generator_params:
hidden_channels: 192
spks: -1
global_channels: -1
segment_size: 32
text_encoder_attention_heads: 2
text_encoder_ffn_expand: 4
text_encoder_blocks: 6
text_encoder_positionwise_layer_type: "conv1d"
text_encoder_positionwise_conv_kernel_size: 3
text_encoder_positional_encoding_layer_type: "rel_pos"
text_encoder_self_attention_layer_type: "rel_selfattn"
text_encoder_activation_type: "swish"
text_encoder_normalize_before: True
text_encoder_dropout_rate: 0.1
text_encoder_positional_dropout_rate: 0.0
text_encoder_attention_dropout_rate: 0.1
use_macaron_style_in_text_encoder: True
use_conformer_conv_in_text_encoder: False
text_encoder_conformer_kernel_size: -1
decoder_kernel_size: 7
decoder_channels: 512
decoder_upsample_scales: [8, 8, 2, 2]
decoder_upsample_kernel_sizes: [16, 16, 4, 4]
decoder_resblock_kernel_sizes: [3, 7, 11]
decoder_resblock_dilations: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
use_weight_norm_in_decoder: True
posterior_encoder_kernel_size: 5
posterior_encoder_layers: 16
posterior_encoder_stacks: 1
posterior_encoder_base_dilation: 1
posterior_encoder_dropout_rate: 0.0
use_weight_norm_in_posterior_encoder: True
flow_flows: 4
flow_kernel_size: 5
flow_base_dilation: 1
flow_layers: 4
flow_dropout_rate: 0.0
use_weight_norm_in_flow: True
use_only_mean_in_flow: True
stochastic_duration_predictor_kernel_size: 3
stochastic_duration_predictor_dropout_rate: 0.5
stochastic_duration_predictor_flows: 4
stochastic_duration_predictor_dds_conv_layers: 3
# discriminator related
discriminator_type: hifigan_multi_scale_multi_period_discriminator
discriminator_params:
scales: 1
scale_downsample_pooling: "AvgPool1D"
scale_downsample_pooling_params:
kernel_size: 4
stride: 2
padding: 2
scale_discriminator_params:
in_channels: 1
out_channels: 1
kernel_sizes: [15, 41, 5, 3]
channels: 128
max_downsample_channels: 1024
max_groups: 16
bias: True
downsample_scales: [2, 2, 4, 4, 1]
nonlinear_activation: "leakyrelu"
nonlinear_activation_params:
negative_slope: 0.1
use_weight_norm: True
use_spectral_norm: False
follow_official_norm: False
periods: [2, 3, 5, 7, 11]
period_discriminator_params:
in_channels: 1
out_channels: 1
kernel_sizes: [5, 3]
channels: 32
downsample_scales: [3, 3, 3, 3, 1]
max_downsample_channels: 1024
bias: True
nonlinear_activation: "leakyrelu"
nonlinear_activation_params:
negative_slope: 0.1
use_weight_norm: True
use_spectral_norm: False
# others
sampling_rate: 22050 # needed in the inference for saving wav
cache_generator_outputs: True # whether to cache generator outputs in the training
###########################################################
# LOSS SETTING #
###########################################################
# loss function related
generator_adv_loss_params:
average_by_discriminators: False # whether to average loss value by #discriminators
loss_type: mse # loss type, "mse" or "hinge"
discriminator_adv_loss_params:
average_by_discriminators: False # whether to average loss value by #discriminators
loss_type: mse # loss type, "mse" or "hinge"
feat_match_loss_params:
average_by_discriminators: False # whether to average loss value by #discriminators
average_by_layers: False # whether to average loss value by #layers of each discriminator
include_final_outputs: True # whether to include final outputs for loss calculation
mel_loss_params:
fs: 22050 # must be the same as the training data
fft_size: 1024 # fft points
hop_size: 256 # hop size
win_length: null # window length
window: hann # window type
num_mels: 80 # number of Mel basis
fmin: 0 # minimum frequency for Mel basis
fmax: null # maximum frequency for Mel basis
log_base: null # null represent natural log
###########################################################
# ADVERSARIAL LOSS SETTING #
###########################################################
lambda_adv: 1.0 # loss scaling coefficient for adversarial loss
lambda_mel: 45.0 # loss scaling coefficient for Mel loss
lambda_feat_match: 2.0 # loss scaling coefficient for feat match loss
lambda_dur: 1.0 # loss scaling coefficient for duration loss
lambda_kl: 1.0 # loss scaling coefficient for KL divergence loss
# others
sampling_rate: 22050 # needed in the inference for saving wav
cache_generator_outputs: True # whether to cache generator outputs in the training
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 64 # Batch size.
num_workers: 4 # Number of workers in DataLoader.
##########################################################
# OPTIMIZER & SCHEDULER SETTING #
##########################################################
# optimizer setting for generator
generator_optimizer_params:
beta1: 0.8
beta2: 0.99
epsilon: 1.0e-9
weight_decay: 0.0
generator_scheduler: exponential_decay
generator_scheduler_params:
learning_rate: 2.0e-4
gamma: 0.999875
# optimizer setting for discriminator
discriminator_optimizer_params:
beta1: 0.8
beta2: 0.99
epsilon: 1.0e-9
weight_decay: 0.0
discriminator_scheduler: exponential_decay
discriminator_scheduler_params:
learning_rate: 2.0e-4
gamma: 0.999875
generator_first: False # whether to start updating generator first
##########################################################
# OTHER TRAINING SETTING #
##########################################################
max_epoch: 1000 # number of epochs
num_snapshots: 10 # max number of snapshots to keep while training
seed: 777 # random seed number
#!/bin/bash
stage=0
stop_stage=100
config_path=$1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# get durations from MFA's result
echo "Generate durations.txt from MFA results ..."
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
--inputdir=./baker_alignment_tone \
--output=durations.txt \
--config=${config_path}
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# extract features
echo "Extract features ..."
python3 ${BIN_DIR}/preprocess.py \
--dataset=baker \
--rootdir=~/datasets/BZNSYP/ \
--dumpdir=dump \
--dur-file=durations.txt \
--config=${config_path} \
--num-cpu=20 \
--cut-sil=True
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# get features' stats(mean and std)
echo "Get features' stats ..."
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="feats"
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# normalize and covert phone/speaker to id, dev and test should use train's stats
echo "Normalize ..."
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/train/raw/metadata.jsonl \
--dumpdir=dump/train/norm \
--feats-stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \
--skip-wav-copy
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/dev/raw/metadata.jsonl \
--dumpdir=dump/dev/norm \
--feats-stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \
--skip-wav-copy
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/test/raw/metadata.jsonl \
--dumpdir=dump/test/norm \
--feats-stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \
--skip-wav-copy
fi
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
stage=0
stop_stage=0
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize.py \
--config=${config_path} \
--ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--phones_dict=dump/phone_id_map.txt \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test
fi
\ No newline at end of file
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
stage=0
stop_stage=0
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize_e2e.py \
--config=${config_path} \
--ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--phones_dict=dump/phone_id_map.txt \
--output_dir=${train_output_path}/test_e2e \
--text=${BIN_DIR}/../sentences.txt
fi
#!/bin/bash
config_path=$1
train_output_path=$2
python3 ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=4 \
--phones-dict=dump/phone_id_map.txt
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=vits
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
\ No newline at end of file
#!/bin/bash
set -e
source path.sh
gpus=0,1
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_153.pdz
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
./local/preprocess.sh ${conf_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
This example contains code used to train a [parallel wavegan](http://arxiv.org/abs/1910.11480) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). This example contains code used to train a [parallel wavegan](http://arxiv.org/abs/1910.11480) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html).
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. Download CSMSC from it's [official website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence at the edge of audio. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence at the edge of audio.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
This example contains code used to train a [Multi Band MelGAN](https://arxiv.org/abs/2005.05106) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). This example contains code used to train a [Multi Band MelGAN](https://arxiv.org/abs/2005.05106) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html).
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. Download CSMSC from it's [official website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
This example contains code used to train a [Style MelGAN](https://arxiv.org/abs/2011.01557) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). This example contains code used to train a [Style MelGAN](https://arxiv.org/abs/2011.01557) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html).
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. Download CSMSC from it's [official website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html).
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. Download CSMSC from it's [official website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence at the edge of audio. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence at the edge of audio.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
This example contains code used to train a [WaveRNN](https://arxiv.org/abs/1802.08435) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). This example contains code used to train a [WaveRNN](https://arxiv.org/abs/1802.08435) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html).
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. Download CSMSC from it's [official website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence at the edge of audio. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence at the edge of audio.
......
...@@ -3,7 +3,7 @@ This example contains code used to train a [Tacotron2](https://arxiv.org/abs/171 ...@@ -3,7 +3,7 @@ This example contains code used to train a [Tacotron2](https://arxiv.org/abs/171
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/). Download LJSpeech-1.1 from it's [Official Website](https://keithito.com/LJ-Speech-Dataset/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for Tacotron2, the durations of MFA are not needed here. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for Tacotron2, the durations of MFA are not needed here.
......
# TransformerTTS with LJSpeech # TransformerTTS with LJSpeech
## Dataset ## Dataset
We experiment with the LJSpeech dataset. Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/). ### Download and Extract
Download LJSpeech-1.1 from it's [Official Website](https://keithito.com/LJ-Speech-Dataset/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`.
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
## Get Started ## Get Started
Assume the path to the dataset is `~/datasets/LJSpeech-1.1`. Assume the path to the dataset is `~/datasets/LJSpeech-1.1` and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`.
Run the command below to Run the command below to
1. **source path**. 1. **source path**.
2. preprocess the dataset. 2. preprocess the dataset.
......
...@@ -3,7 +3,7 @@ This example contains code used to train a [Fastspeech2](https://arxiv.org/abs/2 ...@@ -3,7 +3,7 @@ This example contains code used to train a [Fastspeech2](https://arxiv.org/abs/2
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/). Download LJSpeech-1.1 from it's [Official Website](https://keithito.com/LJ-Speech-Dataset/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2.
......
# WaveFlow with LJSpeech # WaveFlow with LJSpeech
## Dataset ## Dataset
We experiment with the LJSpeech dataset. Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/). ### Download and Extract
Download LJSpeech-1.1 from it's [Official Website](https://keithito.com/LJ-Speech-Dataset/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`.
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
## Get Started ## Get Started
Assume the path to the dataset is `~/datasets/LJSpeech-1.1`. Assume the path to the dataset is `~/datasets/LJSpeech-1.1`.
Assume the path to the Tacotron2 generated mels is `../tts0/output/test`. Assume the path to the Tacotron2 generated mels is `../tts0/output/test`.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
This example contains code used to train a [parallel wavegan](http://arxiv.org/abs/1910.11480) model with [LJSpeech-1.1](https://keithito.com/LJ-Speech-Dataset/). This example contains code used to train a [parallel wavegan](http://arxiv.org/abs/1910.11480) model with [LJSpeech-1.1](https://keithito.com/LJ-Speech-Dataset/).
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/). Download LJSpeech-1.1 from it's [Official Website](https://keithito.com/LJ-Speech-Dataset/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio.
You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo. You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [LJSpeech-1.1](https://keithito.com/LJ-Speech-Dataset/). This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [LJSpeech-1.1](https://keithito.com/LJ-Speech-Dataset/).
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/). Download LJSpeech-1.1 from it's [Official Website](https://keithito.com/LJ-Speech-Dataset/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/LJSpeech-1.1`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio.
You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo. You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo.
......
...@@ -3,7 +3,7 @@ This example contains code used to train a [Fastspeech2](https://arxiv.org/abs/2 ...@@ -3,7 +3,7 @@ This example contains code used to train a [Fastspeech2](https://arxiv.org/abs/2
## Dataset ## Dataset
### Download and Extract the dataset ### Download and Extract the dataset
Download VCTK-0.92 from the [official website](https://datashare.ed.ac.uk/handle/10283/3443). Download VCTK-0.92 from it's [Official Website](https://datashare.ed.ac.uk/handle/10283/3443) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/VCTK-Corpus-0.92`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2.
......
...@@ -3,7 +3,7 @@ This example contains code used to train a [parallel wavegan](http://arxiv.org/a ...@@ -3,7 +3,7 @@ This example contains code used to train a [parallel wavegan](http://arxiv.org/a
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download VCTK-0.92 from the [official website](https://datashare.ed.ac.uk/handle/10283/3443) and extract it to `~/datasets`. Then the dataset is in directory `~/datasets/VCTK-Corpus-0.92`. Download VCTK-0.92 from it's [Official Website](https://datashare.ed.ac.uk/handle/10283/3443) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/VCTK-Corpus-0.92`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio.
......
...@@ -3,7 +3,7 @@ This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010. ...@@ -3,7 +3,7 @@ This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.
## Dataset ## Dataset
### Download and Extract ### Download and Extract
Download VCTK-0.92 from the [official website](https://datashare.ed.ac.uk/handle/10283/3443) and extract it to `~/datasets`. Then the dataset is in directory `~/datasets/VCTK-Corpus-0.92`. Download VCTK-0.92 from it's [Official Website](https://datashare.ed.ac.uk/handle/10283/3443) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/VCTK-Corpus-0.92`.
### Get MFA Result and Extract ### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio.
......
...@@ -141,11 +141,11 @@ using the `tar` scripts to unpack the model and then you can use the script to t ...@@ -141,11 +141,11 @@ using the `tar` scripts to unpack the model and then you can use the script to t
For example: For example:
``` ```
wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1.tar.gz
tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1.tar.gz
source path.sh source path.sh
# If you have processed the data and get the manifest file, you can skip the following 2 steps # If you have processed the data and get the manifest file, you can skip the following 2 steps
CUDA_VISIBLE_DEVICES= bash ./local/test.sh ./data sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_2/model/ conf/ecapa_tdnn.yaml CUDA_VISIBLE_DEVICES= bash ./local/test.sh ./data sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1/model/ conf/ecapa_tdnn.yaml
``` ```
The performance of the released models are shown in [this](./RESULTS.md) The performance of the released models are shown in [this](./RESULTS.md)
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
| Model | Number of Params | Release | Config | dim | Test set | Cosine | Cosine + S-Norm | | Model | Number of Params | Release | Config | dim | Test set | Cosine | Cosine + S-Norm |
| --- | --- | --- | --- | --- | --- | --- | ---- | | --- | --- | --- | --- | --- | --- | --- | ---- |
| ECAPA-TDNN | 85M | 0.2.0 | conf/ecapa_tdnn.yaml |192 | test | 1.02 | 0.95 | | ECAPA-TDNN | 85M | 0.2.1 | conf/ecapa_tdnn.yaml | 192 | test | 0.8188 | 0.7815|
...@@ -59,3 +59,11 @@ global_embedding_norm: True ...@@ -59,3 +59,11 @@ global_embedding_norm: True
embedding_mean_norm: True embedding_mean_norm: True
embedding_std_norm: False embedding_std_norm: False
###########################################
# score-norm #
###########################################
score_norm: s-norm
cohort_size: 20000 # amount of imposter utterances in normalization cohort
n_train_snts: 400000 # used for normalization stats
...@@ -58,3 +58,10 @@ global_embedding_norm: True ...@@ -58,3 +58,10 @@ global_embedding_norm: True
embedding_mean_norm: True embedding_mean_norm: True
embedding_std_norm: False embedding_std_norm: False
###########################################
# score-norm #
###########################################
score_norm: s-norm
cohort_size: 20000 # amount of imposter utterances in normalization cohort
n_train_snts: 400000 # used for normalization stats
...@@ -181,7 +181,7 @@ class ASRExecutor(BaseExecutor): ...@@ -181,7 +181,7 @@ class ASRExecutor(BaseExecutor):
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: elif "conformer" in model_type or "transformer" in model_type:
self.config.spm_model_prefix = os.path.join( self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix) self.res_path, self.config.spm_model_prefix)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
...@@ -205,7 +205,7 @@ class ASRExecutor(BaseExecutor): ...@@ -205,7 +205,7 @@ class ASRExecutor(BaseExecutor):
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
# compute the max len limit # compute the max len limit
if "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: if "conformer" in model_type or "transformer" in model_type:
# in transformer like model, we may use the subsample rate cnn network # in transformer like model, we may use the subsample rate cnn network
subsample_rate = self.model.subsampling_rate() subsample_rate = self.model.subsampling_rate()
frame_shift_ms = self.config.preprocess_config.process[0][ frame_shift_ms = self.config.preprocess_config.process[0][
...@@ -242,7 +242,7 @@ class ASRExecutor(BaseExecutor): ...@@ -242,7 +242,7 @@ class ASRExecutor(BaseExecutor):
self._inputs["audio_len"] = audio_len self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}") logger.info(f"audio feat shape: {audio.shape}")
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: elif "conformer" in model_type or "transformer" in model_type:
logger.info("get the preprocess conf") logger.info("get the preprocess conf")
preprocess_conf = self.config.preprocess_config preprocess_conf = self.config.preprocess_config
preprocess_args = {"train": False} preprocess_args = {"train": False}
......
...@@ -23,6 +23,7 @@ import paddle ...@@ -23,6 +23,7 @@ import paddle
import yaml import yaml
from paddleaudio import load from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram from paddleaudio.features import LogMelSpectrogram
from paddlespeech.utils.dynamic_import import dynamic_import
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
...@@ -30,7 +31,7 @@ from ..utils import cli_register ...@@ -30,7 +31,7 @@ from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
__all__ = ['CLSExecutor'] __all__ = ['CLSExecutor']
......
...@@ -86,7 +86,7 @@ def get_path_from_url(url, ...@@ -86,7 +86,7 @@ def get_path_from_url(url,
str: a local path to save downloaded models & weights & datasets. str: a local path to save downloaded models & weights & datasets.
""" """
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.distributed import ParallelEnv
assert _is_url(url), "downloading from {} not a url".format(url) assert _is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir # parse path after download to decompress under root_dir
......
...@@ -36,8 +36,8 @@ from .pretrained_models import kaldi_bins ...@@ -36,8 +36,8 @@ from .pretrained_models import kaldi_bins
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ["STExecutor"] __all__ = ["STExecutor"]
......
...@@ -21,7 +21,6 @@ from typing import Union ...@@ -21,7 +21,6 @@ from typing import Union
import paddle import paddle
from ...s2t.utils.dynamic_import import dynamic_import
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register from ..utils import cli_register
...@@ -29,6 +28,7 @@ from ..utils import stats_wrapper ...@@ -29,6 +28,7 @@ from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
from .pretrained_models import tokenizer_alias from .pretrained_models import tokenizer_alias
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TextExecutor'] __all__ = ['TextExecutor']
......
...@@ -32,10 +32,10 @@ from ..utils import cli_register ...@@ -32,10 +32,10 @@ from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TTSExecutor'] __all__ = ['TTSExecutor']
......
...@@ -24,11 +24,11 @@ from typing import Any ...@@ -24,11 +24,11 @@ from typing import Any
from typing import Dict from typing import Dict
import paddle import paddle
import paddleaudio
import requests import requests
import yaml import yaml
from paddle.framework import load from paddle.framework import load
import paddleaudio
from . import download from . import download
from .entry import commands from .entry import commands
try: try:
......
...@@ -32,7 +32,7 @@ from ..utils import cli_register ...@@ -32,7 +32,7 @@ from ..utils import cli_register
from ..utils import stats_wrapper from ..utils import stats_wrapper
from .pretrained_models import model_alias from .pretrained_models import model_alias
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.modules.sid_model import SpeakerIdetification
......
...@@ -19,9 +19,9 @@ pretrained_models = { ...@@ -19,9 +19,9 @@ pretrained_models = {
# "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav" # "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav"
"ecapatdnn_voxceleb12-16k": { "ecapatdnn_voxceleb12-16k": {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz', 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1.tar.gz',
'md5': 'md5':
'cc33023c54ab346cd318408f43fcaf95', '67c7ff8885d5246bd16e0f5ac1cba99f',
'cfg_path': 'cfg_path':
'conf/model.yaml', # the yaml config path 'conf/model.yaml', # the yaml config path
'ckpt_path': 'ckpt_path':
......
...@@ -22,7 +22,7 @@ from paddleaudio.features import LogMelSpectrogram ...@@ -22,7 +22,7 @@ from paddleaudio.features import LogMelSpectrogram
from paddleaudio.utils import logger from paddleaudio.utils import logger
from paddlespeech.cls.models import SoundClassifier from paddlespeech.cls.models import SoundClassifier
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
......
...@@ -21,7 +21,7 @@ from paddleaudio.utils import logger ...@@ -21,7 +21,7 @@ from paddleaudio.utils import logger
from paddleaudio.utils import Timer from paddleaudio.utils import Timer
from paddlespeech.cls.models import SoundClassifier from paddlespeech.cls.models import SoundClassifier
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
......
...@@ -37,6 +37,12 @@ if __name__ == "__main__": ...@@ -37,6 +37,12 @@ if __name__ == "__main__":
"--export_path", type=str, help="path of the jit model to save") "--export_path", type=str, help="path of the jit model to save")
parser.add_argument( parser.add_argument(
"--model_type", type=str, default='offline', help="offline/online") "--model_type", type=str, default='offline', help="offline/online")
parser.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
args = parser.parse_args() args = parser.parse_args()
print("model_type:{}".format(args.model_type)) print("model_type:{}".format(args.model_type))
print_arguments(args) print_arguments(args)
......
...@@ -37,6 +37,12 @@ if __name__ == "__main__": ...@@ -37,6 +37,12 @@ if __name__ == "__main__":
# save asr result to # save asr result to
parser.add_argument( parser.add_argument(
"--result_file", type=str, help="path of save the asr result") "--result_file", type=str, help="path of save the asr result")
parser.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())
print("model_type:{}".format(args.model_type)) print("model_type:{}".format(args.model_type))
......
...@@ -40,6 +40,12 @@ if __name__ == "__main__": ...@@ -40,6 +40,12 @@ if __name__ == "__main__":
"--export_path", type=str, help="path of the jit model to save") "--export_path", type=str, help="path of the jit model to save")
parser.add_argument( parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online') "--model_type", type=str, default='offline', help='offline/online')
parser.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
parser.add_argument( parser.add_argument(
"--enable-auto-log", action="store_true", help="use auto log") "--enable-auto-log", action="store_true", help="use auto log")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -33,6 +33,12 @@ if __name__ == "__main__": ...@@ -33,6 +33,12 @@ if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
parser.add_argument( parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online') "--model_type", type=str, default='offline', help='offline/online')
parser.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
args = parser.parse_args() args = parser.parse_args()
print("model_type:{}".format(args.model_type)) print("model_type:{}".format(args.model_type))
print_arguments(args, globals()) print_arguments(args, globals())
......
...@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): ...@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False
......
...@@ -112,7 +112,16 @@ class Trainer(): ...@@ -112,7 +112,16 @@ class Trainer():
logger.info(f"Rank: {self.rank}/{self.world_size}") logger.info(f"Rank: {self.rank}/{self.world_size}")
# set device # set device
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') if self.args.ngpu == 0:
if self.args.nxpu == 0:
paddle.set_device('cpu')
else:
paddle.set_device('xpu')
elif self.args.ngpu > 0:
paddle.set_device("gpu")
else:
raise Exception("invalid device")
if self.parallel: if self.parallel:
self.init_parallel() self.init_parallel()
......
...@@ -752,6 +752,7 @@ class VectorClientExecutor(BaseExecutor): ...@@ -752,6 +752,7 @@ class VectorClientExecutor(BaseExecutor):
res = handler.run(enroll_audio, test_audio, audio_format, res = handler.run(enroll_audio, test_audio, audio_format,
sample_rate) sample_rate)
logger.info(f"The vector score is: {res}") logger.info(f"The vector score is: {res}")
return res
else: else:
logger.error(f"Sorry, we have not support such task {task}") logger.error(f"Sorry, we have not support such task {task}")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import os
import os.path as osp
import shutil
import subprocess
import tarfile
import time
import zipfile
import requests
from tqdm import tqdm
from paddlespeech.cli.log import logger
__all__ = ['get_path_from_url']
DOWNLOAD_RETRY_LIMIT = 3
def _is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def _get_unique_endpoints(trainer_endpoints):
# Sorting is to avoid different environmental variables for each card
trainer_endpoints.sort()
ips = set()
unique_endpoints = set()
for endpoint in trainer_endpoints:
ip = endpoint.split(":")[0]
if ip in ips:
continue
ips.add(ip)
unique_endpoints.add(endpoint)
logger.info("unique_endpoints {}".format(unique_endpoints))
return unique_endpoints
def get_path_from_url(url,
root_dir,
md5sum=None,
check_exist=True,
decompress=True,
method='get'):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
decompress (bool): decompress zip or tar file. Default is `True`
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
from paddle.fluid.dygraph.parallel import ParallelEnv
assert _is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir)
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different ips will download
# data, and the same ip will only download data once.
unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:])
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum, method=method)
else:
while not os.path.exists(fullpath):
time.sleep(1)
if ParallelEnv().current_endpoint in unique_endpoints:
if decompress and (tarfile.is_tarfile(fullpath) or
zipfile.is_zipfile(fullpath)):
fullpath = _decompress(fullpath)
return fullpath
def _get_download(url, fullname):
# using requests.get method
fname = osp.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _wget_download(url, fullname):
# using wget to download url
tmp_fullname = fullname + "_tmp"
# –user-agent
command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT,
url)
subprc = subprocess.Popen(
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
_ = subprc.communicate()
if subprc.returncode != 0:
raise RuntimeError(
'{} failed. Please make sure `wget` is installed or {} exists'.
format(command, url))
shutil.move(tmp_fullname, fullname)
return fullname
_download_methods = {
'get': _get_download,
'wget': _wget_download,
}
def _download(url, path, md5sum=None, method='get'):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
md5sum (str): md5 sum of download package
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
"""
assert method in _download_methods, 'make sure `{}` implemented'.format(
method)
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if not _download_methods[method](url, fullname):
time.sleep(1)
continue
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
def _uncompress_file_zip(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _uncompress_file_tar(filepath, mode="r:*"):
files = tarfile.open(filepath, mode)
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < -1:
return True
return False
def _is_a_single_dir(file_list):
new_file_list = []
for file_path in file_list:
if '/' in file_path:
file_path = file_path.replace('/', os.sep)
elif '\\' in file_path:
file_path = file_path.replace('\\', os.sep)
new_file_list.append(file_path)
file_name = new_file_list[0].split(os.sep)[0]
for i in range(1, len(new_file_list)):
if file_name != new_file_list[i].split(os.sep)[0]:
return False
return True
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -16,6 +16,7 @@ import json ...@@ -16,6 +16,7 @@ import json
import os import os
import re import re
import numpy as np
import paddle import paddle
import soundfile import soundfile
import websocket import websocket
...@@ -44,11 +45,10 @@ class ACSEngine(BaseEngine): ...@@ -44,11 +45,10 @@ class ACSEngine(BaseEngine):
logger.info("Init the acs engine") logger.info("Init the acs engine")
try: try:
self.config = config self.config = config
if self.config.device: self.device = self.config.get("device", paddle.get_device())
self.device = self.config.device
else:
self.device = paddle.get_device()
# websocket default ping timeout is 20 seconds
self.ping_timeout = self.config.get("ping_timeout", 20)
paddle.set_device(self.device) paddle.set_device(self.device)
logger.info(f"ACS Engine set the device: {self.device}") logger.info(f"ACS Engine set the device: {self.device}")
...@@ -100,8 +100,8 @@ class ACSEngine(BaseEngine): ...@@ -100,8 +100,8 @@ class ACSEngine(BaseEngine):
logger.error("No asr server, please input valid ip and port") logger.error("No asr server, please input valid ip and port")
return "" return ""
ws = websocket.WebSocket() ws = websocket.WebSocket()
ws.connect(self.url) logger.info(f"set the ping timeout: {self.ping_timeout} seconds")
# with websocket.WebSocket.connect(self.url) as ws: ws.connect(self.url, ping_timeout=self.ping_timeout)
audio_info = json.dumps( audio_info = json.dumps(
{ {
"name": "test.wav", "name": "test.wav",
...@@ -116,8 +116,8 @@ class ACSEngine(BaseEngine): ...@@ -116,8 +116,8 @@ class ACSEngine(BaseEngine):
logger.info("client receive msg={}".format(msg)) logger.info("client receive msg={}".format(msg))
# send the total audio data # send the total audio data
samples, sample_rate = soundfile.read(audio_data, dtype='int16') for chunk_data in self.read_wave(audio_data):
ws.send_binary(samples.tobytes()) ws.send_binary(chunk_data.tobytes())
msg = ws.recv() msg = ws.recv()
msg = json.loads(msg) msg = json.loads(msg)
logger.info(f"audio result: {msg}") logger.info(f"audio result: {msg}")
...@@ -142,6 +142,39 @@ class ACSEngine(BaseEngine): ...@@ -142,6 +142,39 @@ class ACSEngine(BaseEngine):
return msg return msg
def read_wave(self, audio_data: str):
"""read the audio file from specific wavfile path
Args:
audio_data (str): the audio data,
we assume that audio sample rate matches the model
Yields:
numpy.array: the samall package audio pcm data
"""
samples, sample_rate = soundfile.read(audio_data, dtype='int16')
x_len = len(samples)
assert sample_rate == 16000
chunk_size = int(85 * sample_rate / 1000) # 85ms, sample_rate = 16kHz
if x_len % chunk_size != 0:
padding_len_x = chunk_size - x_len % chunk_size
else:
padding_len_x = 0
padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0)
assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_size
end = start + chunk_size
x_chunk = padded_x[start:end]
yield x_chunk
def get_macthed_word(self, msg): def get_macthed_word(self, msg):
"""Get the matched info in msg """Get the matched info in msg
......
...@@ -38,7 +38,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine ...@@ -38,7 +38,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.audio_process import pcm2float
from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import init_predictor
__all__ = ['ASREngine'] __all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine']
# ASR server connection process class # ASR server connection process class
...@@ -53,7 +53,7 @@ class PaddleASRConnectionHanddler: ...@@ -53,7 +53,7 @@ class PaddleASRConnectionHanddler:
logger.info( logger.info(
"create an paddle asr connection handler to process the websocket connection" "create an paddle asr connection handler to process the websocket connection"
) )
self.config = asr_engine.config self.config = asr_engine.config # server config
self.model_config = asr_engine.executor.config self.model_config = asr_engine.executor.config
self.asr_engine = asr_engine self.asr_engine = asr_engine
...@@ -67,7 +67,7 @@ class PaddleASRConnectionHanddler: ...@@ -67,7 +67,7 @@ class PaddleASRConnectionHanddler:
# tokens to text # tokens to text
self.text_feature = self.asr_engine.executor.text_feature self.text_feature = self.asr_engine.executor.text_feature
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: if "deepspeech2" in self.model_type:
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
self.am_predictor = self.asr_engine.executor.am_predictor self.am_predictor = self.asr_engine.executor.am_predictor
...@@ -89,8 +89,8 @@ class PaddleASRConnectionHanddler: ...@@ -89,8 +89,8 @@ class PaddleASRConnectionHanddler:
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch) cfg.num_proc_bsearch)
# frame window samples length and frame shift samples length
# frame window and frame shift, in samples unit
self.win_length = int(self.model_config.window_ms / 1000 * self.win_length = int(self.model_config.window_ms / 1000 *
self.sample_rate) self.sample_rate)
self.n_shift = int(self.model_config.stride_ms / 1000 * self.n_shift = int(self.model_config.stride_ms / 1000 *
...@@ -109,16 +109,15 @@ class PaddleASRConnectionHanddler: ...@@ -109,16 +109,15 @@ class PaddleASRConnectionHanddler:
self.preprocess_args = {"train": False} self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf) self.preprocessing = Transformation(self.preprocess_conf)
# frame window samples length and frame shift samples length # frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length'] self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift'] self.n_shift = self.preprocess_conf.process[0]['n_shift']
else:
raise ValueError(f"Not supported: {self.model_type}")
def extract_feat(self, samples): def extract_feat(self, samples):
# we compute the elapsed time of first char occuring # we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving # and we record the start time at the first pcm sample arraving
# if self.first_char_occur_elapsed is not None:
# self.first_char_occur_elapsed = time.time()
if "deepspeech2online" in self.model_type: if "deepspeech2online" in self.model_type:
# self.reamined_wav stores all the samples, # self.reamined_wav stores all the samples,
...@@ -154,28 +153,27 @@ class PaddleASRConnectionHanddler: ...@@ -154,28 +153,27 @@ class PaddleASRConnectionHanddler:
spectrum = self.collate_fn_test._normalizer.apply(spectrum) spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment # spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature( feat = self.collate_fn_test.augmentation.transform_feature(spectrum)
spectrum)
audio_len = audio.shape[0] # audio_len is frame num
audio = paddle.to_tensor(audio, dtype='float32') frame_num = feat.shape[0]
# audio_len = paddle.to_tensor(audio_len) feat = paddle.to_tensor(feat, dtype='float32')
audio = paddle.unsqueeze(audio, axis=0) feat = paddle.unsqueeze(feat, axis=0)
if self.cached_feat is None: if self.cached_feat is None:
self.cached_feat = audio self.cached_feat = feat
else: else:
assert (len(audio.shape) == 3) assert (len(feat.shape) == 3)
assert (len(self.cached_feat.shape) == 3) assert (len(self.cached_feat.shape) == 3)
self.cached_feat = paddle.concat( self.cached_feat = paddle.concat(
[self.cached_feat, audio], axis=1) [self.cached_feat, feat], axis=1)
# set the feat device # set the feat device
if self.device is None: if self.device is None:
self.device = self.cached_feat.place self.device = self.cached_feat.place
self.num_frames += audio_len self.num_frames += frame_num
self.remained_wav = self.remained_wav[self.n_shift * audio_len:] self.remained_wav = self.remained_wav[self.n_shift * frame_num:]
logger.info( logger.info(
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
...@@ -183,25 +181,30 @@ class PaddleASRConnectionHanddler: ...@@ -183,25 +181,30 @@ class PaddleASRConnectionHanddler:
logger.info( logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
) )
elif "conformer_online" in self.model_type: elif "conformer_online" in self.model_type:
logger.info("Online ASR extract the feat") logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16) samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1 assert samples.ndim == 1
logger.info(f"This package receive {samples.shape[0]} pcm data")
self.num_samples += samples.shape[0] self.num_samples += samples.shape[0]
logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples, # self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples # include the original remained_wav and this package samples
if self.remained_wav is None: if self.remained_wav is None:
self.remained_wav = samples self.remained_wav = samples
else: else:
assert self.remained_wav.ndim == 1 assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples]) self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info( logger.info(
f"The connection remain the audio samples: {self.remained_wav.shape}" f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
) )
if len(self.remained_wav) < self.win_length: if len(self.remained_wav) < self.win_length:
# samples not enough for feature window
return 0 return 0
# fbank # fbank
...@@ -209,11 +212,13 @@ class PaddleASRConnectionHanddler: ...@@ -209,11 +212,13 @@ class PaddleASRConnectionHanddler:
**self.preprocess_args) **self.preprocess_args)
x_chunk = paddle.to_tensor( x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0) x_chunk, dtype="float32").unsqueeze(axis=0)
# feature cache
if self.cached_feat is None: if self.cached_feat is None:
self.cached_feat = x_chunk self.cached_feat = x_chunk
else: else:
assert (len(x_chunk.shape) == 3) assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) assert (len(self.cached_feat.shape) == 3) # (B,T,D)
self.cached_feat = paddle.concat( self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1) [self.cached_feat, x_chunk], axis=1)
...@@ -221,56 +226,93 @@ class PaddleASRConnectionHanddler: ...@@ -221,56 +226,93 @@ class PaddleASRConnectionHanddler:
if self.device is None: if self.device is None:
self.device = self.cached_feat.place self.device = self.cached_feat.place
# cur frame step
num_frames = x_chunk.shape[1] num_frames = x_chunk.shape[1]
# global frame step
self.num_frames += num_frames self.num_frames += num_frames
# update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:] self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info( logger.info(
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
) )
logger.info( logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
) )
# logger.info(f"accumulate samples: {self.num_samples}") logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}")
else:
raise ValueError(f"not supported: {self.model_type}")
def reset(self): def reset(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: if "deepspeech2" in self.model_type:
# for deepspeech2 # for deepspeech2
self.chunk_state_h_box = copy.deepcopy( # init state
self.asr_engine.executor.chunk_state_h_box) self.chunk_state_h_box = np.zeros(
self.chunk_state_c_box = copy.deepcopy( (self.model_config .num_rnn_layers, 1, self.model_config.rnn_layer_size),
self.asr_engine.executor.chunk_state_c_box) dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.model_config.num_rnn_layers, 1, self.model_config.rnn_layer_size),
dtype=float32)
self.decoder.reset_decoder(batch_size=1) self.decoder.reset_decoder(batch_size=1)
# for conformer online self.device = None
## common
# global sample and frame step
self.num_samples = 0
self.num_frames = 0
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
# partial/ending decoding results
self.result_transcripts = ['']
## conformer
# cache for conformer online
self.subsampling_cache = None self.subsampling_cache = None
self.elayers_output_cache = None self.elayers_output_cache = None
self.conformer_cnn_cache = None self.conformer_cnn_cache = None
self.encoder_out = None self.encoder_out = None
self.cached_feat = None # conformer decoding state
self.remained_wav = None self.chunk_num = 0 # globa decoding chunk num
self.offset = 0 self.offset = 0 # global offset in decoding frame unit
self.num_samples = 0
self.device = None
self.hyps = [] self.hyps = []
self.num_frames = 0
self.chunk_num = 0 # token timestamp result
self.global_frame_offset = 0
self.result_transcripts = ['']
self.word_time_stamp = [] self.word_time_stamp = []
# one best timestamp viterbi prob is large.
self.time_stamp = [] self.time_stamp = []
self.first_char_occur_elapsed = None
def decode(self, is_finished=False): def decode(self, is_finished=False):
"""advance decoding
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Raises:
Exception: when not support model.
Returns:
None: nothing
"""
if "deepspeech2online" in self.model_type: if "deepspeech2online" in self.model_type:
# x_chunk 是特征数据 decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model context = 7 # context=7, in audio frame unit
context = 7 # context=7 in deepspeech2 model subsampling = 4 # subsampling=4, in audio frame unit
subsampling = 4 # subsampling=4 in deepspeech2 model
stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling cached_feature_num = context - subsampling
# decoding window for model # decoding window for model, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context decoding_window = (decoding_chunk_size - 1) * subsampling + context
# decoding stride for model, in audio frame unit
stride = subsampling * decoding_chunk_size
if self.cached_feat is None: if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data") logger.info("no audio feat, please input more pcm data")
...@@ -280,6 +322,7 @@ class PaddleASRConnectionHanddler: ...@@ -280,6 +322,7 @@ class PaddleASRConnectionHanddler:
logger.info( logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
) )
# the cached feat must be larger decoding_window # the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished: if num_frames < decoding_window and not is_finished:
logger.info( logger.info(
...@@ -293,6 +336,7 @@ class PaddleASRConnectionHanddler: ...@@ -293,6 +336,7 @@ class PaddleASRConnectionHanddler:
"flast {num_frames} is less than context {context} frames, and we cannot do model forward" "flast {num_frames} is less than context {context} frames, and we cannot do model forward"
) )
return None, None return None, None
logger.info("start to do model forward") logger.info("start to do model forward")
# num_frames - context + 1 ensure that current frame can get context window # num_frames - context + 1 ensure that current frame can get context window
if is_finished: if is_finished:
...@@ -302,6 +346,7 @@ class PaddleASRConnectionHanddler: ...@@ -302,6 +346,7 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk # we only process decoding_window frames for one chunk
left_frames = decoding_window left_frames = decoding_window
end = None
for cur in range(0, num_frames - left_frames + 1, stride): for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames) end = min(cur + decoding_window, num_frames)
# extract the audio # extract the audio
...@@ -311,7 +356,9 @@ class PaddleASRConnectionHanddler: ...@@ -311,7 +356,9 @@ class PaddleASRConnectionHanddler:
self.result_transcripts = [trans_best] self.result_transcripts = [trans_best]
# update feat cache
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
# return trans_best[0] # return trans_best[0]
elif "conformer" in self.model_type or "transformer" in self.model_type: elif "conformer" in self.model_type or "transformer" in self.model_type:
try: try:
...@@ -328,7 +375,16 @@ class PaddleASRConnectionHanddler: ...@@ -328,7 +375,16 @@ class PaddleASRConnectionHanddler:
@paddle.no_grad() @paddle.no_grad()
def decode_one_chunk(self, x_chunk, x_chunk_lens): def decode_one_chunk(self, x_chunk, x_chunk_lens):
logger.info("start to decoce one chunk with deepspeech2 model") """forward one chunk frames
Args:
x_chunk (np.ndarray): (B,T,D), audio frames.
x_chunk_lens ([type]): (B,), audio frame lens
Returns:
logprob: poster probability.
"""
logger.info("start to decoce one chunk for deepspeech2")
input_names = self.am_predictor.get_input_names() input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
...@@ -365,24 +421,32 @@ class PaddleASRConnectionHanddler: ...@@ -365,24 +421,32 @@ class PaddleASRConnectionHanddler:
self.decoder.next(output_chunk_probs, output_chunk_lens) self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode() trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one best result: {trans_best[0]}") logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
return trans_best[0] return trans_best[0]
@paddle.no_grad() @paddle.no_grad()
def advance_decoding(self, is_finished=False): def advance_decoding(self, is_finished=False):
logger.info("start to decode with advanced_decoding method") logger.info(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg = self.ctc_decode_config cfg = self.ctc_decode_config
# cur chunk size, in decoding frame unit
decoding_chunk_size = cfg.decoding_chunk_size decoding_chunk_size = cfg.decoding_chunk_size
# using num of history chunks
num_decoding_left_chunks = cfg.num_decoding_left_chunks num_decoding_left_chunks = cfg.num_decoding_left_chunks
assert decoding_chunk_size > 0 assert decoding_chunk_size > 0
subsampling = self.model.encoder.embed.subsampling_rate subsampling = self.model.encoder.embed.subsampling_rate
context = self.model.encoder.embed.right_context + 1 context = self.model.encoder.embed.right_context + 1
stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling # processed chunk feature cached for next chunk
# decoding window for model # processed chunk feature cached for next chunk
cached_feature_num = context - subsampling
# decoding stride, in audio frame unit
stride = subsampling * decoding_chunk_size
# decoding window, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context decoding_window = (decoding_chunk_size - 1) * subsampling + context
if self.cached_feat is None: if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data") logger.info("no audio feat, please input more pcm data")
return return
...@@ -407,6 +471,7 @@ class PaddleASRConnectionHanddler: ...@@ -407,6 +471,7 @@ class PaddleASRConnectionHanddler:
return None, None return None, None
logger.info("start to do model forward") logger.info("start to do model forward")
# hist of chunks, in deocding frame unit
required_cache_size = decoding_chunk_size * num_decoding_left_chunks required_cache_size = decoding_chunk_size * num_decoding_left_chunks
outputs = [] outputs = []
...@@ -423,8 +488,11 @@ class PaddleASRConnectionHanddler: ...@@ -423,8 +488,11 @@ class PaddleASRConnectionHanddler:
for cur in range(0, num_frames - left_frames + 1, stride): for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames) end = min(cur + decoding_window, num_frames)
# global chunk_num
self.chunk_num += 1 self.chunk_num += 1
# cur chunk
chunk_xs = self.cached_feat[:, cur:end, :] chunk_xs = self.cached_feat[:, cur:end, :]
# forward chunk
(y, self.subsampling_cache, self.elayers_output_cache, (y, self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache) = self.model.encoder.forward_chunk( self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size, chunk_xs, self.offset, required_cache_size,
...@@ -432,7 +500,7 @@ class PaddleASRConnectionHanddler: ...@@ -432,7 +500,7 @@ class PaddleASRConnectionHanddler:
self.conformer_cnn_cache) self.conformer_cnn_cache)
outputs.append(y) outputs.append(y)
# update the offset # update the global offset, in decoding frame unit
self.offset += y.shape[1] self.offset += y.shape[1]
ys = paddle.cat(outputs, 1) ys = paddle.cat(outputs, 1)
...@@ -445,12 +513,15 @@ class PaddleASRConnectionHanddler: ...@@ -445,12 +513,15 @@ class PaddleASRConnectionHanddler:
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
# advance decoding
self.searcher.search(ctc_probs, self.cached_feat.place) self.searcher.search(ctc_probs, self.cached_feat.place)
# get one best hyps
self.hyps = self.searcher.get_one_best_hyps() self.hyps = self.searcher.get_one_best_hyps()
assert self.cached_feat.shape[0] == 1 assert self.cached_feat.shape[0] == 1
assert end >= cached_feature_num assert end >= cached_feature_num
# advance cache of feat
self.cached_feat = self.cached_feat[0, end - self.cached_feat = self.cached_feat[0, end -
cached_feature_num:, :].unsqueeze(0) cached_feature_num:, :].unsqueeze(0)
assert len( assert len(
...@@ -462,50 +533,81 @@ class PaddleASRConnectionHanddler: ...@@ -462,50 +533,81 @@ class PaddleASRConnectionHanddler:
) )
def update_result(self): def update_result(self):
"""Conformer/Transformer hyps to result.
"""
logger.info("update the final result") logger.info("update the final result")
hyps = self.hyps hyps = self.hyps
# output results and tokenids
self.result_transcripts = [ self.result_transcripts = [
self.text_feature.defeaturize(hyp) for hyp in hyps self.text_feature.defeaturize(hyp) for hyp in hyps
] ]
self.result_tokenids = [hyp for hyp in hyps] self.result_tokenids = [hyp for hyp in hyps]
def get_result(self): def get_result(self):
"""return partial/ending asr result.
Returns:
str: one best result of partial/ending.
"""
if len(self.result_transcripts) > 0: if len(self.result_transcripts) > 0:
return self.result_transcripts[0] return self.result_transcripts[0]
else: else:
return '' return ''
def get_word_time_stamp(self): def get_word_time_stamp(self):
"""return token timestamp result.
Returns:
list: List of ('w':token, 'bg':time, 'ed':time)
"""
return self.word_time_stamp return self.word_time_stamp
@paddle.no_grad() @paddle.no_grad()
def rescoring(self): def rescoring(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: """Second-Pass Decoding,
only for conformer and transformer model.
"""
if "deepspeech2" in self.model_type:
logger.info("deepspeech2 not support rescoring decoding.")
return return
logger.info("rescoring the final result")
if "attention_rescoring" != self.ctc_decode_config.decoding_method: if "attention_rescoring" != self.ctc_decode_config.decoding_method:
logger.info(
f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring"
)
return return
logger.info("rescoring the final result")
# last decoding for last audio
self.searcher.finalize_search() self.searcher.finalize_search()
# update beam search results
self.update_result() self.update_result()
beam_size = self.ctc_decode_config.beam_size beam_size = self.ctc_decode_config.beam_size
hyps = self.searcher.get_hyps() hyps = self.searcher.get_hyps()
if hyps is None or len(hyps) == 0: if hyps is None or len(hyps) == 0:
logger.info("No Hyps!")
return return
# rescore by decoder post probability
# assert len(hyps) == beam_size # assert len(hyps) == beam_size
# list of Tensor
hyp_list = [] hyp_list = []
for hyp in hyps: for hyp in hyps:
hyp_content = hyp[0] hyp_content = hyp[0]
# Prevent the hyp is empty # Prevent the hyp is empty
if len(hyp_content) == 0: if len(hyp_content) == 0:
hyp_content = (self.model.ctc.blank_id, ) hyp_content = (self.model.ctc.blank_id, )
hyp_content = paddle.to_tensor( hyp_content = paddle.to_tensor(
hyp_content, place=self.device, dtype=paddle.long) hyp_content, place=self.device, dtype=paddle.long)
hyp_list.append(hyp_content) hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
hyps_pad = pad_sequence(
hyp_list, batch_first=True, padding_value=self.model.ignore_id)
hyps_lens = paddle.to_tensor( hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=self.device, [len(hyp[0]) for hyp in hyps], place=self.device,
dtype=paddle.long) # (beam_size,) dtype=paddle.long) # (beam_size,)
...@@ -531,10 +633,12 @@ class PaddleASRConnectionHanddler: ...@@ -531,10 +633,12 @@ class PaddleASRConnectionHanddler:
score = 0.0 score = 0.0
for j, w in enumerate(hyp[0]): for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w] score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token. # last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.model.eos] score += decoder_out[i][len(hyp[0])][self.model.eos]
# add ctc score (which in ln domain) # add ctc score (which in ln domain)
score += hyp[1] * self.ctc_decode_config.ctc_weight score += hyp[1] * self.ctc_decode_config.ctc_weight
if score > best_score: if score > best_score:
best_score = score best_score = score
best_index = i best_index = i
...@@ -542,43 +646,52 @@ class PaddleASRConnectionHanddler: ...@@ -542,43 +646,52 @@ class PaddleASRConnectionHanddler:
# update the one best result # update the one best result
# hyps stored the beam results and each fields is: # hyps stored the beam results and each fields is:
logger.info(f"best index: {best_index}") logger.info(f"best hyp index: {best_index}")
# logger.info(f'best result: {hyps[best_index]}') # logger.info(f'best result: {hyps[best_index]}')
# the field of the hyps is: # the field of the hyps is:
## asr results
# hyps[0][0]: the sentence word-id in the vocab with a tuple # hyps[0][0]: the sentence word-id in the vocab with a tuple
# hyps[0][1]: the sentence decoding probability with all paths # hyps[0][1]: the sentence decoding probability with all paths
## timestamp
# hyps[0][2]: viterbi_blank ending probability # hyps[0][2]: viterbi_blank ending probability
# hyps[0][3]: viterbi_non_blank probability # hyps[0][3]: viterbi_non_blank dending probability
# hyps[0][4]: current_token_prob, # hyps[0][4]: current_token_prob,
# hyps[0][5]: times_viterbi_blank, # hyps[0][5]: times_viterbi_blank ending timestamp,
# hyps[0][6]: times_titerbi_non_blank # hyps[0][6]: times_titerbi_non_blank encding timestamp.
self.hyps = [hyps[best_index][0]] self.hyps = [hyps[best_index][0]]
logger.info(f"best hyp ids: {self.hyps}")
# update the hyps time stamp # update the hyps time stamp
self.time_stamp = hyps[best_index][5] if hyps[best_index][2] > hyps[ self.time_stamp = hyps[best_index][5] if hyps[best_index][2] > hyps[
best_index][3] else hyps[best_index][6] best_index][3] else hyps[best_index][6]
logger.info(f"time stamp: {self.time_stamp}") logger.info(f"time stamp: {self.time_stamp}")
# update one best result
self.update_result() self.update_result()
# update each word start and end time stamp # update each word start and end time stamp
frame_shift_in_ms = self.model.encoder.embed.subsampling_rate * self.n_shift / self.sample_rate # decoding frame to audio frame
logger.info(f"frame shift ms: {frame_shift_in_ms}") frame_shift = self.model.encoder.embed.subsampling_rate
frame_shift_in_sec = frame_shift * (self.n_shift / self.sample_rate)
logger.info(f"frame shift sec: {frame_shift_in_sec}")
word_time_stamp = [] word_time_stamp = []
for idx, _ in enumerate(self.time_stamp): for idx, _ in enumerate(self.time_stamp):
start = (self.time_stamp[idx - 1] + self.time_stamp[idx] start = (self.time_stamp[idx - 1] + self.time_stamp[idx]
) / 2.0 if idx > 0 else 0 ) / 2.0 if idx > 0 else 0
start = start * frame_shift_in_ms start = start * frame_shift_in_sec
end = (self.time_stamp[idx] + self.time_stamp[idx + 1] end = (self.time_stamp[idx] + self.time_stamp[idx + 1]
) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset ) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset
end = end * frame_shift_in_ms
end = end * frame_shift_in_sec
word_time_stamp.append({ word_time_stamp.append({
"w": self.result_transcripts[0][idx], "w": self.result_transcripts[0][idx],
"bg": start, "bg": start,
"ed": end "ed": end
}) })
# logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}") # logger.info(f"{word_time_stamp[-1]}")
self.word_time_stamp = word_time_stamp self.word_time_stamp = word_time_stamp
logger.info(f"word time stamp: {self.word_time_stamp}") logger.info(f"word time stamp: {self.word_time_stamp}")
...@@ -610,6 +723,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -610,6 +723,7 @@ class ASRServerExecutor(ASRExecutor):
self.sample_rate = sample_rate self.sample_rate = sample_rate
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
if cfg_path is None or am_model is None or am_params is None: if cfg_path is None or am_model is None or am_params is None:
logger.info(f"Load the pretrained model, tag = {tag}") logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
...@@ -639,7 +753,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -639,7 +753,7 @@ class ASRServerExecutor(ASRExecutor):
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
with UpdateConfig(self.config): with UpdateConfig(self.config):
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: if "deepspeech2" in model_type:
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
self.vocab = self.config.vocab_filepath self.vocab = self.config.vocab_filepath
self.config.decode.lang_model_path = os.path.join( self.config.decode.lang_model_path = os.path.join(
...@@ -655,6 +769,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -655,6 +769,7 @@ class ASRServerExecutor(ASRExecutor):
self.download_lm( self.download_lm(
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
logger.info("start to create the stream conformer asr engine") logger.info("start to create the stream conformer asr engine")
if self.config.spm_model_prefix: if self.config.spm_model_prefix:
...@@ -682,7 +797,8 @@ class ASRServerExecutor(ASRExecutor): ...@@ -682,7 +797,8 @@ class ASRServerExecutor(ASRExecutor):
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
else: else:
raise Exception("wrong type") raise Exception("wrong type")
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
if "deepspeech2" in model_type:
# AM predictor # AM predictor
logger.info("ASR engine start to init the am predictor") logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf self.am_predictor_conf = am_predictor_conf
...@@ -690,35 +806,6 @@ class ASRServerExecutor(ASRExecutor): ...@@ -690,35 +806,6 @@ class ASRServerExecutor(ASRExecutor):
model_file=self.am_model, model_file=self.am_model,
params_file=self.am_params, params_file=self.am_params,
predictor_conf=self.am_predictor_conf) predictor_conf=self.am_predictor_conf)
# decoder
logger.info("ASR engine start to create the ctc decoder instance")
self.decoder = CTCDecoder(
odim=self.config.output_dim, # <blank> is in vocab
enc_n_units=self.config.rnn_layer_size * 2,
blank_id=self.config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
# init decoder
logger.info("ASR engine start to init the ctc decoder")
cfg = self.config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# init state box
self.chunk_state_h_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
...@@ -733,281 +820,14 @@ class ASRServerExecutor(ASRExecutor): ...@@ -733,281 +820,14 @@ class ASRServerExecutor(ASRExecutor):
model_dict = paddle.load(self.am_model) model_dict = paddle.load(self.am_model)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
logger.info("create the transformer like model success") logger.info("create the transformer like model success")
# update the ctc decoding
self.searcher = CTCPrefixBeamSearch(self.config.decode)
self.transformer_decode_reset()
return True
def reset_decoder_and_chunk(self):
"""reset decoder and chunk state for an new audio
"""
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
self.decoder.reset_decoder(batch_size=1)
# init state box, for new audio request
self.chunk_state_h_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
elif "conformer" in self.model_type or "transformer" in self.model_type:
self.transformer_decode_reset()
def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str):
"""decode one chunk
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
model_type (str): online model type
Returns:
str: one best result
"""
logger.info("start to decoce chunk by chunk")
if "deepspeech2online" in model_type:
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(
input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
audio_len_handle.reshape(x_chunk_lens.shape)
audio_len_handle.copy_from_cpu(x_chunk_lens)
h_box_handle.reshape(self.chunk_state_h_box.shape)
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
c_box_handle.reshape(self.chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle(
output_names[3])
self.am_predictor.run()
output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu()
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one best result: {trans_best[0]}")
return trans_best[0]
elif "conformer" in model_type or "transformer" in model_type:
try:
logger.info(
f"we will use the transformer like model : {self.model_type}"
)
self.advanced_decoding(x_chunk, x_chunk_lens)
self.update_result()
return self.result_transcripts[0]
except Exception as e:
logger.exception(e)
else: else:
raise Exception("invalid model name") raise ValueError(f"Not support: {model_type}")
def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens):
logger.info("start to decode with advanced_decoding method")
encoder_out, encoder_mask = self.encoder_forward(xs)
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
self.searcher.search(ctc_probs, xs.place)
# update the one best result
self.hyps = self.searcher.get_one_best_hyps()
# now we supprot ctc_prefix_beam_search and attention_rescoring
if "attention_rescoring" in self.config.decode.decoding_method:
self.rescoring(encoder_out, xs.place)
def encoder_forward(self, xs):
logger.info("get the model out from the feat")
cfg = self.config.decode
decoding_chunk_size = cfg.decoding_chunk_size
num_decoding_left_chunks = cfg.num_decoding_left_chunks
assert decoding_chunk_size > 0
subsampling = self.model.encoder.embed.subsampling_rate
context = self.model.encoder.embed.right_context + 1
stride = subsampling * decoding_chunk_size
# decoding window for model
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
logger.info("start to do model forward")
outputs = []
# num_frames - context + 1 ensure that current frame can get context window
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size,
self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache)
outputs.append(y)
self.offset += y.shape[1]
ys = paddle.cat(outputs, 1)
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
return ys, masks
def rescoring(self, encoder_out, device):
logger.info("start to rescoring the hyps")
beam_size = self.config.decode.beam_size
hyps = self.searcher.get_hyps()
assert len(hyps) == beam_size
hyp_list = []
for hyp in hyps:
hyp_content = hyp[0]
# Prevent the hyp is empty
if len(hyp_content) == 0:
hyp_content = (self.model.ctc.blank_id, )
hyp_content = paddle.to_tensor(
hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos,
self.model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for i, hyp in enumerate(hyps):
score = 0.0
for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.model.eos]
# add ctc score (which in ln domain)
score += hyp[1] * self.config.decode.ctc_weight
if score > best_score:
best_score = score
best_index = i
# update the one best result
self.hyps = [hyps[best_index][0]]
return hyps[best_index][0]
def transformer_decode_reset(self):
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.offset = 0
# decoding reset
self.searcher.reset()
def update_result(self):
logger.info("update the final result")
hyps = self.hyps
self.result_transcripts = [
self.text_feature.defeaturize(hyp) for hyp in hyps
]
self.result_tokenids = [hyp for hyp in hyps]
def extract_feat(self, samples, sample_rate):
"""extract feat
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
if "deepspeech2online" in self.model_type:
# pcm16 -> pcm 32
samples = pcm2float(samples)
# read audio
speech_segment = SpeechSegment.from_pcm(
samples, sample_rate, transcript=" ")
# audio augment
self.collate_fn_test.augmentation.transform_audio(speech_segment)
# extract speech feature
spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
speech_segment, self.collate_fn_test.keep_transcription_text)
# CMVN spectrum
if self.collate_fn_test._normalizer:
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature(
spectrum)
audio_len = audio.shape[0] return True
audio = paddle.to_tensor(audio, dtype='float32')
# audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
x_chunk = audio.numpy()
x_chunk_lens = np.array([audio_len])
return x_chunk, x_chunk_lens
elif "conformer_online" in self.model_type:
if sample_rate != self.sample_rate:
logger.info(f"audio sample rate {sample_rate} is not match,"
"the model sample_rate is {self.sample_rate}")
logger.info(f"ASR Engine use the {self.model_type} to process")
logger.info("Create the preprocess instance")
preprocess_conf = self.config.preprocess_config
preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf)
logger.info("Read the audio file")
logger.info(f"audio shape: {samples.shape}")
# fbank
x_chunk = preprocessing(samples, **preprocess_args)
x_chunk_lens = paddle.to_tensor(x_chunk.shape[0])
x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0)
logger.info(
f"process the audio feature success, feat shape: {x_chunk.shape}"
)
return x_chunk, x_chunk_lens
class ASREngine(BaseEngine): class ASREngine(BaseEngine):
"""ASR server engine """ASR server resource
Args: Args:
metaclass: Defaults to Singleton. metaclass: Defaults to Singleton.
...@@ -1015,7 +835,7 @@ class ASREngine(BaseEngine): ...@@ -1015,7 +835,7 @@ class ASREngine(BaseEngine):
def __init__(self): def __init__(self):
super(ASREngine, self).__init__() super(ASREngine, self).__init__()
logger.info("create the online asr engine instance") logger.info("create the online asr engine resource instance")
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
"""init engine resource """init engine resource
...@@ -1026,16 +846,11 @@ class ASREngine(BaseEngine): ...@@ -1026,16 +846,11 @@ class ASREngine(BaseEngine):
Returns: Returns:
bool: init failed or success bool: init failed or success
""" """
self.input = None
self.output = ""
self.executor = ASRServerExecutor()
self.config = config self.config = config
self.executor = ASRServerExecutor()
try: try:
if self.config.get("device", None): self.device = self.config.get("device", paddle.get_device())
self.device = self.config.device
else:
self.device = paddle.get_device()
logger.info(f"paddlespeech_server set the device: {self.device}")
paddle.set_device(self.device) paddle.set_device(self.device)
except BaseException as e: except BaseException as e:
logger.error( logger.error(
...@@ -1045,6 +860,8 @@ class ASREngine(BaseEngine): ...@@ -1045,6 +860,8 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'") "If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1) sys.exit(-1)
logger.info(f"paddlespeech_server set the device: {self.device}")
if not self.executor._init_from_path( if not self.executor._init_from_path(
model_type=self.config.model_type, model_type=self.config.model_type,
am_model=self.config.am_model, am_model=self.config.am_model,
...@@ -1062,42 +879,11 @@ class ASREngine(BaseEngine): ...@@ -1062,42 +879,11 @@ class ASREngine(BaseEngine):
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully.")
return True return True
def preprocess(self, def preprocess(self, *args, **kwargs):
samples, raise NotImplementedError("Online not using this.")
sample_rate,
model_type="deepspeech2online_aishell-zh-16k"):
"""preprocess
Args: def run(self, *args, **kwargs):
samples (numpy.array): numpy.float32 raise NotImplementedError("Online not using this.")
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
# if "deepspeech" in model_type:
x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate)
return x_chunk, x_chunk_lens
def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1):
"""run online engine
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
decoder_chunk_size(int)
"""
self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens,
self.config.model_type)
def postprocess(self): def postprocess(self):
"""postprocess raise NotImplementedError("Online not using this.")
"""
return self.output
def reset(self):
"""reset engine decoder and inference state
"""
self.executor.reset_decoder_and_chunk()
self.output = ""
...@@ -25,7 +25,6 @@ from yacs.config import CfgNode ...@@ -25,7 +25,6 @@ from yacs.config import CfgNode
from .pretrained_models import pretrained_models from .pretrained_models import pretrained_models
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.util import denorm from paddlespeech.server.utils.util import denorm
...@@ -33,6 +32,7 @@ from paddlespeech.server.utils.util import get_chunks ...@@ -33,6 +32,7 @@ from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TTSEngine'] __all__ = ['TTSEngine']
......
...@@ -17,12 +17,12 @@ from typing import List ...@@ -17,12 +17,12 @@ from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.restful.acs_api import router as acs_router
from paddlespeech.server.restful.asr_api import router as asr_router from paddlespeech.server.restful.asr_api import router as asr_router
from paddlespeech.server.restful.cls_api import router as cls_router from paddlespeech.server.restful.cls_api import router as cls_router
from paddlespeech.server.restful.text_api import router as text_router from paddlespeech.server.restful.text_api import router as text_router
from paddlespeech.server.restful.tts_api import router as tts_router from paddlespeech.server.restful.tts_api import router as tts_router
from paddlespeech.server.restful.vector_api import router as vec_router from paddlespeech.server.restful.vector_api import router as vec_router
from paddlespeech.server.restful.acs_api import router as acs_router
_router = APIRouter() _router = APIRouter()
......
...@@ -29,9 +29,9 @@ import requests ...@@ -29,9 +29,9 @@ import requests
import yaml import yaml
from paddle.framework import load from paddle.framework import load
from . import download
from .entry import client_commands from .entry import client_commands
from .entry import server_commands from .entry import server_commands
from paddlespeech.cli import download
try: try:
from .. import __version__ from .. import __version__
except ImportError: except ImportError:
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class Frame(object): class Frame(object):
"""Represents a "frame" of audio data.""" """Represents a "frame" of audio data."""
...@@ -77,8 +78,8 @@ class ChunkBuffer(object): ...@@ -77,8 +78,8 @@ class ChunkBuffer(object):
offset = 0 offset = 0
while offset + self.window_bytes <= len(audio): while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], self.timestamp, yield Frame(audio[offset:offset + self.window_bytes],
self.window_sec) self.timestamp, self.window_sec)
self.timestamp += self.shift_sec self.timestamp += self.shift_sec
offset += self.shift_bytes offset += self.shift_bytes
......
...@@ -293,3 +293,45 @@ def transformer_single_spk_batch_fn(examples): ...@@ -293,3 +293,45 @@ def transformer_single_spk_batch_fn(examples):
"speech_lengths": speech_lengths, "speech_lengths": speech_lengths,
} }
return batch return batch
def vits_single_spk_batch_fn(examples):
"""
Returns:
Dict[str, Any]:
- text (Tensor): Text index tensor (B, T_text).
- text_lengths (Tensor): Text length tensor (B,).
- feats (Tensor): Feature tensor (B, T_feats, aux_channels).
- feats_lengths (Tensor): Feature length tensor (B,).
- speech (Tensor): Speech waveform tensor (B, T_wav).
"""
# fields = ["text", "text_lengths", "feats", "feats_lengths", "speech"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
speech = [np.array(item["wave"], dtype=np.float32) for item in examples]
text_lengths = [
np.array(item["text_lengths"], dtype=np.int64) for item in examples
]
feats_lengths = [
np.array(item["feats_lengths"], dtype=np.int64) for item in examples
]
text = batch_sequences(text)
feats = batch_sequences(feats)
speech = batch_sequences(speech)
# convert each batch to paddle.Tensor
text = paddle.to_tensor(text)
feats = paddle.to_tensor(feats)
text_lengths = paddle.to_tensor(text_lengths)
feats_lengths = paddle.to_tensor(feats_lengths)
batch = {
"text": text,
"text_lengths": text_lengths,
"feats": feats,
"feats_lengths": feats_lengths,
"speech": speech
}
return batch
...@@ -167,7 +167,6 @@ def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32): ...@@ -167,7 +167,6 @@ def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32):
def batch_sequences(sequences, axis=0, pad_value=0): def batch_sequences(sequences, axis=0, pad_value=0):
# import pdb; pdb.set_trace()
seq = sequences[0] seq = sequences[0]
ndim = seq.ndim ndim = seq.ndim
if axis < 0: if axis < 0:
......
...@@ -20,15 +20,14 @@ from scipy.interpolate import interp1d ...@@ -20,15 +20,14 @@ from scipy.interpolate import interp1d
class LogMelFBank(): class LogMelFBank():
def __init__(self, def __init__(self,
sr=24000, sr: int=24000,
n_fft=2048, n_fft: int=2048,
hop_length=300, hop_length: int=300,
win_length=None, win_length: int=None,
window="hann", window: str="hann",
n_mels=80, n_mels: int=80,
fmin=80, fmin: int=80,
fmax=7600, fmax: int=7600):
eps=1e-10):
self.sr = sr self.sr = sr
# stft # stft
self.n_fft = n_fft self.n_fft = n_fft
...@@ -54,7 +53,7 @@ class LogMelFBank(): ...@@ -54,7 +53,7 @@ class LogMelFBank():
fmax=self.fmax) fmax=self.fmax)
return mel_filter return mel_filter
def _stft(self, wav): def _stft(self, wav: np.ndarray):
D = librosa.core.stft( D = librosa.core.stft(
wav, wav,
n_fft=self.n_fft, n_fft=self.n_fft,
...@@ -65,11 +64,11 @@ class LogMelFBank(): ...@@ -65,11 +64,11 @@ class LogMelFBank():
pad_mode=self.pad_mode) pad_mode=self.pad_mode)
return D return D
def _spectrogram(self, wav): def _spectrogram(self, wav: np.ndarray):
D = self._stft(wav) D = self._stft(wav)
return np.abs(D) return np.abs(D)
def _mel_spectrogram(self, wav): def _mel_spectrogram(self, wav: np.ndarray):
S = self._spectrogram(wav) S = self._spectrogram(wav)
mel = np.dot(self.mel_filter, S) mel = np.dot(self.mel_filter, S)
return mel return mel
...@@ -90,14 +89,18 @@ class LogMelFBank(): ...@@ -90,14 +89,18 @@ class LogMelFBank():
class Pitch(): class Pitch():
def __init__(self, sr=24000, hop_length=300, f0min=80, f0max=7600): def __init__(self,
sr: int=24000,
hop_length: int=300,
f0min: int=80,
f0max: int=7600):
self.sr = sr self.sr = sr
self.hop_length = hop_length self.hop_length = hop_length
self.f0min = f0min self.f0min = f0min
self.f0max = f0max self.f0max = f0max
def _convert_to_continuous_f0(self, f0: np.array) -> np.array: def _convert_to_continuous_f0(self, f0: np.ndarray) -> np.ndarray:
if (f0 == 0).all(): if (f0 == 0).all():
print("All frames seems to be unvoiced.") print("All frames seems to be unvoiced.")
return f0 return f0
...@@ -120,9 +123,9 @@ class Pitch(): ...@@ -120,9 +123,9 @@ class Pitch():
return f0 return f0
def _calculate_f0(self, def _calculate_f0(self,
input: np.array, input: np.ndarray,
use_continuous_f0=True, use_continuous_f0: bool=True,
use_log_f0=True) -> np.array: use_log_f0: bool=True) -> np.ndarray:
input = input.astype(np.float) input = input.astype(np.float)
frame_period = 1000 * self.hop_length / self.sr frame_period = 1000 * self.hop_length / self.sr
f0, timeaxis = pyworld.dio( f0, timeaxis = pyworld.dio(
...@@ -139,7 +142,8 @@ class Pitch(): ...@@ -139,7 +142,8 @@ class Pitch():
f0[nonzero_idxs] = np.log(f0[nonzero_idxs]) f0[nonzero_idxs] = np.log(f0[nonzero_idxs])
return f0.reshape(-1) return f0.reshape(-1)
def _average_by_duration(self, input: np.array, d: np.array) -> np.array: def _average_by_duration(self, input: np.ndarray,
d: np.ndarray) -> np.ndarray:
d_cumsum = np.pad(d.cumsum(0), (1, 0), 'constant') d_cumsum = np.pad(d.cumsum(0), (1, 0), 'constant')
arr_list = [] arr_list = []
for start, end in zip(d_cumsum[:-1], d_cumsum[1:]): for start, end in zip(d_cumsum[:-1], d_cumsum[1:]):
...@@ -154,11 +158,11 @@ class Pitch(): ...@@ -154,11 +158,11 @@ class Pitch():
return arr_list return arr_list
def get_pitch(self, def get_pitch(self,
wav, wav: np.ndarray,
use_continuous_f0=True, use_continuous_f0: bool=True,
use_log_f0=True, use_log_f0: bool=True,
use_token_averaged_f0=True, use_token_averaged_f0: bool=True,
duration=None): duration: np.ndarray=None):
f0 = self._calculate_f0(wav, use_continuous_f0, use_log_f0) f0 = self._calculate_f0(wav, use_continuous_f0, use_log_f0)
if use_token_averaged_f0 and duration is not None: if use_token_averaged_f0 and duration is not None:
f0 = self._average_by_duration(f0, duration) f0 = self._average_by_duration(f0, duration)
...@@ -167,15 +171,13 @@ class Pitch(): ...@@ -167,15 +171,13 @@ class Pitch():
class Energy(): class Energy():
def __init__(self, def __init__(self,
sr=24000, n_fft: int=2048,
n_fft=2048, hop_length: int=300,
hop_length=300, win_length: int=None,
win_length=None, window: str="hann",
window="hann", center: bool=True,
center=True, pad_mode: str="reflect"):
pad_mode="reflect"):
self.sr = sr
self.n_fft = n_fft self.n_fft = n_fft
self.win_length = win_length self.win_length = win_length
self.hop_length = hop_length self.hop_length = hop_length
...@@ -183,7 +185,7 @@ class Energy(): ...@@ -183,7 +185,7 @@ class Energy():
self.center = center self.center = center
self.pad_mode = pad_mode self.pad_mode = pad_mode
def _stft(self, wav): def _stft(self, wav: np.ndarray):
D = librosa.core.stft( D = librosa.core.stft(
wav, wav,
n_fft=self.n_fft, n_fft=self.n_fft,
...@@ -194,7 +196,7 @@ class Energy(): ...@@ -194,7 +196,7 @@ class Energy():
pad_mode=self.pad_mode) pad_mode=self.pad_mode)
return D return D
def _calculate_energy(self, input): def _calculate_energy(self, input: np.ndarray):
input = input.astype(np.float32) input = input.astype(np.float32)
input_stft = self._stft(input) input_stft = self._stft(input)
input_power = np.abs(input_stft)**2 input_power = np.abs(input_stft)**2
...@@ -203,7 +205,8 @@ class Energy(): ...@@ -203,7 +205,8 @@ class Energy():
np.sum(input_power, axis=0), a_min=1.0e-10, a_max=float('inf'))) np.sum(input_power, axis=0), a_min=1.0e-10, a_max=float('inf')))
return energy return energy
def _average_by_duration(self, input: np.array, d: np.array) -> np.array: def _average_by_duration(self, input: np.ndarray,
d: np.ndarray) -> np.ndarray:
d_cumsum = np.pad(d.cumsum(0), (1, 0), 'constant') d_cumsum = np.pad(d.cumsum(0), (1, 0), 'constant')
arr_list = [] arr_list = []
for start, end in zip(d_cumsum[:-1], d_cumsum[1:]): for start, end in zip(d_cumsum[:-1], d_cumsum[1:]):
...@@ -214,8 +217,49 @@ class Energy(): ...@@ -214,8 +217,49 @@ class Energy():
arr_list = np.expand_dims(np.array(arr_list), 0).T arr_list = np.expand_dims(np.array(arr_list), 0).T
return arr_list return arr_list
def get_energy(self, wav, use_token_averaged_energy=True, duration=None): def get_energy(self,
wav: np.ndarray,
use_token_averaged_energy: bool=True,
duration: np.ndarray=None):
energy = self._calculate_energy(wav) energy = self._calculate_energy(wav)
if use_token_averaged_energy and duration is not None: if use_token_averaged_energy and duration is not None:
energy = self._average_by_duration(energy, duration) energy = self._average_by_duration(energy, duration)
return energy return energy
class LinearSpectrogram():
def __init__(
self,
n_fft: int=1024,
win_length: int=None,
hop_length: int=256,
window: str="hann",
center: bool=True, ):
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.center = center
self.n_fft = n_fft
self.pad_mode = "reflect"
def _stft(self, wav: np.ndarray):
D = librosa.core.stft(
wav,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode)
return D
def _spectrogram(self, wav: np.ndarray):
D = self._stft(wav)
return np.abs(D)
def get_linear_spectrogram(self, wav: np.ndarray):
linear_spectrogram = self._spectrogram(wav)
linear_spectrogram = np.clip(
linear_spectrogram, a_min=1e-10, a_max=float("inf"))
return linear_spectrogram.T
...@@ -147,10 +147,17 @@ def process_sentences(config, ...@@ -147,10 +147,17 @@ def process_sentences(config,
spk_emb_dir: Path=None): spk_emb_dir: Path=None):
if nprocs == 1: if nprocs == 1:
results = [] results = []
for fp in fps: for fp in tqdm.tqdm(fps, total=len(fps)):
record = process_sentence(config, fp, sentences, output_dir, record = process_sentence(
mel_extractor, pitch_extractor, config=config,
energy_extractor, cut_sil, spk_emb_dir) fp=fp,
sentences=sentences,
output_dir=output_dir,
mel_extractor=mel_extractor,
pitch_extractor=pitch_extractor,
energy_extractor=energy_extractor,
cut_sil=cut_sil,
spk_emb_dir=spk_emb_dir)
if record: if record:
results.append(record) results.append(record)
else: else:
...@@ -325,7 +332,6 @@ def main(): ...@@ -325,7 +332,6 @@ def main():
f0min=config.f0min, f0min=config.f0min,
f0max=config.f0max) f0max=config.f0max)
energy_extractor = Energy( energy_extractor = Energy(
sr=config.fs,
n_fft=config.n_fft, n_fft=config.n_fft,
hop_length=config.n_shift, hop_length=config.n_shift,
win_length=config.win_length, win_length=config.win_length,
...@@ -334,36 +340,36 @@ def main(): ...@@ -334,36 +340,36 @@ def main():
# process for the 3 sections # process for the 3 sections
if train_wav_files: if train_wav_files:
process_sentences( process_sentences(
config, config=config,
train_wav_files, fps=train_wav_files,
sentences, sentences=sentences,
train_dump_dir, output_dir=train_dump_dir,
mel_extractor, mel_extractor=mel_extractor,
pitch_extractor, pitch_extractor=pitch_extractor,
energy_extractor, energy_extractor=energy_extractor,
nprocs=args.num_cpu, nprocs=args.num_cpu,
cut_sil=args.cut_sil, cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir) spk_emb_dir=spk_emb_dir)
if dev_wav_files: if dev_wav_files:
process_sentences( process_sentences(
config, config=config,
dev_wav_files, fps=dev_wav_files,
sentences, sentences=sentences,
dev_dump_dir, output_dir=dev_dump_dir,
mel_extractor, mel_extractor=mel_extractor,
pitch_extractor, pitch_extractor=pitch_extractor,
energy_extractor, energy_extractor=energy_extractor,
cut_sil=args.cut_sil, cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir) spk_emb_dir=spk_emb_dir)
if test_wav_files: if test_wav_files:
process_sentences( process_sentences(
config, config=config,
test_wav_files, fps=test_wav_files,
sentences, sentences=sentences,
test_dump_dir, output_dir=test_dump_dir,
mel_extractor, mel_extractor=mel_extractor,
pitch_extractor, pitch_extractor=pitch_extractor,
energy_extractor, energy_extractor=energy_extractor,
nprocs=args.num_cpu, nprocs=args.num_cpu,
cut_sil=args.cut_sil, cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir) spk_emb_dir=spk_emb_dir)
......
...@@ -88,15 +88,17 @@ def process_sentence(config: Dict[str, Any], ...@@ -88,15 +88,17 @@ def process_sentence(config: Dict[str, Any],
y, (0, num_frames * config.n_shift - y.size), mode="reflect") y, (0, num_frames * config.n_shift - y.size), mode="reflect")
else: else:
y = y[:num_frames * config.n_shift] y = y[:num_frames * config.n_shift]
num_sample = y.shape[0] num_samples = y.shape[0]
mel_path = output_dir / (utt_id + "_feats.npy") mel_path = output_dir / (utt_id + "_feats.npy")
wav_path = output_dir / (utt_id + "_wave.npy") wav_path = output_dir / (utt_id + "_wave.npy")
np.save(wav_path, y) # (num_samples, ) # (num_samples, )
np.save(mel_path, logmel) # (num_frames, n_mels) np.save(wav_path, y)
# (num_frames, n_mels)
np.save(mel_path, logmel)
record = { record = {
"utt_id": utt_id, "utt_id": utt_id,
"num_samples": num_sample, "num_samples": num_samples,
"num_frames": num_frames, "num_frames": num_frames,
"feats": str(mel_path), "feats": str(mel_path),
"wave": str(wav_path), "wave": str(wav_path),
...@@ -111,11 +113,17 @@ def process_sentences(config, ...@@ -111,11 +113,17 @@ def process_sentences(config,
mel_extractor=None, mel_extractor=None,
nprocs: int=1, nprocs: int=1,
cut_sil: bool=True): cut_sil: bool=True):
if nprocs == 1: if nprocs == 1:
results = [] results = []
for fp in tqdm.tqdm(fps, total=len(fps)): for fp in tqdm.tqdm(fps, total=len(fps)):
record = process_sentence(config, fp, sentences, output_dir, record = process_sentence(
mel_extractor, cut_sil) config=config,
fp=fp,
sentences=sentences,
output_dir=output_dir,
mel_extractor=mel_extractor,
cut_sil=cut_sil)
if record: if record:
results.append(record) results.append(record)
else: else:
...@@ -150,7 +158,7 @@ def main(): ...@@ -150,7 +158,7 @@ def main():
"--dataset", "--dataset",
default="baker", default="baker",
type=str, type=str,
help="name of dataset, should in {baker, ljspeech, vctk} now") help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now")
parser.add_argument( parser.add_argument(
"--rootdir", default=None, type=str, help="directory to dataset.") "--rootdir", default=None, type=str, help="directory to dataset.")
parser.add_argument( parser.add_argument(
...@@ -264,28 +272,28 @@ def main(): ...@@ -264,28 +272,28 @@ def main():
# process for the 3 sections # process for the 3 sections
if train_wav_files: if train_wav_files:
process_sentences( process_sentences(
config, config=config,
train_wav_files, fps=train_wav_files,
sentences, sentences=sentences,
train_dump_dir, output_dir=train_dump_dir,
mel_extractor=mel_extractor, mel_extractor=mel_extractor,
nprocs=args.num_cpu, nprocs=args.num_cpu,
cut_sil=args.cut_sil) cut_sil=args.cut_sil)
if dev_wav_files: if dev_wav_files:
process_sentences( process_sentences(
config, config=config,
dev_wav_files, fps=dev_wav_files,
sentences, sentences=sentences,
dev_dump_dir, output_dir=dev_dump_dir,
mel_extractor=mel_extractor, mel_extractor=mel_extractor,
nprocs=args.num_cpu, nprocs=args.num_cpu,
cut_sil=args.cut_sil) cut_sil=args.cut_sil)
if test_wav_files: if test_wav_files:
process_sentences( process_sentences(
config, config=config,
test_wav_files, fps=test_wav_files,
sentences, sentences=sentences,
test_dump_dir, output_dir=test_dump_dir,
mel_extractor=mel_extractor, mel_extractor=mel_extractor,
nprocs=args.num_cpu, nprocs=args.num_cpu,
cut_sil=args.cut_sil) cut_sil=args.cut_sil)
......
...@@ -126,11 +126,17 @@ def process_sentences(config, ...@@ -126,11 +126,17 @@ def process_sentences(config,
nprocs: int=1, nprocs: int=1,
cut_sil: bool=True, cut_sil: bool=True,
use_relative_path: bool=False): use_relative_path: bool=False):
if nprocs == 1: if nprocs == 1:
results = [] results = []
for fp in tqdm.tqdm(fps, total=len(fps)): for fp in tqdm.tqdm(fps, total=len(fps)):
record = process_sentence(config, fp, sentences, output_dir, record = process_sentence(
mel_extractor, cut_sil) config=config,
fp=fp,
sentences=sentences,
output_dir=output_dir,
mel_extractor=mel_extractor,
cut_sil=cut_sil)
if record: if record:
results.append(record) results.append(record)
else: else:
...@@ -268,30 +274,30 @@ def main(): ...@@ -268,30 +274,30 @@ def main():
# process for the 3 sections # process for the 3 sections
if train_wav_files: if train_wav_files:
process_sentences( process_sentences(
config, config=config,
train_wav_files, fps=train_wav_files,
sentences, sentences=sentences,
train_dump_dir, output_dir=train_dump_dir,
mel_extractor, mel_extractor=mel_extractor,
nprocs=args.num_cpu, nprocs=args.num_cpu,
cut_sil=args.cut_sil, cut_sil=args.cut_sil,
use_relative_path=args.use_relative_path) use_relative_path=args.use_relative_path)
if dev_wav_files: if dev_wav_files:
process_sentences( process_sentences(
config, config=config,
dev_wav_files, fps=dev_wav_files,
sentences, sentences=sentences,
dev_dump_dir, output_dir=dev_dump_dir,
mel_extractor, mel_extractor=mel_extractor,
cut_sil=args.cut_sil, cut_sil=args.cut_sil,
use_relative_path=args.use_relative_path) use_relative_path=args.use_relative_path)
if test_wav_files: if test_wav_files:
process_sentences( process_sentences(
config, config=config,
test_wav_files, fps=test_wav_files,
sentences, sentences=sentences,
test_dump_dir, output_dir=test_dump_dir,
mel_extractor, mel_extractor=mel_extractor,
nprocs=args.num_cpu, nprocs=args.num_cpu,
cut_sil=args.cut_sil, cut_sil=args.cut_sil,
use_relative_path=args.use_relative_path) use_relative_path=args.use_relative_path)
......
...@@ -176,7 +176,10 @@ def main(): ...@@ -176,7 +176,10 @@ def main():
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu.") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu.")
parser.add_argument( parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.") "--nxpu",
type=int,
default=0,
help="if nxpu == 0 and ngpu == 0, use cpu.")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
......
...@@ -188,7 +188,10 @@ def main(): ...@@ -188,7 +188,10 @@ def main():
parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument( parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.") "--nxpu",
type=int,
default=0,
help="if nxpu == 0 and ngpu == 0, use cpu.")
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu")
......
...@@ -27,11 +27,11 @@ from paddle import jit ...@@ -27,11 +27,11 @@ from paddle import jit
from paddle.static import InputSpec from paddle.static import InputSpec
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import
model_alias = { model_alias = {
# acoustic model # acoustic model
......
...@@ -125,7 +125,7 @@ def evaluate(args): ...@@ -125,7 +125,7 @@ def evaluate(args):
def parse_args(): def parse_args():
# parse args and config and redirect to train_sp # parse args and config
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Synthesize with acoustic model & vocoder") description="Synthesize with acoustic model & vocoder")
# acoustic model # acoustic model
...@@ -143,7 +143,7 @@ def parse_args(): ...@@ -143,7 +143,7 @@ def parse_args():
'--am_config', '--am_config',
type=str, type=str,
default=None, default=None,
help='Config of acoustic model. Use deault config when it is None.') help='Config of acoustic model.')
parser.add_argument( parser.add_argument(
'--am_ckpt', '--am_ckpt',
type=str, type=str,
...@@ -182,7 +182,7 @@ def parse_args(): ...@@ -182,7 +182,7 @@ def parse_args():
'--voc_config', '--voc_config',
type=str, type=str,
default=None, default=None,
help='Config of voc. Use deault config when it is None.') help='Config of voc.')
parser.add_argument( parser.add_argument(
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument( parser.add_argument(
......
...@@ -159,7 +159,7 @@ def evaluate(args): ...@@ -159,7 +159,7 @@ def evaluate(args):
def parse_args(): def parse_args():
# parse args and config and redirect to train_sp # parse args and config
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Synthesize with acoustic model & vocoder") description="Synthesize with acoustic model & vocoder")
# acoustic model # acoustic model
...@@ -177,7 +177,7 @@ def parse_args(): ...@@ -177,7 +177,7 @@ def parse_args():
'--am_config', '--am_config',
type=str, type=str,
default=None, default=None,
help='Config of acoustic model. Use deault config when it is None.') help='Config of acoustic model.')
parser.add_argument( parser.add_argument(
'--am_ckpt', '--am_ckpt',
type=str, type=str,
...@@ -223,7 +223,7 @@ def parse_args(): ...@@ -223,7 +223,7 @@ def parse_args():
'--voc_config', '--voc_config',
type=str, type=str,
default=None, default=None,
help='Config of voc. Use deault config when it is None.') help='Config of voc.')
parser.add_argument( parser.add_argument(
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument( parser.add_argument(
......
...@@ -24,7 +24,6 @@ from paddle.static import InputSpec ...@@ -24,7 +24,6 @@ from paddle.static import InputSpec
from timer import timer from timer import timer
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.exps.syn_utils import denorm from paddlespeech.t2s.exps.syn_utils import denorm
from paddlespeech.t2s.exps.syn_utils import get_chunks from paddlespeech.t2s.exps.syn_utils import get_chunks
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
...@@ -33,6 +32,7 @@ from paddlespeech.t2s.exps.syn_utils import get_voc_inference ...@@ -33,6 +32,7 @@ from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.exps.syn_utils import model_alias from paddlespeech.t2s.exps.syn_utils import model_alias
from paddlespeech.t2s.exps.syn_utils import voc_to_static from paddlespeech.t2s.exps.syn_utils import voc_to_static
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
from paddlespeech.utils.dynamic_import import dynamic_import
def evaluate(args): def evaluate(args):
...@@ -201,7 +201,7 @@ def evaluate(args): ...@@ -201,7 +201,7 @@ def evaluate(args):
def parse_args(): def parse_args():
# parse args and config and redirect to train_sp # parse args and config
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Synthesize with acoustic model & vocoder") description="Synthesize with acoustic model & vocoder")
# acoustic model # acoustic model
...@@ -212,10 +212,7 @@ def parse_args(): ...@@ -212,10 +212,7 @@ def parse_args():
choices=['fastspeech2_csmsc'], choices=['fastspeech2_csmsc'],
help='Choose acoustic model type of tts task.') help='Choose acoustic model type of tts task.')
parser.add_argument( parser.add_argument(
'--am_config', '--am_config', type=str, default=None, help='Config of acoustic model.')
type=str,
default=None,
help='Config of acoustic model. Use deault config when it is None.')
parser.add_argument( parser.add_argument(
'--am_ckpt', '--am_ckpt',
type=str, type=str,
...@@ -245,10 +242,7 @@ def parse_args(): ...@@ -245,10 +242,7 @@ def parse_args():
], ],
help='Choose vocoder type of tts task.') help='Choose vocoder type of tts task.')
parser.add_argument( parser.add_argument(
'--voc_config', '--voc_config', type=str, default=None, help='Config of voc.')
type=str,
default=None,
help='Config of voc. Use deault config when it is None.')
parser.add_argument( parser.add_argument(
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument( parser.add_argument(
......
...@@ -125,9 +125,15 @@ def process_sentences(config, ...@@ -125,9 +125,15 @@ def process_sentences(config,
spk_emb_dir: Path=None): spk_emb_dir: Path=None):
if nprocs == 1: if nprocs == 1:
results = [] results = []
for fp in fps: for fp in tqdm.tqdm(fps, total=len(fps)):
record = process_sentence(config, fp, sentences, output_dir, record = process_sentence(
mel_extractor, cut_sil, spk_emb_dir) config=config,
fp=fp,
sentences=sentences,
output_dir=output_dir,
mel_extractor=mel_extractor,
cut_sil=cut_sil,
spk_emb_dir=spk_emb_dir)
if record: if record:
results.append(record) results.append(record)
else: else:
...@@ -299,30 +305,30 @@ def main(): ...@@ -299,30 +305,30 @@ def main():
# process for the 3 sections # process for the 3 sections
if train_wav_files: if train_wav_files:
process_sentences( process_sentences(
config, config=config,
train_wav_files, fps=train_wav_files,
sentences, sentences=sentences,
train_dump_dir, output_dir=train_dump_dir,
mel_extractor, mel_extractor=mel_extractor,
nprocs=args.num_cpu, nprocs=args.num_cpu,
cut_sil=args.cut_sil, cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir) spk_emb_dir=spk_emb_dir)
if dev_wav_files: if dev_wav_files:
process_sentences( process_sentences(
config, config=config,
dev_wav_files, fps=dev_wav_files,
sentences, sentences=sentences,
dev_dump_dir, output_dir=dev_dump_dir,
mel_extractor, mel_extractor=mel_extractor,
cut_sil=args.cut_sil, cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir) spk_emb_dir=spk_emb_dir)
if test_wav_files: if test_wav_files:
process_sentences( process_sentences(
config, config=config,
test_wav_files, fps=test_wav_files,
sentences, sentences=sentences,
test_dump_dir, output_dir=test_dump_dir,
mel_extractor, mel_extractor=mel_extractor,
nprocs=args.num_cpu, nprocs=args.num_cpu,
cut_sil=args.cut_sil, cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir) spk_emb_dir=spk_emb_dir)
......
...@@ -125,11 +125,16 @@ def process_sentences(config, ...@@ -125,11 +125,16 @@ def process_sentences(config,
output_dir: Path, output_dir: Path,
mel_extractor=None, mel_extractor=None,
nprocs: int=1): nprocs: int=1):
if nprocs == 1: if nprocs == 1:
results = [] results = []
for fp in tqdm.tqdm(fps, total=len(fps)): for fp in tqdm.tqdm(fps, total=len(fps)):
record = process_sentence(config, fp, sentences, output_dir, record = process_sentence(
mel_extractor) config=config,
fp=fp,
sentences=sentences,
output_dir=output_dir,
mel_extractor=mel_extractor)
if record: if record:
results.append(record) results.append(record)
else: else:
...@@ -247,27 +252,27 @@ def main(): ...@@ -247,27 +252,27 @@ def main():
# process for the 3 sections # process for the 3 sections
if train_wav_files: if train_wav_files:
process_sentences( process_sentences(
config, config=config,
train_wav_files, fps=train_wav_files,
sentences, sentences=sentences,
train_dump_dir, output_dir=train_dump_dir,
mel_extractor, mel_extractor=mel_extractor,
nprocs=args.num_cpu) nprocs=args.num_cpu)
if dev_wav_files: if dev_wav_files:
process_sentences( process_sentences(
config, config=config,
dev_wav_files, fps=dev_wav_files,
sentences, sentences=sentences,
dev_dump_dir, output_dir=dev_dump_dir,
mel_extractor, mel_extractor=mel_extractor,
nprocs=args.num_cpu) nprocs=args.num_cpu)
if test_wav_files: if test_wav_files:
process_sentences( process_sentences(
config, config=config,
test_wav_files, fps=test_wav_files,
sentences, sentences=sentences,
test_dump_dir, output_dir=test_dump_dir,
mel_extractor, mel_extractor=mel_extractor,
nprocs=args.num_cpu) nprocs=args.num_cpu)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalize feature files and dump them."""
import argparse
import logging
from operator import itemgetter
from pathlib import Path
import jsonlines
import numpy as np
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from paddlespeech.t2s.datasets.data_table import DataTable
def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
)
parser.add_argument(
"--metadata",
type=str,
required=True,
help="directory including feature files to be normalized. "
"you need to specify either *-scp or rootdir.")
parser.add_argument(
"--dumpdir",
type=str,
required=True,
help="directory to dump normalized feature files.")
parser.add_argument(
"--feats-stats",
type=str,
required=True,
help="speech statistics file.")
parser.add_argument(
"--skip-wav-copy",
default=False,
action="store_true",
help="whether to skip the copy of wav files.")
parser.add_argument(
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker-dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
args = parser.parse_args()
# set logger
if args.verbose > 1:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
elif args.verbose > 0:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
logging.warning('Skip DEBUG/INFO messages')
dumpdir = Path(args.dumpdir).expanduser()
# use absolute path
dumpdir = dumpdir.resolve()
dumpdir.mkdir(parents=True, exist_ok=True)
# get dataset
with jsonlines.open(args.metadata, 'r') as reader:
metadata = list(reader)
dataset = DataTable(
metadata,
converters={
"feats": np.load,
"wave": None if args.skip_wav_copy else np.load,
})
logging.info(f"The number of files = {len(dataset)}.")
# restore scaler
feats_scaler = StandardScaler()
feats_scaler.mean_ = np.load(args.feats_stats)[0]
feats_scaler.scale_ = np.load(args.feats_stats)[1]
feats_scaler.n_features_in_ = feats_scaler.mean_.shape[0]
vocab_phones = {}
with open(args.phones_dict, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
vocab_phones[phn] = int(id)
vocab_speaker = {}
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
for spk, id in spk_id:
vocab_speaker[spk] = int(id)
# process each file
output_metadata = []
for item in tqdm(dataset):
utt_id = item['utt_id']
feats = item['feats']
wave = item['wave']
# normalize
feats = feats_scaler.transform(feats)
feats_path = dumpdir / f"{utt_id}_feats.npy"
np.save(feats_path, feats.astype(np.float32), allow_pickle=False)
if not args.skip_wav_copy:
wav_path = dumpdir / f"{utt_id}_wave.npy"
np.save(wav_path, wave.astype(np.float32), allow_pickle=False)
else:
wav_path = wave
phone_ids = [vocab_phones[p] for p in item['phones']]
spk_id = vocab_speaker[item["speaker"]]
record = {
"utt_id": item['utt_id'],
"text": phone_ids,
"text_lengths": item['text_lengths'],
'feats': str(feats_path),
"feats_lengths": item['feats_lengths'],
"wave": str(wav_path),
"spk_id": spk_id,
}
# add spk_emb for voice cloning
if "spk_emb" in item:
record["spk_emb"] = str(item["spk_emb"])
output_metadata.append(record)
output_metadata.sort(key=itemgetter('utt_id'))
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
with jsonlines.open(output_metadata_path, 'w') as writer:
for item in output_metadata:
writer.write(item)
logging.info(f"metadata dumped into {output_metadata_path}")
if __name__ == "__main__":
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from concurrent.futures import ThreadPoolExecutor
from operator import itemgetter
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
import jsonlines
import librosa
import numpy as np
import tqdm
import yaml
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.get_feats import LinearSpectrogram
from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length
from paddlespeech.t2s.datasets.preprocess_utils import get_input_token
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
from paddlespeech.t2s.utils import str2bool
def process_sentence(config: Dict[str, Any],
fp: Path,
sentences: Dict,
output_dir: Path,
spec_extractor=None,
cut_sil: bool=True,
spk_emb_dir: Path=None):
utt_id = fp.stem
# for vctk
if utt_id.endswith("_mic2"):
utt_id = utt_id[:-5]
record = None
if utt_id in sentences:
# reading, resampling may occur
wav, _ = librosa.load(str(fp), sr=config.fs)
if len(wav.shape) != 1:
return record
max_value = np.abs(wav).max()
if max_value > 1.0:
wav = wav / max_value
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
assert np.abs(wav).max(
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
phones = sentences[utt_id][0]
durations = sentences[utt_id][1]
speaker = sentences[utt_id][2]
d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant')
# little imprecise than use *.TextGrid directly
times = librosa.frames_to_time(
d_cumsum, sr=config.fs, hop_length=config.n_shift)
if cut_sil:
start = 0
end = d_cumsum[-1]
if phones[0] == "sil" and len(durations) > 1:
start = times[1]
durations = durations[1:]
phones = phones[1:]
if phones[-1] == 'sil' and len(durations) > 1:
end = times[-2]
durations = durations[:-1]
phones = phones[:-1]
sentences[utt_id][0] = phones
sentences[utt_id][1] = durations
start, end = librosa.time_to_samples([start, end], sr=config.fs)
wav = wav[start:end]
# extract mel feats
spec = spec_extractor.get_linear_spectrogram(wav)
# change duration according to mel_length
compare_duration_and_mel_length(sentences, utt_id, spec)
# utt_id may be popped in compare_duration_and_mel_length
if utt_id not in sentences:
return None
phones = sentences[utt_id][0]
durations = sentences[utt_id][1]
num_frames = spec.shape[0]
assert sum(durations) == num_frames
if wav.size < num_frames * config.n_shift:
wav = np.pad(
wav, (0, num_frames * config.n_shift - wav.size),
mode="reflect")
else:
wav = wav[:num_frames * config.n_shift]
num_samples = wav.shape[0]
spec_path = output_dir / (utt_id + "_feats.npy")
wav_path = output_dir / (utt_id + "_wave.npy")
# (num_samples, )
np.save(wav_path, wav)
# (num_frames, aux_channels)
np.save(spec_path, spec)
record = {
"utt_id": utt_id,
"phones": phones,
"text_lengths": len(phones),
"feats": str(spec_path),
"feats_lengths": num_frames,
"wave": str(wav_path),
"speaker": speaker
}
if spk_emb_dir:
if speaker in os.listdir(spk_emb_dir):
embed_name = utt_id + ".npy"
embed_path = spk_emb_dir / speaker / embed_name
if embed_path.is_file():
record["spk_emb"] = str(embed_path)
else:
return None
return record
def process_sentences(config,
fps: List[Path],
sentences: Dict,
output_dir: Path,
spec_extractor=None,
nprocs: int=1,
cut_sil: bool=True,
spk_emb_dir: Path=None):
if nprocs == 1:
results = []
for fp in tqdm.tqdm(fps, total=len(fps)):
record = process_sentence(
config=config,
fp=fp,
sentences=sentences,
output_dir=output_dir,
spec_extractor=spec_extractor,
cut_sil=cut_sil,
spk_emb_dir=spk_emb_dir)
if record:
results.append(record)
else:
with ThreadPoolExecutor(nprocs) as pool:
futures = []
with tqdm.tqdm(total=len(fps)) as progress:
for fp in fps:
future = pool.submit(process_sentence, config, fp,
sentences, output_dir, spec_extractor,
cut_sil, spk_emb_dir)
future.add_done_callback(lambda p: progress.update())
futures.append(future)
results = []
for ft in futures:
record = ft.result()
if record:
results.append(record)
results.sort(key=itemgetter("utt_id"))
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
for item in results:
writer.write(item)
print("Done")
def main():
# parse config and args
parser = argparse.ArgumentParser(
description="Preprocess audio and then extract features.")
parser.add_argument(
"--dataset",
default="baker",
type=str,
help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now")
parser.add_argument(
"--rootdir", default=None, type=str, help="directory to dataset.")
parser.add_argument(
"--dumpdir",
type=str,
required=True,
help="directory to dump feature files.")
parser.add_argument(
"--dur-file", default=None, type=str, help="path to durations.txt.")
parser.add_argument("--config", type=str, help="fastspeech2 config file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.")
parser.add_argument(
"--cut-sil",
type=str2bool,
default=True,
help="whether cut sil in the edge of audio")
parser.add_argument(
"--spk_emb_dir",
default=None,
type=str,
help="directory to speaker embedding files.")
args = parser.parse_args()
rootdir = Path(args.rootdir).expanduser()
dumpdir = Path(args.dumpdir).expanduser()
# use absolute path
dumpdir = dumpdir.resolve()
dumpdir.mkdir(parents=True, exist_ok=True)
dur_file = Path(args.dur_file).expanduser()
if args.spk_emb_dir:
spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve()
else:
spk_emb_dir = None
assert rootdir.is_dir()
assert dur_file.is_file()
with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
if args.verbose > 1:
print(vars(args))
print(config)
sentences, speaker_set = get_phn_dur(dur_file)
merge_silence(sentences)
phone_id_map_path = dumpdir / "phone_id_map.txt"
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
get_input_token(sentences, phone_id_map_path, args.dataset)
get_spk_id_map(speaker_set, speaker_id_map_path)
if args.dataset == "baker":
wav_files = sorted(list((rootdir / "Wave").rglob("*.wav")))
# split data into 3 sections
num_train = 9800
num_dev = 100
train_wav_files = wav_files[:num_train]
dev_wav_files = wav_files[num_train:num_train + num_dev]
test_wav_files = wav_files[num_train + num_dev:]
elif args.dataset == "aishell3":
sub_num_dev = 5
wav_dir = rootdir / "train" / "wav"
train_wav_files = []
dev_wav_files = []
test_wav_files = []
for speaker in os.listdir(wav_dir):
wav_files = sorted(list((wav_dir / speaker).rglob("*.wav")))
if len(wav_files) > 100:
train_wav_files += wav_files[:-sub_num_dev * 2]
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
test_wav_files += wav_files[-sub_num_dev:]
else:
train_wav_files += wav_files
elif args.dataset == "ljspeech":
wav_files = sorted(list((rootdir / "wavs").rglob("*.wav")))
# split data into 3 sections
num_train = 12900
num_dev = 100
train_wav_files = wav_files[:num_train]
dev_wav_files = wav_files[num_train:num_train + num_dev]
test_wav_files = wav_files[num_train + num_dev:]
elif args.dataset == "vctk":
sub_num_dev = 5
wav_dir = rootdir / "wav48_silence_trimmed"
train_wav_files = []
dev_wav_files = []
test_wav_files = []
for speaker in os.listdir(wav_dir):
wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac")))
if len(wav_files) > 100:
train_wav_files += wav_files[:-sub_num_dev * 2]
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
test_wav_files += wav_files[-sub_num_dev:]
else:
train_wav_files += wav_files
else:
print("dataset should in {baker, aishell3, ljspeech, vctk} now!")
train_dump_dir = dumpdir / "train" / "raw"
train_dump_dir.mkdir(parents=True, exist_ok=True)
dev_dump_dir = dumpdir / "dev" / "raw"
dev_dump_dir.mkdir(parents=True, exist_ok=True)
test_dump_dir = dumpdir / "test" / "raw"
test_dump_dir.mkdir(parents=True, exist_ok=True)
# Extractor
spec_extractor = LinearSpectrogram(
n_fft=config.n_fft,
hop_length=config.n_shift,
win_length=config.win_length,
window=config.window)
# process for the 3 sections
if train_wav_files:
process_sentences(
config=config,
fps=train_wav_files,
sentences=sentences,
output_dir=train_dump_dir,
spec_extractor=spec_extractor,
nprocs=args.num_cpu,
cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir)
if dev_wav_files:
process_sentences(
config=config,
fps=dev_wav_files,
sentences=sentences,
output_dir=dev_dump_dir,
spec_extractor=spec_extractor,
cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir)
if test_wav_files:
process_sentences(
config=config,
fps=test_wav_files,
sentences=sentences,
output_dir=test_dump_dir,
spec_extractor=spec_extractor,
nprocs=args.num_cpu,
cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir)
if __name__ == "__main__":
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
import jsonlines
import paddle
import soundfile as sf
import yaml
from timer import timer
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.vits import VITS
def evaluate(args):
# construct dataset for evaluation
with jsonlines.open(args.test_metadata, 'r') as reader:
test_metadata = list(reader)
# Init body.
with open(args.config) as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
fields = ["utt_id", "text"]
test_dataset = DataTable(data=test_metadata, fields=fields)
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
odim = config.n_fft // 2 + 1
vits = VITS(idim=vocab_size, odim=odim, **config["model"])
vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
vits.eval()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
N = 0
T = 0
for datum in test_dataset:
utt_id = datum["utt_id"]
phone_ids = paddle.to_tensor(datum["text"])
with timer() as t:
with paddle.no_grad():
out = vits.inference(text=phone_ids)
wav = out["wav"]
wav = wav.numpy()
N += wav.size
T += t.elapse
speed = wav.size / t.elapse
rtf = config.fs / speed
print(
f"{utt_id}, wave: {wav.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs)
print(f"{utt_id} done!")
print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }")
def parse_args():
# parse args and config
parser = argparse.ArgumentParser(description="Synthesize with VITS")
# model
parser.add_argument(
'--config', type=str, default=None, help='Config of VITS.')
parser.add_argument(
'--ckpt', type=str, default=None, help='Checkpoint file of VITS.')
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
# other
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument("--test_metadata", type=str, help="test metadata.")
parser.add_argument("--output_dir", type=str, help="output dir.")
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.ngpu == 0:
paddle.set_device("cpu")
elif args.ngpu > 0:
paddle.set_device("gpu")
else:
print("ngpu should >= 0 !")
evaluate(args)
if __name__ == "__main__":
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
import paddle
import soundfile as sf
import yaml
from timer import timer
from yacs.config import CfgNode
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.models.vits import VITS
def evaluate(args):
# Init body.
with open(args.config) as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
sentences = get_sentences(text_file=args.text, lang=args.lang)
# frontend
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
odim = config.n_fft // 2 + 1
vits = VITS(idim=vocab_size, odim=odim, **config["model"])
vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
vits.eval()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
merge_sentences = False
N = 0
T = 0
for utt_id, sentence in sentences:
with timer() as t:
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
with paddle.no_grad():
flags = 0
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i]
out = vits.inference(text=part_phone_ids)
wav = out["wav"]
if flags == 0:
wav_all = wav
flags = 1
else:
wav_all = paddle.concat([wav_all, wav])
wav = wav_all.numpy()
N += wav.size
T += t.elapse
speed = wav.size / t.elapse
rtf = config.fs / speed
print(
f"{utt_id}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs)
print(f"{utt_id} done!")
print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }")
def parse_args():
# parse args and config
parser = argparse.ArgumentParser(description="Synthesize with VITS")
# model
parser.add_argument(
'--config', type=str, default=None, help='Config of VITS.')
parser.add_argument(
'--ckpt', type=str, default=None, help='Checkpoint file of VITS.')
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
# other
parser.add_argument(
'--lang',
type=str,
default='zh',
help='Choose model language. zh or en')
parser.add_argument(
"--inference_dir",
type=str,
default=None,
help="dir to save inference models")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument(
"--text",
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line.")
parser.add_argument("--output_dir", type=str, help="output dir.")
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.ngpu == 0:
paddle.set_device("cpu")
elif args.ngpu > 0:
paddle.set_device("gpu")
else:
print("ngpu should >= 0 !")
evaluate(args)
if __name__ == "__main__":
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import shutil
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import yaml
from paddle import DataParallel
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.vits import VITS
from paddlespeech.t2s.models.vits import VITSEvaluator
from paddlespeech.t2s.models.vits import VITSUpdater
from paddlespeech.t2s.modules.losses import DiscriminatorAdversarialLoss
from paddlespeech.t2s.modules.losses import FeatureMatchLoss
from paddlespeech.t2s.modules.losses import GeneratorAdversarialLoss
from paddlespeech.t2s.modules.losses import KLDivergenceLoss
from paddlespeech.t2s.modules.losses import MelSpectrogramLoss
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.optimizer import scheduler_classes
from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer
def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
world_size = paddle.distributed.get_world_size()
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
paddle.set_device("cpu")
else:
paddle.set_device("gpu")
if world_size > 1:
paddle.distributed.init_parallel_env()
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
print(
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
fields = ["text", "text_lengths", "feats", "feats_lengths", "wave"]
converters = {
"wave": np.load,
"feats": np.load,
}
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=fields,
converters=converters, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=fields,
converters=converters, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True)
dev_sampler = DistributedBatchSampler(
dev_dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=False)
print("samplers done!")
train_batch_fn = vits_single_spk_batch_fn
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=train_batch_fn,
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
batch_sampler=dev_sampler,
collate_fn=train_batch_fn,
num_workers=config.num_workers)
print("dataloaders done!")
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
odim = config.n_fft // 2 + 1
model = VITS(idim=vocab_size, odim=odim, **config["model"])
gen_parameters = model.generator.parameters()
dis_parameters = model.discriminator.parameters()
if world_size > 1:
model = DataParallel(model)
gen_parameters = model._layers.generator.parameters()
dis_parameters = model._layers.discriminator.parameters()
print("model done!")
# loss
criterion_mel = MelSpectrogramLoss(
**config["mel_loss_params"], )
criterion_feat_match = FeatureMatchLoss(
**config["feat_match_loss_params"], )
criterion_gen_adv = GeneratorAdversarialLoss(
**config["generator_adv_loss_params"], )
criterion_dis_adv = DiscriminatorAdversarialLoss(
**config["discriminator_adv_loss_params"], )
criterion_kl = KLDivergenceLoss()
print("criterions done!")
lr_schedule_g = scheduler_classes[config["generator_scheduler"]](
**config["generator_scheduler_params"])
optimizer_g = Adam(
learning_rate=lr_schedule_g,
parameters=gen_parameters,
**config["generator_optimizer_params"])
lr_schedule_d = scheduler_classes[config["discriminator_scheduler"]](
**config["discriminator_scheduler_params"])
optimizer_d = Adam(
learning_rate=lr_schedule_d,
parameters=dis_parameters,
**config["discriminator_optimizer_params"])
print("optimizers done!")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
if dist.get_rank() == 0:
config_name = args.config.split("/")[-1]
# copy conf to output_dir
shutil.copyfile(args.config, output_dir / config_name)
updater = VITSUpdater(
model=model,
optimizers={
"generator": optimizer_g,
"discriminator": optimizer_d,
},
criterions={
"mel": criterion_mel,
"feat_match": criterion_feat_match,
"gen_adv": criterion_gen_adv,
"dis_adv": criterion_dis_adv,
"kl": criterion_kl,
},
schedulers={
"generator": lr_schedule_g,
"discriminator": lr_schedule_d,
},
dataloader=train_dataloader,
lambda_adv=config.lambda_adv,
lambda_mel=config.lambda_mel,
lambda_kl=config.lambda_kl,
lambda_feat_match=config.lambda_feat_match,
lambda_dur=config.lambda_dur,
generator_first=config.generator_first,
output_dir=output_dir)
evaluator = VITSEvaluator(
model=model,
criterions={
"mel": criterion_mel,
"feat_match": criterion_feat_match,
"gen_adv": criterion_gen_adv,
"dis_adv": criterion_dis_adv,
"kl": criterion_kl,
},
dataloader=dev_dataloader,
lambda_adv=config.lambda_adv,
lambda_mel=config.lambda_mel,
lambda_kl=config.lambda_kl,
lambda_feat_match=config.lambda_feat_match,
lambda_dur=config.lambda_dur,
generator_first=config.generator_first,
output_dir=output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
print("Trainer Done!")
trainer.run()
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a HiFiGAN model.")
parser.add_argument(
"--config", type=str, help="config file to overwrite default config.")
parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument(
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
print(
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
)
# dispatch
if args.ngpu > 1:
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
else:
train_sp(args, config)
if __name__ == "__main__":
main()
...@@ -122,7 +122,7 @@ def voice_cloning(args): ...@@ -122,7 +122,7 @@ def voice_cloning(args):
def parse_args(): def parse_args():
# parse args and config and redirect to train_sp # parse args and config
parser = argparse.ArgumentParser(description="") parser = argparse.ArgumentParser(description="")
parser.add_argument( parser.add_argument(
'--am', '--am',
...@@ -134,7 +134,7 @@ def parse_args(): ...@@ -134,7 +134,7 @@ def parse_args():
'--am_config', '--am_config',
type=str, type=str,
default=None, default=None,
help='Config of acoustic model. Use deault config when it is None.') help='Config of acoustic model.')
parser.add_argument( parser.add_argument(
'--am_ckpt', '--am_ckpt',
type=str, type=str,
...@@ -163,7 +163,7 @@ def parse_args(): ...@@ -163,7 +163,7 @@ def parse_args():
'--voc_config', '--voc_config',
type=str, type=str,
default=None, default=None,
help='Config of voc. Use deault config when it is None.') help='Config of voc.')
parser.add_argument( parser.add_argument(
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument( parser.add_argument(
......
...@@ -195,7 +195,7 @@ class Frontend(): ...@@ -195,7 +195,7 @@ class Frontend():
new_initials.append(initials[i]) new_initials.append(initials[i])
return new_initials, new_finals return new_initials, new_finals
def _p2id(self, phonemes: List[str]) -> np.array: def _p2id(self, phonemes: List[str]) -> np.ndarray:
# replace unk phone with sp # replace unk phone with sp
phonemes = [ phonemes = [
phn if phn in self.vocab_phones else "sp" for phn in phonemes phn if phn in self.vocab_phones else "sp" for phn in phonemes
...@@ -203,7 +203,7 @@ class Frontend(): ...@@ -203,7 +203,7 @@ class Frontend():
phone_ids = [self.vocab_phones[item] for item in phonemes] phone_ids = [self.vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64) return np.array(phone_ids, np.int64)
def _t2id(self, tones: List[str]) -> np.array: def _t2id(self, tones: List[str]) -> np.ndarray:
# replace unk phone with sp # replace unk phone with sp
tones = [tone if tone in self.vocab_tones else "0" for tone in tones] tones = [tone if tone in self.vocab_tones else "0" for tone in tones]
tone_ids = [self.vocab_tones[item] for item in tones] tone_ids = [self.vocab_tones[item] for item in tones]
......
...@@ -18,5 +18,6 @@ from .parallel_wavegan import * ...@@ -18,5 +18,6 @@ from .parallel_wavegan import *
from .speedyspeech import * from .speedyspeech import *
from .tacotron2 import * from .tacotron2 import *
from .transformer_tts import * from .transformer_tts import *
from .vits import *
from .waveflow import * from .waveflow import *
from .wavernn import * from .wavernn import *
...@@ -16,6 +16,7 @@ import copy ...@@ -16,6 +16,7 @@ import copy
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Optional
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -34,6 +35,7 @@ class HiFiGANGenerator(nn.Layer): ...@@ -34,6 +35,7 @@ class HiFiGANGenerator(nn.Layer):
in_channels: int=80, in_channels: int=80,
out_channels: int=1, out_channels: int=1,
channels: int=512, channels: int=512,
global_channels: int=-1,
kernel_size: int=7, kernel_size: int=7,
upsample_scales: List[int]=(8, 8, 2, 2), upsample_scales: List[int]=(8, 8, 2, 2),
upsample_kernel_sizes: List[int]=(16, 16, 4, 4), upsample_kernel_sizes: List[int]=(16, 16, 4, 4),
...@@ -51,6 +53,7 @@ class HiFiGANGenerator(nn.Layer): ...@@ -51,6 +53,7 @@ class HiFiGANGenerator(nn.Layer):
in_channels (int): Number of input channels. in_channels (int): Number of input channels.
out_channels (int): Number of output channels. out_channels (int): Number of output channels.
channels (int): Number of hidden representation channels. channels (int): Number of hidden representation channels.
global_channels (int): Number of global conditioning channels.
kernel_size (int): Kernel size of initial and final conv layer. kernel_size (int): Kernel size of initial and final conv layer.
upsample_scales (list): List of upsampling scales. upsample_scales (list): List of upsampling scales.
upsample_kernel_sizes (list): List of kernel sizes for upsampling layers. upsample_kernel_sizes (list): List of kernel sizes for upsampling layers.
...@@ -119,6 +122,9 @@ class HiFiGANGenerator(nn.Layer): ...@@ -119,6 +122,9 @@ class HiFiGANGenerator(nn.Layer):
padding=(kernel_size - 1) // 2, ), padding=(kernel_size - 1) // 2, ),
nn.Tanh(), ) nn.Tanh(), )
if global_channels > 0:
self.global_conv = nn.Conv1D(global_channels, channels, 1)
nn.initializer.set_global_initializer(None) nn.initializer.set_global_initializer(None)
# apply weight norm # apply weight norm
...@@ -128,15 +134,18 @@ class HiFiGANGenerator(nn.Layer): ...@@ -128,15 +134,18 @@ class HiFiGANGenerator(nn.Layer):
# reset parameters # reset parameters
self.reset_parameters() self.reset_parameters()
def forward(self, c): def forward(self, c, g: Optional[paddle.Tensor]=None):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
c (Tensor): Input tensor (B, in_channels, T). c (Tensor): Input tensor (B, in_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns: Returns:
Tensor: Output tensor (B, out_channels, T). Tensor: Output tensor (B, out_channels, T).
""" """
c = self.input_conv(c) c = self.input_conv(c)
if g is not None:
c = c + self.global_conv(g)
for i in range(self.num_upsamples): for i in range(self.num_upsamples):
c = self.upsamples[i](c) c = self.upsamples[i](c)
# initialize # initialize
...@@ -187,16 +196,19 @@ class HiFiGANGenerator(nn.Layer): ...@@ -187,16 +196,19 @@ class HiFiGANGenerator(nn.Layer):
self.apply(_remove_weight_norm) self.apply(_remove_weight_norm)
def inference(self, c): def inference(self, c, g: Optional[paddle.Tensor]=None):
"""Perform inference. """Perform inference.
Args: Args:
c (Tensor): Input tensor (T, in_channels). c (Tensor): Input tensor (T, in_channels).
normalize_before (bool): Whether to perform normalization. normalize_before (bool): Whether to perform normalization.
g (Optional[Tensor]): Global conditioning tensor (global_channels, 1).
Returns: Returns:
Tensor: Tensor:
Output tensor (T ** prod(upsample_scales), out_channels). Output tensor (T ** prod(upsample_scales), out_channels).
""" """
c = self.forward(c.transpose([1, 0]).unsqueeze(0)) if g is not None:
g = g.unsqueeze(0)
c = self.forward(c.transpose([1, 0]).unsqueeze(0), g=g)
return c.squeeze(0).transpose([1, 0]) return c.squeeze(0).transpose([1, 0])
......
...@@ -68,8 +68,8 @@ class PWGUpdater(StandardUpdater): ...@@ -68,8 +68,8 @@ class PWGUpdater(StandardUpdater):
self.discriminator_train_start_steps = discriminator_train_start_steps self.discriminator_train_start_steps = discriminator_train_start_steps
self.lambda_adv = lambda_adv self.lambda_adv = lambda_adv
self.lambda_aux = lambda_aux self.lambda_aux = lambda_aux
self.state = UpdaterState(iteration=0, epoch=0)
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.dataloader) self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
......
...@@ -16,7 +16,6 @@ from pathlib import Path ...@@ -16,7 +16,6 @@ from pathlib import Path
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle.fluid.layers import huber_loss
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import Layer from paddle.nn import Layer
...@@ -78,8 +77,11 @@ class SpeedySpeechUpdater(StandardUpdater): ...@@ -78,8 +77,11 @@ class SpeedySpeechUpdater(StandardUpdater):
target_durations.astype(predicted_durations.dtype), target_durations.astype(predicted_durations.dtype),
paddle.to_tensor([1.0])) paddle.to_tensor([1.0]))
duration_loss = weighted_mean( duration_loss = weighted_mean(
huber_loss( F.smooth_l1_loss(
predicted_durations, paddle.log(target_durations), delta=1.0), predicted_durations,
paddle.log(target_durations),
delta=1.0,
reduction='none', ),
text_mask, ) text_mask, )
# ssim loss # ssim loss
...@@ -146,8 +148,11 @@ class SpeedySpeechEvaluator(StandardEvaluator): ...@@ -146,8 +148,11 @@ class SpeedySpeechEvaluator(StandardEvaluator):
target_durations.astype(predicted_durations.dtype), target_durations.astype(predicted_durations.dtype),
paddle.to_tensor([1.0])) paddle.to_tensor([1.0]))
duration_loss = weighted_mean( duration_loss = weighted_mean(
huber_loss( F.smooth_l1_loss(
predicted_durations, paddle.log(target_durations), delta=1.0), predicted_durations,
paddle.log(target_durations),
delta=1.0,
reduction='none', ),
text_mask, ) text_mask, )
# ssim loss # ssim loss
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .vits import *
from .vits_updater import *
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stochastic duration predictor modules in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Optional
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddlespeech.t2s.models.vits.flow import ConvFlow
from paddlespeech.t2s.models.vits.flow import DilatedDepthSeparableConv
from paddlespeech.t2s.models.vits.flow import ElementwiseAffineFlow
from paddlespeech.t2s.models.vits.flow import FlipFlow
from paddlespeech.t2s.models.vits.flow import LogFlow
class StochasticDurationPredictor(nn.Layer):
"""Stochastic duration predictor module.
This is a module of stochastic duration predictor described in `Conditional
Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2106.06103
"""
def __init__(
self,
channels: int=192,
kernel_size: int=3,
dropout_rate: float=0.5,
flows: int=4,
dds_conv_layers: int=3,
global_channels: int=-1, ):
"""Initialize StochasticDurationPredictor module.
Args:
channels (int): Number of channels.
kernel_size (int): Kernel size.
dropout_rate (float): Dropout rate.
flows (int): Number of flows.
dds_conv_layers (int): Number of conv layers in DDS conv.
global_channels (int): Number of global conditioning channels.
"""
super().__init__()
self.pre = nn.Conv1D(channels, channels, 1)
self.dds = DilatedDepthSeparableConv(
channels,
kernel_size,
layers=dds_conv_layers,
dropout_rate=dropout_rate, )
self.proj = nn.Conv1D(channels, channels, 1)
self.log_flow = LogFlow()
self.flows = nn.LayerList()
self.flows.append(ElementwiseAffineFlow(2))
for i in range(flows):
self.flows.append(
ConvFlow(
2,
channels,
kernel_size,
layers=dds_conv_layers, ))
self.flows.append(FlipFlow())
self.post_pre = nn.Conv1D(1, channels, 1)
self.post_dds = DilatedDepthSeparableConv(
channels,
kernel_size,
layers=dds_conv_layers,
dropout_rate=dropout_rate, )
self.post_proj = nn.Conv1D(channels, channels, 1)
self.post_flows = nn.LayerList()
self.post_flows.append(ElementwiseAffineFlow(2))
for i in range(flows):
self.post_flows.append(
ConvFlow(
2,
channels,
kernel_size,
layers=dds_conv_layers, ))
self.post_flows.append(FlipFlow())
if global_channels > 0:
self.global_conv = nn.Conv1D(global_channels, channels, 1)
def forward(
self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
w: Optional[paddle.Tensor]=None,
g: Optional[paddle.Tensor]=None,
inverse: bool=False,
noise_scale: float=1.0, ) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T_text).
x_mask (Tensor): Mask tensor (B, 1, T_text).
w (Optional[Tensor]): Duration tensor (B, 1, T_text).
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
inverse (bool): Whether to inverse the flow.
noise_scale (float): Noise scale value.
Returns:
Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
If inverse, log-duration tensor (B, 1, T_text).
"""
# stop gradient
# x = x.detach()
x = self.pre(x)
if g is not None:
# stop gradient
x = x + self.global_conv(g.detach())
x = self.dds(x, x_mask)
x = self.proj(x) * x_mask
if not inverse:
assert w is not None, "w must be provided."
h_w = self.post_pre(w)
h_w = self.post_dds(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = (paddle.randn([paddle.shape(w)[0], 2, paddle.shape(w)[2]]) *
x_mask)
z_q = e_q
logdet_tot_q = 0.0
for i, flow in enumerate(self.post_flows):
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = paddle.split(z_q, [1, 1], 1)
u = F.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += paddle.sum(
(F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask, [1, 2])
logq = (paddle.sum(-0.5 *
(math.log(2 * math.pi) +
(e_q**2)) * x_mask, [1, 2]) - logdet_tot_q)
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = paddle.concat([z0, z1], 1)
for flow in self.flows:
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
logdet_tot = logdet_tot + logdet
nll = (paddle.sum(0.5 * (math.log(2 * math.pi) +
(z**2)) * x_mask, [1, 2]) - logdet_tot)
# (B,)
return nll + logq
else:
flows = list(reversed(self.flows))
# remove a useless vflow
flows = flows[:-2] + [flows[-1]]
z = (paddle.randn([paddle.shape(x)[0], 2, paddle.shape(x)[2]]) *
noise_scale)
for flow in flows:
z = flow(z, x_mask, g=x, inverse=inverse)
z0, z1 = paddle.split(z, 2, axis=1)
logw = z0
return logw
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Basic Flow modules used in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Optional
from typing import Tuple
from typing import Union
import paddle
from paddle import nn
from paddlespeech.t2s.models.vits.transform import piecewise_rational_quadratic_transform
class FlipFlow(nn.Layer):
"""Flip flow module."""
def forward(self, x: paddle.Tensor, *args, inverse: bool=False, **kwargs
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Flipped tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
x = paddle.flip(x, [1])
if not inverse:
logdet = paddle.zeros(paddle.shape(x)[0], dtype=x.dtype)
return x, logdet
else:
return x
class LogFlow(nn.Layer):
"""Log flow module."""
def forward(self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
inverse: bool=False,
eps: float=1e-5,
**kwargs
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
inverse (bool): Whether to inverse the flow.
eps (float): Epsilon for log.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
if not inverse:
y = paddle.log(paddle.clip(x, min=eps)) * x_mask
logdet = paddle.sum(-y, [1, 2])
return y, logdet
else:
x = paddle.exp(x) * x_mask
return x
class ElementwiseAffineFlow(nn.Layer):
"""Elementwise affine flow module."""
def __init__(self, channels: int):
"""Initialize ElementwiseAffineFlow module.
Args:
channels (int): Number of channels.
"""
super().__init__()
self.channels = channels
m = paddle.zeros([channels, 1])
self.m = paddle.create_parameter(
shape=m.shape,
dtype=str(m.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(m))
logs = paddle.zeros([channels, 1])
self.logs = paddle.create_parameter(
shape=logs.shape,
dtype=str(logs.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(logs))
def forward(self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
inverse: bool=False,
**kwargs
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
if not inverse:
y = self.m + paddle.exp(self.logs) * x
y = y * x_mask
logdet = paddle.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * paddle.exp(-self.logs) * x_mask
return x
class Transpose(nn.Layer):
"""Transpose module for paddle.nn.Sequential()."""
def __init__(self, dim1: int, dim2: int):
"""Initialize Transpose module."""
super().__init__()
self.dim1 = dim1
self.dim2 = dim2
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
"""Transpose."""
len_dim = len(x.shape)
orig_perm = list(range(len_dim))
new_perm = orig_perm[:]
temp = new_perm[self.dim1]
new_perm[self.dim1] = new_perm[self.dim2]
new_perm[self.dim2] = temp
return paddle.transpose(x, new_perm)
class DilatedDepthSeparableConv(nn.Layer):
"""Dilated depth-separable conv module."""
def __init__(
self,
channels: int,
kernel_size: int,
layers: int,
dropout_rate: float=0.0,
eps: float=1e-5, ):
"""Initialize DilatedDepthSeparableConv module.
Args:
channels (int): Number of channels.
kernel_size (int): Kernel size.
layers (int): Number of layers.
dropout_rate (float): Dropout rate.
eps (float): Epsilon for layer norm.
"""
super().__init__()
self.convs = nn.LayerList()
for i in range(layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs.append(
nn.Sequential(
nn.Conv1D(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding, ),
Transpose(1, 2),
nn.LayerNorm(channels, epsilon=eps),
Transpose(1, 2),
nn.GELU(),
nn.Conv1D(
channels,
channels,
1, ),
Transpose(1, 2),
nn.LayerNorm(channels, epsilon=eps),
Transpose(1, 2),
nn.GELU(),
nn.Dropout(dropout_rate), ))
def forward(self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
g: Optional[paddle.Tensor]=None) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, channels, T).
"""
if g is not None:
x = x + g
for f in self.convs:
y = f(x * x_mask)
x = x + y
return x * x_mask
class ConvFlow(nn.Layer):
"""Convolutional flow module."""
def __init__(
self,
in_channels: int,
hidden_channels: int,
kernel_size: int,
layers: int,
bins: int=10,
tail_bound: float=5.0, ):
"""Initialize ConvFlow module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size.
layers (int): Number of layers.
bins (int): Number of bins.
tail_bound (float): Tail bound value.
"""
super().__init__()
self.half_channels = in_channels // 2
self.hidden_channels = hidden_channels
self.bins = bins
self.tail_bound = tail_bound
self.input_conv = nn.Conv1D(
self.half_channels,
hidden_channels,
1, )
self.dds_conv = DilatedDepthSeparableConv(
hidden_channels,
kernel_size,
layers,
dropout_rate=0.0, )
self.proj = nn.Conv1D(
hidden_channels,
self.half_channels * (bins * 3 - 1),
1, )
weight = paddle.zeros(paddle.shape(self.proj.weight))
self.proj.weight = paddle.create_parameter(
shape=weight.shape,
dtype=str(weight.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(weight))
bias = paddle.zeros(paddle.shape(self.proj.bias))
self.proj.bias = paddle.create_parameter(
shape=bias.shape,
dtype=str(bias.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(bias))
def forward(
self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
g: Optional[paddle.Tensor]=None,
inverse: bool=False,
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = x.split(2, 1)
h = self.input_conv(xa)
h = self.dds_conv(h, x_mask, g=g)
# (B, half_channels * (bins * 3 - 1), T)
h = self.proj(h) * x_mask
b, c, t = xa.shape
# (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
h = h.reshape([b, c, -1, t]).transpose([0, 1, 3, 2])
denom = math.sqrt(self.hidden_channels)
unnorm_widths = h[..., :self.bins] / denom
unnorm_heights = h[..., self.bins:2 * self.bins] / denom
unnorm_derivatives = h[..., 2 * self.bins:]
xb, logdet_abs = piecewise_rational_quadratic_transform(
xb,
unnorm_widths,
unnorm_heights,
unnorm_derivatives,
inverse=inverse,
tails="linear",
tail_bound=self.tail_bound, )
x = paddle.concat([xa, xb], 1) * x_mask
logdet = paddle.sum(logdet_abs * x_mask, [1, 2])
if not inverse:
return x, logdet
else:
return x
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generator module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddlespeech.t2s.models.hifigan import HiFiGANGenerator
from paddlespeech.t2s.models.vits.duration_predictor import StochasticDurationPredictor
from paddlespeech.t2s.models.vits.posterior_encoder import PosteriorEncoder
from paddlespeech.t2s.models.vits.residual_coupling import ResidualAffineCouplingBlock
from paddlespeech.t2s.models.vits.text_encoder import TextEncoder
from paddlespeech.t2s.modules.nets_utils import get_random_segments
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
class VITSGenerator(nn.Layer):
"""Generator module in VITS.
This is a module of VITS described in `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`_.
As text encoder, we use conformer architecture instead of the relative positional
Transformer, which contains additional convolution layers.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
vocabs: int,
aux_channels: int=513,
hidden_channels: int=192,
spks: Optional[int]=None,
langs: Optional[int]=None,
spk_embed_dim: Optional[int]=None,
global_channels: int=-1,
segment_size: int=32,
text_encoder_attention_heads: int=2,
text_encoder_ffn_expand: int=4,
text_encoder_blocks: int=6,
text_encoder_positionwise_layer_type: str="conv1d",
text_encoder_positionwise_conv_kernel_size: int=1,
text_encoder_positional_encoding_layer_type: str="rel_pos",
text_encoder_self_attention_layer_type: str="rel_selfattn",
text_encoder_activation_type: str="swish",
text_encoder_normalize_before: bool=True,
text_encoder_dropout_rate: float=0.1,
text_encoder_positional_dropout_rate: float=0.0,
text_encoder_attention_dropout_rate: float=0.0,
text_encoder_conformer_kernel_size: int=7,
use_macaron_style_in_text_encoder: bool=True,
use_conformer_conv_in_text_encoder: bool=True,
decoder_kernel_size: int=7,
decoder_channels: int=512,
decoder_upsample_scales: List[int]=[8, 8, 2, 2],
decoder_upsample_kernel_sizes: List[int]=[16, 16, 4, 4],
decoder_resblock_kernel_sizes: List[int]=[3, 7, 11],
decoder_resblock_dilations: List[List[int]]=[[1, 3, 5], [1, 3, 5],
[1, 3, 5]],
use_weight_norm_in_decoder: bool=True,
posterior_encoder_kernel_size: int=5,
posterior_encoder_layers: int=16,
posterior_encoder_stacks: int=1,
posterior_encoder_base_dilation: int=1,
posterior_encoder_dropout_rate: float=0.0,
use_weight_norm_in_posterior_encoder: bool=True,
flow_flows: int=4,
flow_kernel_size: int=5,
flow_base_dilation: int=1,
flow_layers: int=4,
flow_dropout_rate: float=0.0,
use_weight_norm_in_flow: bool=True,
use_only_mean_in_flow: bool=True,
stochastic_duration_predictor_kernel_size: int=3,
stochastic_duration_predictor_dropout_rate: float=0.5,
stochastic_duration_predictor_flows: int=4,
stochastic_duration_predictor_dds_conv_layers: int=3, ):
"""Initialize VITS generator module.
Args:
vocabs (int): Input vocabulary size.
aux_channels (int): Number of acoustic feature channels.
hidden_channels (int): Number of hidden channels.
spks (Optional[int]): Number of speakers. If set to > 1, assume that the
sids will be provided as the input and use sid embedding layer.
langs (Optional[int]): Number of languages. If set to > 1, assume that the
lids will be provided as the input and use sid embedding layer.
spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
assume that spembs will be provided as the input.
global_channels (int): Number of global conditioning channels.
segment_size (int): Segment size for decoder.
text_encoder_attention_heads (int): Number of heads in conformer block
of text encoder.
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
of text encoder.
text_encoder_blocks (int): Number of conformer blocks in text encoder.
text_encoder_positionwise_layer_type (str): Position-wise layer type in
conformer block of text encoder.
text_encoder_positionwise_conv_kernel_size (int): Position-wise convolution
kernel size in conformer block of text encoder. Only used when the
above layer type is conv1d or conv1d-linear.
text_encoder_positional_encoding_layer_type (str): Positional encoding layer
type in conformer block of text encoder.
text_encoder_self_attention_layer_type (str): Self-attention layer type in
conformer block of text encoder.
text_encoder_activation_type (str): Activation function type in conformer
block of text encoder.
text_encoder_normalize_before (bool): Whether to apply layer norm before
self-attention in conformer block of text encoder.
text_encoder_dropout_rate (float): Dropout rate in conformer block of
text encoder.
text_encoder_positional_dropout_rate (float): Dropout rate for positional
encoding in conformer block of text encoder.
text_encoder_attention_dropout_rate (float): Dropout rate for attention in
conformer block of text encoder.
text_encoder_conformer_kernel_size (int): Conformer conv kernel size. It
will be used when only use_conformer_conv_in_text_encoder = True.
use_macaron_style_in_text_encoder (bool): Whether to use macaron style FFN
in conformer block of text encoder.
use_conformer_conv_in_text_encoder (bool): Whether to use covolution in
conformer block of text encoder.
decoder_kernel_size (int): Decoder kernel size.
decoder_channels (int): Number of decoder initial channels.
decoder_upsample_scales (List[int]): List of upsampling scales in decoder.
decoder_upsample_kernel_sizes (List[int]): List of kernel size for
upsampling layers in decoder.
decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks
in decoder.
decoder_resblock_dilations (List[List[int]]): List of list of dilations for
resblocks in decoder.
use_weight_norm_in_decoder (bool): Whether to apply weight normalization in
decoder.
posterior_encoder_kernel_size (int): Posterior encoder kernel size.
posterior_encoder_layers (int): Number of layers of posterior encoder.
posterior_encoder_stacks (int): Number of stacks of posterior encoder.
posterior_encoder_base_dilation (int): Base dilation of posterior encoder.
posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder.
use_weight_norm_in_posterior_encoder (bool): Whether to apply weight
normalization in posterior encoder.
flow_flows (int): Number of flows in flow.
flow_kernel_size (int): Kernel size in flow.
flow_base_dilation (int): Base dilation in flow.
flow_layers (int): Number of layers in flow.
flow_dropout_rate (float): Dropout rate in flow
use_weight_norm_in_flow (bool): Whether to apply weight normalization in
flow.
use_only_mean_in_flow (bool): Whether to use only mean in flow.
stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic
duration predictor.
stochastic_duration_predictor_dropout_rate (float): Dropout rate in
stochastic duration predictor.
stochastic_duration_predictor_flows (int): Number of flows in stochastic
duration predictor.
stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv
layers in stochastic duration predictor.
"""
super().__init__()
self.segment_size = segment_size
self.text_encoder = TextEncoder(
vocabs=vocabs,
attention_dim=hidden_channels,
attention_heads=text_encoder_attention_heads,
linear_units=hidden_channels * text_encoder_ffn_expand,
blocks=text_encoder_blocks,
positionwise_layer_type=text_encoder_positionwise_layer_type,
positionwise_conv_kernel_size=text_encoder_positionwise_conv_kernel_size,
positional_encoding_layer_type=text_encoder_positional_encoding_layer_type,
self_attention_layer_type=text_encoder_self_attention_layer_type,
activation_type=text_encoder_activation_type,
normalize_before=text_encoder_normalize_before,
dropout_rate=text_encoder_dropout_rate,
positional_dropout_rate=text_encoder_positional_dropout_rate,
attention_dropout_rate=text_encoder_attention_dropout_rate,
conformer_kernel_size=text_encoder_conformer_kernel_size,
use_macaron_style=use_macaron_style_in_text_encoder,
use_conformer_conv=use_conformer_conv_in_text_encoder, )
self.decoder = HiFiGANGenerator(
in_channels=hidden_channels,
out_channels=1,
channels=decoder_channels,
global_channels=global_channels,
kernel_size=decoder_kernel_size,
upsample_scales=decoder_upsample_scales,
upsample_kernel_sizes=decoder_upsample_kernel_sizes,
resblock_kernel_sizes=decoder_resblock_kernel_sizes,
resblock_dilations=decoder_resblock_dilations,
use_weight_norm=use_weight_norm_in_decoder, )
self.posterior_encoder = PosteriorEncoder(
in_channels=aux_channels,
out_channels=hidden_channels,
hidden_channels=hidden_channels,
kernel_size=posterior_encoder_kernel_size,
layers=posterior_encoder_layers,
stacks=posterior_encoder_stacks,
base_dilation=posterior_encoder_base_dilation,
global_channels=global_channels,
dropout_rate=posterior_encoder_dropout_rate,
use_weight_norm=use_weight_norm_in_posterior_encoder, )
self.flow = ResidualAffineCouplingBlock(
in_channels=hidden_channels,
hidden_channels=hidden_channels,
flows=flow_flows,
kernel_size=flow_kernel_size,
base_dilation=flow_base_dilation,
layers=flow_layers,
global_channels=global_channels,
dropout_rate=flow_dropout_rate,
use_weight_norm=use_weight_norm_in_flow,
use_only_mean=use_only_mean_in_flow, )
# TODO: Add deterministic version as an option
self.duration_predictor = StochasticDurationPredictor(
channels=hidden_channels,
kernel_size=stochastic_duration_predictor_kernel_size,
dropout_rate=stochastic_duration_predictor_dropout_rate,
flows=stochastic_duration_predictor_flows,
dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
global_channels=global_channels, )
self.upsample_factor = int(np.prod(decoder_upsample_scales))
self.spks = None
if spks is not None and spks > 1:
assert global_channels > 0
self.spks = spks
self.global_emb = nn.Embedding(spks, global_channels)
self.spk_embed_dim = None
if spk_embed_dim is not None and spk_embed_dim > 0:
assert global_channels > 0
self.spk_embed_dim = spk_embed_dim
self.spemb_proj = nn.Linear(spk_embed_dim, global_channels)
self.langs = None
if langs is not None and langs > 1:
assert global_channels > 0
self.langs = langs
self.lang_emb = nn.Embedding(langs, global_channels)
# delayed import
from paddlespeech.t2s.models.vits.monotonic_align import maximum_path
self.maximum_path = maximum_path
def forward(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
paddle.Tensor, paddle.Tensor,
Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
paddle.Tensor, paddle.Tensor, ], ]:
"""Calculate forward propagation.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, aux_channels, T_feats).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
Tensor: Duration negative log-likelihood (NLL) tensor (B,).
Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text).
Tensor: Segments start index tensor (B,).
Tensor: Text mask tensor (B, 1, T_text).
Tensor: Feature mask tensor (B, 1, T_feats).
tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
- Tensor: Posterior encoder hidden representation (B, H, T_feats).
- Tensor: Flow hidden representation (B, H, T_feats).
- Tensor: Expanded text encoder projected mean (B, H, T_feats).
- Tensor: Expanded text encoder projected scale (B, H, T_feats).
- Tensor: Posterior encoder projected mean (B, H, T_feats).
- Tensor: Posterior encoder projected scale (B, H, T_feats).
"""
# forward text encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
# calculate global conditioning
g = None
if self.spks is not None:
# speaker one-hot vector embedding: (B, global_channels, 1)
g = self.global_emb(paddle.reshape(sids, [-1])).unsqueeze(-1)
if self.spk_embed_dim is not None:
# pretreined speaker embedding, e.g., X-vector (B, global_channels, 1)
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if self.langs is not None:
# language one-hot vector embedding: (B, global_channels, 1)
g_ = self.lang_emb(paddle.reshape(lids, [-1])).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(
feats, feats_lengths, g=g)
# forward flow
# (B, H, T_feats)
z_p = self.flow(z, y_mask, g=g)
# monotonic alignment search
with paddle.no_grad():
# negative cross-entropy
# (B, H, T_text)
s_p_sq_r = paddle.exp(-2 * logs_p)
# (B, 1, T_text)
neg_x_ent_1 = paddle.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
[1],
keepdim=True, )
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_2 = paddle.matmul(
-0.5 * (z_p**2).transpose([0, 2, 1]),
s_p_sq_r, )
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_3 = paddle.matmul(
z_p.transpose([0, 2, 1]),
(m_p * s_p_sq_r), )
# (B, 1, T_text)
neg_x_ent_4 = paddle.sum(
-0.5 * (m_p**2) * s_p_sq_r,
[1],
keepdim=True, )
# (B, T_feats, T_text)
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
# (B, 1, T_feats, T_text)
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
-1)
# monotonic attention weight: (B, 1, T_feats, T_text)
attn = (self.maximum_path(
neg_x_ent,
attn_mask.squeeze(1), ).unsqueeze(1).detach())
# forward duration predictor
# (B, 1, T_text)
w = attn.sum(2)
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
dur_nll = dur_nll / paddle.sum(x_mask)
# expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
m_p = paddle.matmul(attn.squeeze(1),
m_p.transpose([0, 2, 1])).transpose([0, 2, 1])
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
logs_p = paddle.matmul(attn.squeeze(1),
logs_p.transpose([0, 2, 1])).transpose([0, 2, 1])
# get random segments
z_segments, z_start_idxs = get_random_segments(
z,
feats_lengths,
self.segment_size, )
# forward decoder with random segments
wav = self.decoder(z_segments, g=g)
return (wav, dur_nll, attn, z_start_idxs, x_mask, y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q), )
def inference(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: Optional[paddle.Tensor]=None,
feats_lengths: Optional[paddle.Tensor]=None,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
dur: Optional[paddle.Tensor]=None,
noise_scale: float=0.667,
noise_scale_dur: float=0.8,
alpha: float=1.0,
max_len: Optional[int]=None,
use_teacher_forcing: bool=False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Run inference.
Args:
text (Tensor): Input text index tensor (B, T_text,).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
skip the prediction of durations (i.e., teacher forcing).
noise_scale (float): Noise scale parameter for flow.
noise_scale_dur (float): Noise scale parameter for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length of acoustic feature sequence.
use_teacher_forcing (bool): Whether to use teacher forcing.
Returns:
Tensor: Generated waveform tensor (B, T_wav).
Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
Tensor: Duration tensor (B, T_text).
"""
# encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
g = None
if self.spks is not None:
# (B, global_channels, 1)
g = self.global_emb(paddle.reshape(sids, [-1])).unsqueeze(-1)
if self.spk_embed_dim is not None:
# (B, global_channels, 1)
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if self.langs is not None:
# (B, global_channels, 1)
g_ = self.lang_emb(paddle.reshape(lids, [-1])).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if use_teacher_forcing:
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(
feats, feats_lengths, g=g)
# forward flow
# (B, H, T_feats)
z_p = self.flow(z, y_mask, g=g)
# monotonic alignment search
# (B, H, T_text)
s_p_sq_r = paddle.exp(-2 * logs_p)
# (B, 1, T_text)
neg_x_ent_1 = paddle.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
[1],
keepdim=True, )
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_2 = paddle.matmul(
-0.5 * (z_p**2).transpose([0, 2, 1]),
s_p_sq_r, )
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_3 = paddle.matmul(
z_p.transpose([0, 2, 1]),
(m_p * s_p_sq_r), )
# (B, 1, T_text)
neg_x_ent_4 = paddle.sum(
-0.5 * (m_p**2) * s_p_sq_r,
[1],
keepdim=True, )
# (B, T_feats, T_text)
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
# (B, 1, T_feats, T_text)
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
-1)
# monotonic attention weight: (B, 1, T_feats, T_text)
attn = self.maximum_path(
neg_x_ent,
attn_mask.squeeze(1), ).unsqueeze(1)
# (B, 1, T_text)
dur = attn.sum(2)
# forward decoder with random segments
wav = self.decoder(z * y_mask, g=g)
else:
# duration
if dur is None:
logw = self.duration_predictor(
x,
x_mask,
g=g,
inverse=True,
noise_scale=noise_scale_dur, )
w = paddle.exp(logw) * x_mask * alpha
dur = paddle.ceil(w)
y_lengths = paddle.cast(
paddle.clip(paddle.sum(dur, [1, 2]), min=1), dtype='int64')
y_mask = make_non_pad_mask(y_lengths).unsqueeze(1)
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
-1)
attn = self._generate_path(dur, attn_mask)
# expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
m_p = paddle.matmul(
attn.squeeze(1),
m_p.transpose([0, 2, 1]), ).transpose([0, 2, 1])
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
logs_p = paddle.matmul(
attn.squeeze(1),
logs_p.transpose([0, 2, 1]), ).transpose([0, 2, 1])
# decoder
z_p = m_p + paddle.randn(
paddle.shape(m_p)) * paddle.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, inverse=True)
wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
def _generate_path(self, dur: paddle.Tensor,
mask: paddle.Tensor) -> paddle.Tensor:
"""Generate path a.k.a. monotonic attention.
Args:
dur (Tensor): Duration tensor (B, 1, T_text).
mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text).
Returns:
Tensor: Path tensor (B, 1, T_feats, T_text).
"""
b, _, t_y, t_x = paddle.shape(mask)
cum_dur = paddle.cumsum(dur, -1)
cum_dur_flat = paddle.reshape(cum_dur, [b * t_x])
path = paddle.arange(t_y, dtype=dur.dtype)
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
path = paddle.reshape(path, [b, t_x, t_y])
'''
path will be like (t_x = 3, t_y = 5):
[[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
[1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
[1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
'''
path = paddle.cast(path, dtype='float32')
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
return path.unsqueeze(1).transpose([0, 1, 3, 2]) * mask
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Maximum path calculation module.
This code is based on https://github.com/jaywalnut310/vits.
"""
import warnings
import numpy as np
import paddle
from numba import njit
from numba import prange
try:
from .core import maximum_path_c
is_cython_avalable = True
except ImportError:
is_cython_avalable = False
warnings.warn(
"Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
"If you want to use the cython version, please build it as follows: "
"`cd paddlespeech/t2s/models/vits/monotonic_align; python setup.py build_ext --inplace`"
)
def maximum_path(neg_x_ent: paddle.Tensor,
attn_mask: paddle.Tensor) -> paddle.Tensor:
"""Calculate maximum path.
Args:
neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
attn_mask (Tensor): Attention mask (B, T_feats, T_text).
Returns:
Tensor: Maximum path tensor (B, T_feats, T_text).
"""
dtype = neg_x_ent.dtype
neg_x_ent = neg_x_ent.numpy().astype(np.float32)
path = np.zeros(neg_x_ent.shape, dtype=np.int32)
t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
if is_cython_avalable:
maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
else:
maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
return paddle.cast(paddle.to_tensor(path), dtype=dtype)
@njit
def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
"""Calculate a single maximum path with numba."""
index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or
value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1
@njit(parallel=True)
def maximum_path_numba(paths, values, t_ys, t_xs):
"""Calculate batch maximum path with numba."""
for i in prange(paths.shape[0]):
maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Maximum path calculation module with cython optimization.
This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
"""
cimport cython
from cython.parallel import prange
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
cdef int x
cdef int y
cdef float v_prev
cdef float v_cur
cdef float tmp
cdef int index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
cdef int b = paths.shape[0]
cdef int i
for i in prange(b, nogil=True):
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup cython code."""
from Cython.Build import cythonize
from setuptools import Extension
from setuptools import setup
from setuptools.command.build_ext import build_ext as _build_ext
class build_ext(_build_ext):
"""Overwrite build_ext."""
def finalize_options(self):
"""Prevent numpy from thinking it is still in its setup process."""
_build_ext.finalize_options(self)
__builtins__.__NUMPY_SETUP__ = False
import numpy
self.include_dirs.append(numpy.get_include())
exts = [Extension(
name="core",
sources=["core.pyx"], )]
setup(
name="monotonic_align",
ext_modules=cythonize(exts, language_level=3),
cmdclass={"build_ext": build_ext}, )
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text encoder module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
from typing import Optional
from typing import Tuple
import paddle
from paddle import nn
from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
class PosteriorEncoder(nn.Layer):
"""Posterior encoder module in VITS.
This is a module of posterior encoder described in `Conditional Variational
Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
in_channels: int=513,
out_channels: int=192,
hidden_channels: int=192,
kernel_size: int=5,
layers: int=16,
stacks: int=1,
base_dilation: int=1,
global_channels: int=-1,
dropout_rate: float=0.0,
bias: bool=True,
use_weight_norm: bool=True, ):
"""Initilialize PosteriorEncoder module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size in WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of repeat stacking of WaveNet.
base_dilation (int): Base dilation factor.
global_channels (int): Number of global conditioning channels.
dropout_rate (float): Dropout rate.
bias (bool): Whether to use bias parameters in conv.
use_weight_norm (bool): Whether to apply weight norm.
"""
super().__init__()
# define modules
self.input_conv = nn.Conv1D(in_channels, hidden_channels, 1)
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True, )
self.proj = nn.Conv1D(hidden_channels, out_channels * 2, 1)
def forward(
self,
x: paddle.Tensor,
x_lengths: paddle.Tensor,
g: Optional[paddle.Tensor]=None
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T_feats).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
Tensor: Projected mean tensor (B, out_channels, T_feats).
Tensor: Projected scale tensor (B, out_channels, T_feats).
Tensor: Mask tensor for input tensor (B, 1, T_feats).
"""
x_mask = make_non_pad_mask(x_lengths).unsqueeze(1)
x = self.input_conv(x) * x_mask
x = self.encoder(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = paddle.split(stats, 2, axis=1)
z = (m + paddle.randn(paddle.shape(m)) * paddle.exp(logs)) * x_mask
return z, m, logs, x_mask
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Residual affine coupling modules in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
from typing import Optional
from typing import Tuple
from typing import Union
import paddle
from paddle import nn
from paddlespeech.t2s.models.vits.flow import FlipFlow
from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet
class ResidualAffineCouplingBlock(nn.Layer):
"""Residual affine coupling block module.
This is a module of residual affine coupling block, which used as "Flow" in
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
in_channels: int=192,
hidden_channels: int=192,
flows: int=4,
kernel_size: int=5,
base_dilation: int=1,
layers: int=4,
global_channels: int=-1,
dropout_rate: float=0.0,
use_weight_norm: bool=True,
bias: bool=True,
use_only_mean: bool=True, ):
"""Initilize ResidualAffineCouplingBlock module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
flows (int): Number of flows.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
super().__init__()
self.flows = nn.LayerList()
for i in range(flows):
self.flows.append(
ResidualAffineCouplingLayer(
in_channels=in_channels,
hidden_channels=hidden_channels,
kernel_size=kernel_size,
base_dilation=base_dilation,
layers=layers,
stacks=1,
global_channels=global_channels,
dropout_rate=dropout_rate,
use_weight_norm=use_weight_norm,
bias=bias,
use_only_mean=use_only_mean, ))
self.flows.append(FlipFlow())
def forward(
self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
g: Optional[paddle.Tensor]=None,
inverse: bool=False, ) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_mask (Tensor): Length tensor (B, 1, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
"""
if not inverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, inverse=inverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, inverse=inverse)
return x
class ResidualAffineCouplingLayer(nn.Layer):
"""Residual affine coupling layer."""
def __init__(
self,
in_channels: int=192,
hidden_channels: int=192,
kernel_size: int=5,
base_dilation: int=1,
layers: int=5,
stacks: int=1,
global_channels: int=-1,
dropout_rate: float=0.0,
use_weight_norm: bool=True,
bias: bool=True,
use_only_mean: bool=True, ):
"""Initialzie ResidualAffineCouplingLayer module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
super().__init__()
self.half_channels = in_channels // 2
self.use_only_mean = use_only_mean
# define modules
self.input_conv = nn.Conv1D(
self.half_channels,
hidden_channels,
1, )
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True, )
if use_only_mean:
self.proj = nn.Conv1D(
hidden_channels,
self.half_channels,
1, )
else:
self.proj = nn.Conv1D(
hidden_channels,
self.half_channels * 2,
1, )
weight = paddle.zeros(paddle.shape(self.proj.weight))
self.proj.weight = paddle.create_parameter(
shape=weight.shape,
dtype=str(weight.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(weight))
bias = paddle.zeros(paddle.shape(self.proj.bias))
self.proj.bias = paddle.create_parameter(
shape=bias.shape,
dtype=str(bias.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(bias))
def forward(
self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
g: Optional[paddle.Tensor]=None,
inverse: bool=False,
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = paddle.split(x, 2, axis=1)
h = self.input_conv(xa) * x_mask
h = self.encoder(h, x_mask, g=g)
stats = self.proj(h) * x_mask
if not self.use_only_mean:
m, logs = paddle.split(stats, 2, axis=1)
else:
m = stats
logs = paddle.zeros(paddle.shape(m))
if not inverse:
xb = m + xb * paddle.exp(logs) * x_mask
x = paddle.concat([xa, xb], 1)
logdet = paddle.sum(logs, [1, 2])
return x, logdet
else:
xb = (xb - m) * paddle.exp(-logs) * x_mask
x = paddle.concat([xa, xb], 1)
return x
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text encoder module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Tuple
import paddle
from paddle import nn
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder as Encoder
class TextEncoder(nn.Layer):
"""Text encoder module in VITS.
This is a module of text encoder described in `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`_.
Instead of the relative positional Transformer, we use conformer architecture as
the encoder module, which contains additional convolution layers.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
vocabs: int,
attention_dim: int=192,
attention_heads: int=2,
linear_units: int=768,
blocks: int=6,
positionwise_layer_type: str="conv1d",
positionwise_conv_kernel_size: int=3,
positional_encoding_layer_type: str="rel_pos",
self_attention_layer_type: str="rel_selfattn",
activation_type: str="swish",
normalize_before: bool=True,
use_macaron_style: bool=False,
use_conformer_conv: bool=False,
conformer_kernel_size: int=7,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.0,
attention_dropout_rate: float=0.0, ):
"""Initialize TextEncoder module.
Args:
vocabs (int): Vocabulary size.
attention_dim (int): Attention dimension.
attention_heads (int): Number of attention heads.
linear_units (int): Number of linear units of positionwise layers.
blocks (int): Number of encoder blocks.
positionwise_layer_type (str): Positionwise layer type.
positionwise_conv_kernel_size (int): Positionwise layer's kernel size.
positional_encoding_layer_type (str): Positional encoding layer type.
self_attention_layer_type (str): Self-attention layer type.
activation_type (str): Activation function type.
normalize_before (bool): Whether to apply LayerNorm before attention.
use_macaron_style (bool): Whether to use macaron style components.
use_conformer_conv (bool): Whether to use conformer conv layers.
conformer_kernel_size (int): Conformer's conv kernel size.
dropout_rate (float): Dropout rate.
positional_dropout_rate (float): Dropout rate for positional encoding.
attention_dropout_rate (float): Dropout rate for attention.
"""
super().__init__()
# store for forward
self.attention_dim = attention_dim
# define modules
self.emb = nn.Embedding(vocabs, attention_dim)
dist = paddle.distribution.Normal(loc=0.0, scale=attention_dim**-0.5)
w = dist.sample(self.emb.weight.shape)
self.emb.weight.set_value(w)
self.encoder = Encoder(
idim=-1,
input_layer=None,
attention_dim=attention_dim,
attention_heads=attention_heads,
linear_units=linear_units,
num_blocks=blocks,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
normalize_before=normalize_before,
positionwise_layer_type=positionwise_layer_type,
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
macaron_style=use_macaron_style,
pos_enc_layer_type=positional_encoding_layer_type,
selfattention_layer_type=self_attention_layer_type,
activation_type=activation_type,
use_cnn_module=use_conformer_conv,
cnn_module_kernel=conformer_kernel_size, )
self.proj = nn.Conv1D(attention_dim, attention_dim * 2, 1)
def forward(
self,
x: paddle.Tensor,
x_lengths: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input index tensor (B, T_text).
x_lengths (Tensor): Length tensor (B,).
Returns:
Tensor: Encoded hidden representation (B, attention_dim, T_text).
Tensor: Projected mean tensor (B, attention_dim, T_text).
Tensor: Projected scale tensor (B, attention_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text).
"""
x = self.emb(x) * math.sqrt(self.attention_dim)
x_mask = make_non_pad_mask(x_lengths).unsqueeze(1)
# encoder assume the channel last (B, T_text, attention_dim)
# but mask shape shoud be (B, 1, T_text)
x, _ = self.encoder(x, x_mask)
# convert the channel first (B, attention_dim, T_text)
x = paddle.transpose(x, [0, 2, 1])
stats = self.proj(x) * x_mask
m, logs = paddle.split(stats, 2, axis=1)
return x, m, logs, x_mask
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flow-related transformation.
This code is based on https://github.com/bayesiains/nflows.
"""
import numpy as np
import paddle
from paddle.nn import functional as F
from paddlespeech.t2s.modules.nets_utils import paddle_gather
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs)
return outputs, logabsdet
def mask_preprocess(x, mask):
B, C, T, bins = paddle.shape(x)
new_x = paddle.zeros([mask.sum(), bins])
for i in range(bins):
new_x[:, i] = x[:, :, :, i][mask]
return new_x
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = paddle.zeros(paddle.shape(inputs))
logabsdet = paddle.zeros(paddle.shape(inputs))
if tails == "linear":
unnormalized_derivatives = F.pad(
unnormalized_derivatives,
pad=[0] * (len(unnormalized_derivatives.shape) - 1) * 2 + [1, 1])
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
unnormalized_widths = mask_preprocess(unnormalized_widths,
inside_interval_mask)
unnormalized_heights = mask_preprocess(unnormalized_heights,
inside_interval_mask)
unnormalized_derivatives = mask_preprocess(unnormalized_derivatives,
inside_interval_mask)
(outputs[inside_interval_mask],
logabsdet[inside_interval_mask], ) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative, )
return outputs, logabsdet
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
if paddle.min(inputs) < left or paddle.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, axis=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = paddle.cumsum(widths, axis=-1)
cumwidths = F.pad(
cumwidths,
pad=[0] * (len(cumwidths.shape) - 1) * 2 + [1, 0],
mode="constant",
value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, axis=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = paddle.cumsum(heights, axis=-1)
cumheights = F.pad(
cumheights,
pad=[0] * (len(cumheights.shape) - 1) * 2 + [1, 0],
mode="constant",
value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = _searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = _searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = paddle_gather(cumwidths, -1, bin_idx)[..., 0]
input_bin_widths = paddle_gather(widths, -1, bin_idx)[..., 0]
input_cumheights = paddle_gather(cumheights, -1, bin_idx)[..., 0]
delta = heights / widths
input_delta = paddle_gather(delta, -1, bin_idx)[..., 0]
input_derivatives = paddle_gather(derivatives, -1, bin_idx)[..., 0]
input_derivatives_plus_one = paddle_gather(derivatives[..., 1:], -1,
bin_idx)[..., 0]
input_heights = paddle_gather(heights, -1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - paddle.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta
) * theta_one_minus_theta)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2) + 2 * input_delta *
theta_one_minus_theta + input_derivatives * (1 - root).pow(2))
logabsdet = paddle.log(derivative_numerator) - 2 * paddle.log(
denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2) +
input_derivatives * theta_one_minus_theta)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta
) * theta_one_minus_theta)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2) + 2 * input_delta *
theta_one_minus_theta + input_derivatives * (1 - theta).pow(2))
logabsdet = paddle.log(derivative_numerator) - 2 * paddle.log(
denominator)
return outputs, logabsdet
def _searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return paddle.sum(inputs[..., None] >= bin_locations, axis=-1) - 1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""VITS module"""
from typing import Any
from typing import Dict
from typing import Optional
import paddle
from paddle import nn
from typeguard import check_argument_types
from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator
from paddlespeech.t2s.models.vits.generator import VITSGenerator
from paddlespeech.t2s.modules.nets_utils import initialize
AVAILABLE_GENERATERS = {
"vits_generator": VITSGenerator,
}
AVAILABLE_DISCRIMINATORS = {
"hifigan_period_discriminator":
HiFiGANPeriodDiscriminator,
"hifigan_scale_discriminator":
HiFiGANScaleDiscriminator,
"hifigan_multi_period_discriminator":
HiFiGANMultiPeriodDiscriminator,
"hifigan_multi_scale_discriminator":
HiFiGANMultiScaleDiscriminator,
"hifigan_multi_scale_multi_period_discriminator":
HiFiGANMultiScaleMultiPeriodDiscriminator,
}
class VITS(nn.Layer):
"""VITS module (generator + discriminator).
This is a module of VITS described in `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
# generator related
idim: int,
odim: int,
sampling_rate: int=22050,
generator_type: str="vits_generator",
generator_params: Dict[str, Any]={
"hidden_channels": 192,
"spks": None,
"langs": None,
"spk_embed_dim": None,
"global_channels": -1,
"segment_size": 32,
"text_encoder_attention_heads": 2,
"text_encoder_ffn_expand": 4,
"text_encoder_blocks": 6,
"text_encoder_positionwise_layer_type": "conv1d",
"text_encoder_positionwise_conv_kernel_size": 1,
"text_encoder_positional_encoding_layer_type": "rel_pos",
"text_encoder_self_attention_layer_type": "rel_selfattn",
"text_encoder_activation_type": "swish",
"text_encoder_normalize_before": True,
"text_encoder_dropout_rate": 0.1,
"text_encoder_positional_dropout_rate": 0.0,
"text_encoder_attention_dropout_rate": 0.0,
"text_encoder_conformer_kernel_size": 7,
"use_macaron_style_in_text_encoder": True,
"use_conformer_conv_in_text_encoder": True,
"decoder_kernel_size": 7,
"decoder_channels": 512,
"decoder_upsample_scales": [8, 8, 2, 2],
"decoder_upsample_kernel_sizes": [16, 16, 4, 4],
"decoder_resblock_kernel_sizes": [3, 7, 11],
"decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"use_weight_norm_in_decoder": True,
"posterior_encoder_kernel_size": 5,
"posterior_encoder_layers": 16,
"posterior_encoder_stacks": 1,
"posterior_encoder_base_dilation": 1,
"posterior_encoder_dropout_rate": 0.0,
"use_weight_norm_in_posterior_encoder": True,
"flow_flows": 4,
"flow_kernel_size": 5,
"flow_base_dilation": 1,
"flow_layers": 4,
"flow_dropout_rate": 0.0,
"use_weight_norm_in_flow": True,
"use_only_mean_in_flow": True,
"stochastic_duration_predictor_kernel_size": 3,
"stochastic_duration_predictor_dropout_rate": 0.5,
"stochastic_duration_predictor_flows": 4,
"stochastic_duration_predictor_dds_conv_layers": 3,
},
# discriminator related
discriminator_type: str="hifigan_multi_scale_multi_period_discriminator",
discriminator_params: Dict[str, Any]={
"scales": 1,
"scale_downsample_pooling": "AvgPool1D",
"scale_downsample_pooling_params": {
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
"scale_discriminator_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "leakyrelu",
"nonlinear_activation_params": {
"negative_slope": 0.1
},
"use_weight_norm": True,
"use_spectral_norm": False,
},
"follow_official_norm": False,
"periods": [2, 3, 5, 7, 11],
"period_discriminator_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "leakyrelu",
"nonlinear_activation_params": {
"negative_slope": 0.1
},
"use_weight_norm": True,
"use_spectral_norm": False,
},
},
cache_generator_outputs: bool=True,
init_type: str="xavier_uniform", ):
"""Initialize VITS module.
Args:
idim (int): Input vocabrary size.
odim (int): Acoustic feature dimension. The actual output channels will
be 1 since VITS is the end-to-end text-to-wave model but for the
compatibility odim is used to indicate the acoustic feature dimension.
sampling_rate (int): Sampling rate, not used for the training but it will
be referred in saving waveform during the inference.
generator_type (str): Generator type.
generator_params (Dict[str, Any]): Parameter dict for generator.
discriminator_type (str): Discriminator type.
discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
cache_generator_outputs (bool): Whether to cache generator outputs.
"""
assert check_argument_types()
super().__init__()
# initialize parameters
initialize(self, init_type)
# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator":
# NOTE: Update parameters for the compatibility.
# The idim and odim is automatically decided from input data,
# where idim represents #vocabularies and odim represents
# the input acoustic feature dimension.
generator_params.update(vocabs=idim, aux_channels=odim)
self.generator = generator_class(
**generator_params, )
discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
self.discriminator = discriminator_class(
**discriminator_params, )
nn.initializer.set_global_initializer(None)
# cache
self.cache_generator_outputs = cache_generator_outputs
self._cache = None
# store sampling rate for saving wav file
# (not used for the training)
self.fs = sampling_rate
# store parameters for test compatibility
self.spks = self.generator.spks
self.langs = self.generator.langs
self.spk_embed_dim = self.generator.spk_embed_dim
self.reuse_cache_gen = True
self.reuse_cache_dis = True
def forward(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
forward_generator: bool=True, ) -> Dict[str, Any]:
"""Perform generator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
forward_generator (bool): Whether to forward generator.
Returns:
Dict[str, Any]:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
- weight (Tensor): Weight tensor to summarize losses.
- optim_idx (int): Optimizer index (0 for G and 1 for D).
"""
if forward_generator:
return self._forward_generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids, )
else:
return self._forward_discrminator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids, )
def _forward_generator(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None, ) -> Dict[str, Any]:
"""Perform generator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
"""
# setup
feats = feats.transpose([0, 2, 1])
# calculate generator outputs
self.reuse_cache_gen = True
if not self.cache_generator_outputs or self._cache is None:
self.reuse_cache_gen = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids, )
else:
outs = self._cache
# store cache
if self.training and self.cache_generator_outputs and not self.reuse_cache_gen:
self._cache = outs
return outs
def _forward_discrminator(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None, ) -> Dict[str, Any]:
"""Perform discriminator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
"""
# setup
feats = feats.transpose([0, 2, 1])
# calculate generator outputs
self.reuse_cache_dis = True
if not self.cache_generator_outputs or self._cache is None:
self.reuse_cache_dis = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids, )
else:
outs = self._cache
# store cache
if self.cache_generator_outputs and not self.reuse_cache_dis:
self._cache = outs
return outs
def inference(
self,
text: paddle.Tensor,
feats: Optional[paddle.Tensor]=None,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
durations: Optional[paddle.Tensor]=None,
noise_scale: float=0.667,
noise_scale_dur: float=0.8,
alpha: float=1.0,
max_len: Optional[int]=None,
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
"""Run inference.
Args:
text (Tensor): Input text index tensor (T_text,).
feats (Tensor): Feature tensor (T_feats, aux_channels).
sids (Tensor): Speaker index tensor (1,).
spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
lids (Tensor): Language index tensor (1,).
durations (Tensor): Ground-truth duration tensor (T_text,).
noise_scale (float): Noise scale value for flow.
noise_scale_dur (float): Noise scale value for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length.
use_teacher_forcing (bool): Whether to use teacher forcing.
Returns:
Dict[str, Tensor]:
* wav (Tensor): Generated waveform tensor (T_wav,).
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
* duration (Tensor): Predicted duration tensor (T_text,).
"""
# setup
text = text[None]
text_lengths = paddle.to_tensor(paddle.shape(text)[1])
if durations is not None:
durations = paddle.reshape(durations, [1, 1, -1])
# inference
if use_teacher_forcing:
assert feats is not None
feats = feats[None].transpose([0, 2, 1])
feats_lengths = paddle.to_tensor([paddle.shape(feats)[2]])
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids,
max_len=max_len,
use_teacher_forcing=use_teacher_forcing, )
else:
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
sids=sids,
spembs=spembs,
lids=lids,
dur=durations,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
max_len=max_len, )
return dict(
wav=paddle.reshape(wav, [-1]), att_w=att_w[0], duration=dur[0])
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from paddlespeech.t2s.modules.nets_utils import get_segments
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class VITSUpdater(StandardUpdater):
def __init__(self,
model: Layer,
optimizers: Dict[str, Optimizer],
criterions: Dict[str, Layer],
schedulers: Dict[str, LRScheduler],
dataloader: DataLoader,
generator_train_start_steps: int=0,
discriminator_train_start_steps: int=100000,
lambda_adv: float=1.0,
lambda_mel: float=45.0,
lambda_feat_match: float=2.0,
lambda_dur: float=1.0,
lambda_kl: float=1.0,
generator_first: bool=False,
output_dir=None):
# it is designed to hold multiple models
# 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分
models = {"main": model}
self.models: Dict[str, Layer] = models
# self.model = model
self.model = model._layers if isinstance(model, paddle.DataParallel) else model
self.optimizers = optimizers
self.optimizer_g: Optimizer = optimizers['generator']
self.optimizer_d: Optimizer = optimizers['discriminator']
self.criterions = criterions
self.criterion_mel = criterions['mel']
self.criterion_feat_match = criterions['feat_match']
self.criterion_gen_adv = criterions["gen_adv"]
self.criterion_dis_adv = criterions["dis_adv"]
self.criterion_kl = criterions["kl"]
self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
self.scheduler_d = schedulers['discriminator']
self.dataloader = dataloader
self.generator_train_start_steps = generator_train_start_steps
self.discriminator_train_start_steps = discriminator_train_start_steps
self.lambda_adv = lambda_adv
self.lambda_mel = lambda_mel
self.lambda_feat_match = lambda_feat_match
self.lambda_dur = lambda_dur
self.lambda_kl = lambda_kl
if generator_first:
self.turns = ["generator", "discriminator"]
else:
self.turns = ["discriminator", "generator"]
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
for turn in self.turns:
speech = batch["speech"]
speech = speech.unsqueeze(1)
outs = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
feats=batch["feats"],
feats_lengths=batch["feats_lengths"],
forward_generator=turn == "generator")
# Generator
if turn == "generator":
# parse outputs
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
_, z_p, m_p, logs_p, _, logs_q = outs_
speech_ = get_segments(
x=speech,
start_idxs=start_idxs *
self.model.generator.upsample_factor,
segment_size=self.model.generator.segment_size *
self.model.generator.upsample_factor, )
# calculate discriminator outputs
p_hat = self.model.discriminator(speech_hat_)
with paddle.no_grad():
# do not store discriminator gradient in generator turn
p = self.model.discriminator(speech_)
# calculate losses
mel_loss = self.criterion_mel(speech_hat_, speech_)
kl_loss = self.criterion_kl(z_p, logs_q, m_p, logs_p, z_mask)
dur_loss = paddle.sum(dur_nll)
adv_loss = self.criterion_gen_adv(p_hat)
feat_match_loss = self.criterion_feat_match(p_hat, p)
mel_loss = mel_loss * self.lambda_mel
kl_loss = kl_loss * self.lambda_kl
dur_loss = dur_loss * self.lambda_dur
adv_loss = adv_loss * self.lambda_adv
feat_match_loss = feat_match_loss * self.lambda_feat_match
gen_loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
report("train/generator_loss", float(gen_loss))
report("train/generator_mel_loss", float(mel_loss))
report("train/generator_kl_loss", float(kl_loss))
report("train/generator_dur_loss", float(dur_loss))
report("train/generator_adv_loss", float(adv_loss))
report("train/generator_feat_match_loss",
float(feat_match_loss))
losses_dict["generator_loss"] = float(gen_loss)
losses_dict["generator_mel_loss"] = float(mel_loss)
losses_dict["generator_kl_loss"] = float(kl_loss)
losses_dict["generator_dur_loss"] = float(dur_loss)
losses_dict["generator_adv_loss"] = float(adv_loss)
losses_dict["generator_feat_match_loss"] = float(
feat_match_loss)
self.optimizer_g.clear_grad()
gen_loss.backward()
self.optimizer_g.step()
self.scheduler_g.step()
# reset cache
if self.model.reuse_cache_gen or not self.model.training:
self.model._cache = None
# Disctiminator
elif turn == "discriminator":
# parse outputs
speech_hat_, _, _, start_idxs, *_ = outs
speech_ = get_segments(
x=speech,
start_idxs=start_idxs *
self.model.generator.upsample_factor,
segment_size=self.model.generator.segment_size *
self.model.generator.upsample_factor, )
# calculate discriminator outputs
p_hat = self.model.discriminator(speech_hat_.detach())
p = self.model.discriminator(speech_)
# calculate losses
real_loss, fake_loss = self.criterion_dis_adv(p_hat, p)
dis_loss = real_loss + fake_loss
report("train/real_loss", float(real_loss))
report("train/fake_loss", float(fake_loss))
report("train/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.optimizer_d.clear_grad()
dis_loss.backward()
self.optimizer_d.step()
self.scheduler_d.step()
# reset cache
if self.model.reuse_cache_dis or not self.model.training:
self.model._cache = None
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class VITSEvaluator(StandardEvaluator):
def __init__(self,
model,
criterions: Dict[str, Layer],
dataloader: DataLoader,
lambda_adv: float=1.0,
lambda_mel: float=45.0,
lambda_feat_match: float=2.0,
lambda_dur: float=1.0,
lambda_kl: float=1.0,
generator_first: bool=False,
output_dir=None):
# 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分
models = {"main": model}
self.models: Dict[str, Layer] = models
# self.model = model
self.model = model._layers if isinstance(model, paddle.DataParallel) else model
self.criterions = criterions
self.criterion_mel = criterions['mel']
self.criterion_feat_match = criterions['feat_match']
self.criterion_gen_adv = criterions["gen_adv"]
self.criterion_dis_adv = criterions["dis_adv"]
self.criterion_kl = criterions["kl"]
self.dataloader = dataloader
self.lambda_adv = lambda_adv
self.lambda_mel = lambda_mel
self.lambda_feat_match = lambda_feat_match
self.lambda_dur = lambda_dur
self.lambda_kl = lambda_kl
if generator_first:
self.turns = ["generator", "discriminator"]
else:
self.turns = ["discriminator", "generator"]
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
# logging.debug("Evaluate: ")
self.msg = "Evaluate: "
losses_dict = {}
for turn in self.turns:
speech = batch["speech"]
speech = speech.unsqueeze(1)
outs = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
feats=batch["feats"],
feats_lengths=batch["feats_lengths"],
forward_generator=turn == "generator")
# Generator
if turn == "generator":
# parse outputs
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
_, z_p, m_p, logs_p, _, logs_q = outs_
speech_ = get_segments(
x=speech,
start_idxs=start_idxs *
self.model.generator.upsample_factor,
segment_size=self.model.generator.segment_size *
self.model.generator.upsample_factor, )
# calculate discriminator outputs
p_hat = self.model.discriminator(speech_hat_)
with paddle.no_grad():
# do not store discriminator gradient in generator turn
p = self.model.discriminator(speech_)
# calculate losses
mel_loss = self.criterion_mel(speech_hat_, speech_)
kl_loss = self.criterion_kl(z_p, logs_q, m_p, logs_p, z_mask)
dur_loss = paddle.sum(dur_nll)
adv_loss = self.criterion_gen_adv(p_hat)
feat_match_loss = self.criterion_feat_match(p_hat, p)
mel_loss = mel_loss * self.lambda_mel
kl_loss = kl_loss * self.lambda_kl
dur_loss = dur_loss * self.lambda_dur
adv_loss = adv_loss * self.lambda_adv
feat_match_loss = feat_match_loss * self.lambda_feat_match
gen_loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
report("eval/generator_loss", float(gen_loss))
report("eval/generator_mel_loss", float(mel_loss))
report("eval/generator_kl_loss", float(kl_loss))
report("eval/generator_dur_loss", float(dur_loss))
report("eval/generator_adv_loss", float(adv_loss))
report("eval/generator_feat_match_loss", float(feat_match_loss))
losses_dict["generator_loss"] = float(gen_loss)
losses_dict["generator_mel_loss"] = float(mel_loss)
losses_dict["generator_kl_loss"] = float(kl_loss)
losses_dict["generator_dur_loss"] = float(dur_loss)
losses_dict["generator_adv_loss"] = float(adv_loss)
losses_dict["generator_feat_match_loss"] = float(
feat_match_loss)
# reset cache
if self.model.reuse_cache_gen or not self.model.training:
self.model._cache = None
# Disctiminator
elif turn == "discriminator":
# parse outputs
speech_hat_, _, _, start_idxs, *_ = outs
speech_ = get_segments(
x=speech,
start_idxs=start_idxs *
self.model.generator.upsample_factor,
segment_size=self.model.generator.segment_size *
self.model.generator.upsample_factor, )
# calculate discriminator outputs
p_hat = self.model.discriminator(speech_hat_.detach())
p = self.model.discriminator(speech_)
# calculate losses
real_loss, fake_loss = self.criterion_dis_adv(p_hat, p)
dis_loss = real_loss + fake_loss
report("eval/real_loss", float(real_loss))
report("eval/fake_loss", float(fake_loss))
report("eval/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
# reset cache
if self.model.reuse_cache_dis or not self.model.training:
self.model._cache = None
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import math
from typing import Optional
from typing import Tuple
import paddle
import paddle.nn.functional as F
from paddle import nn
class ResidualBlock(nn.Layer):
"""Residual block module in WaveNet."""
def __init__(
self,
kernel_size: int=3,
residual_channels: int=64,
gate_channels: int=128,
skip_channels: int=64,
aux_channels: int=80,
global_channels: int=-1,
dropout_rate: float=0.0,
dilation: int=1,
bias: bool=True,
scale_residual: bool=False, ):
"""Initialize ResidualBlock module.
Args:
kernel_size (int): Kernel size of dilation convolution layer.
residual_channels (int): Number of channels for residual connection.
skip_channels (int): Number of channels for skip connection.
aux_channels (int): Number of local conditioning channels.
dropout (float): Dropout probability.
dilation (int): Dilation factor.
bias (bool): Whether to add bias parameter in convolution layers.
scale_residual (bool): Whether to scale the residual outputs.
"""
super().__init__()
self.dropout_rate = dropout_rate
self.residual_channels = residual_channels
self.skip_channels = skip_channels
self.scale_residual = scale_residual
# check
assert (
kernel_size - 1) % 2 == 0, "Not support even number kernel size."
assert gate_channels % 2 == 0
# dilation conv
padding = (kernel_size - 1) // 2 * dilation
self.conv = nn.Conv1D(
residual_channels,
gate_channels,
kernel_size,
padding=padding,
dilation=dilation,
bias_attr=bias, )
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = nn.Conv1D(
aux_channels, gate_channels, kernel_size=1, bias_attr=False)
else:
self.conv1x1_aux = None
# global conditioning
if global_channels > 0:
self.conv1x1_glo = nn.Conv1D(
global_channels, gate_channels, kernel_size=1, bias_attr=False)
else:
self.conv1x1_glo = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
# NOTE: concat two convs into a single conv for the efficiency
# (integrate res 1x1 + skip 1x1 convs)
self.conv1x1_out = nn.Conv1D(
gate_out_channels,
residual_channels + skip_channels,
kernel_size=1,
bias_attr=bias)
def forward(
self,
x: paddle.Tensor,
x_mask: Optional[paddle.Tensor]=None,
c: Optional[paddle.Tensor]=None,
g: Optional[paddle.Tensor]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, residual_channels, T).
x_mask Optional[paddle.Tensor]: Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor for residual connection (B, residual_channels, T).
Tensor: Output tensor for skip connection (B, skip_channels, T).
"""
residual = x
x = F.dropout(x, p=self.dropout_rate, training=self.training)
x = self.conv(x)
# split into two part for gated activation
splitdim = 1
xa, xb = paddle.split(x, 2, axis=splitdim)
# local conditioning
if c is not None:
c = self.conv1x1_aux(c)
ca, cb = paddle.split(c, 2, axis=splitdim)
xa, xb = xa + ca, xb + cb
# global conditioning
if g is not None:
g = self.conv1x1_glo(g)
ga, gb = paddle.split(g, 2, axis=splitdim)
xa, xb = xa + ga, xb + gb
x = paddle.tanh(xa) * F.sigmoid(xb)
# residual + skip 1x1 conv
x = self.conv1x1_out(x)
if x_mask is not None:
x = x * x_mask
# split integrated conv results
x, s = paddle.split(
x, [self.residual_channels, self.skip_channels], axis=1)
# for residual connection
x = x + residual
if self.scale_residual:
x = x * math.sqrt(0.5)
return x, s
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import math
from typing import Optional
import paddle
from paddle import nn
from paddlespeech.t2s.models.vits.wavenet.residual_block import ResidualBlock
class WaveNet(nn.Layer):
"""WaveNet with global conditioning."""
def __init__(
self,
in_channels: int=1,
out_channels: int=1,
kernel_size: int=3,
layers: int=30,
stacks: int=3,
base_dilation: int=2,
residual_channels: int=64,
aux_channels: int=-1,
gate_channels: int=128,
skip_channels: int=64,
global_channels: int=-1,
dropout_rate: float=0.0,
bias: bool=True,
use_weight_norm: bool=True,
use_first_conv: bool=False,
use_last_conv: bool=False,
scale_residual: bool=False,
scale_skip_connect: bool=False, ):
"""Initialize WaveNet module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Kernel size of dilated convolution.
layers (int): Number of residual block layers.
stacks (int): Number of stacks i.e., dilation cycles.
base_dilation (int): Base dilation factor.
residual_channels (int): Number of channels in residual conv.
gate_channels (int): Number of channels in gated conv.
skip_channels (int): Number of channels in skip conv.
aux_channels (int): Number of channels for local conditioning feature.
global_channels (int): Number of channels for global conditioning feature.
dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
bias (bool): Whether to use bias parameter in conv layer.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
use_first_conv (bool): Whether to use the first conv layers.
use_last_conv (bool): Whether to use the last conv layers.
scale_residual (bool): Whether to scale the residual outputs.
scale_skip_connect (bool): Whether to scale the skip connection outputs.
"""
super().__init__()
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
self.base_dilation = base_dilation
self.use_first_conv = use_first_conv
self.use_last_conv = use_last_conv
self.scale_skip_connect = scale_skip_connect
# check the number of layers and stacks
assert layers % stacks == 0
layers_per_stack = layers // stacks
# define first convolution
if self.use_first_conv:
self.first_conv = nn.Conv1D(
in_channels, residual_channels, kernel_size=1, bias_attr=True)
# define residual blocks
self.conv_layers = nn.LayerList()
for layer in range(layers):
dilation = base_dilation**(layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
global_channels=global_channels,
dilation=dilation,
dropout_rate=dropout_rate,
bias=bias,
scale_residual=scale_residual, )
self.conv_layers.append(conv)
# define output layers
if self.use_last_conv:
self.last_conv = nn.Sequential(
nn.ReLU(),
nn.Conv1D(
skip_channels, skip_channels, kernel_size=1,
bias_attr=True),
nn.ReLU(),
nn.Conv1D(
skip_channels, out_channels, kernel_size=1, bias_attr=True),
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(
self,
x: paddle.Tensor,
x_mask: Optional[paddle.Tensor]=None,
c: Optional[paddle.Tensor]=None,
g: Optional[paddle.Tensor]=None, ) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
(B, residual_channels, T).
x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, out_channels, T) if use_last_conv else
(B, residual_channels, T).
"""
# encode to hidden representation
if self.use_first_conv:
x = self.first_conv(x)
# residual block
skips = 0.0
for f in self.conv_layers:
x, h = f(x, x_mask=x_mask, c=c, g=g)
skips = skips + h
x = skips
if self.scale_skip_connect:
x = x * math.sqrt(1.0 / len(self.conv_layers))
# apply final layers
if self.use_last_conv:
x = self.last_conv(x)
return x
def apply_weight_norm(self):
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)
...@@ -17,7 +17,6 @@ import librosa ...@@ -17,7 +17,6 @@ import librosa
import numpy as np import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.fluid.layers import sequence_mask
from paddle.nn import functional as F from paddle.nn import functional as F
from scipy import signal from scipy import signal
...@@ -160,7 +159,7 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None): ...@@ -160,7 +159,7 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
return x return x
# Loss for new Tacotron2 # Loss for Tacotron2
class GuidedAttentionLoss(nn.Layer): class GuidedAttentionLoss(nn.Layer):
"""Guided attention loss function module. """Guided attention loss function module.
...@@ -428,41 +427,6 @@ class Tacotron2Loss(nn.Layer): ...@@ -428,41 +427,6 @@ class Tacotron2Loss(nn.Layer):
return l1_loss, mse_loss, bce_loss return l1_loss, mse_loss, bce_loss
# Loss for Tacotron2
def attention_guide(dec_lens, enc_lens, N, T, g, dtype=None):
"""Build that W matrix. shape(B, T_dec, T_enc)
W[i, n, t] = 1 - exp(-(n/dec_lens[i] - t/enc_lens[i])**2 / (2g**2))
See also:
Tachibana, Hideyuki, Katsuya Uenoyama, and Shunsuke Aihara. 2017. “Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention.” ArXiv:1710.08969 [Cs, Eess], October. http://arxiv.org/abs/1710.08969.
"""
dtype = dtype or paddle.get_default_dtype()
dec_pos = paddle.arange(0, N).astype(dtype) / dec_lens.unsqueeze(
-1) # n/N # shape(B, T_dec)
enc_pos = paddle.arange(0, T).astype(dtype) / enc_lens.unsqueeze(
-1) # t/T # shape(B, T_enc)
W = 1 - paddle.exp(-(dec_pos.unsqueeze(-1) - enc_pos.unsqueeze(1))**2 /
(2 * g**2))
dec_mask = sequence_mask(dec_lens, maxlen=N)
enc_mask = sequence_mask(enc_lens, maxlen=T)
mask = dec_mask.unsqueeze(-1) * enc_mask.unsqueeze(1)
mask = paddle.cast(mask, W.dtype)
W *= mask
return W
def guided_attention_loss(attention_weight, dec_lens, enc_lens, g):
"""Guided attention loss, masked to excluded padding parts."""
_, N, T = attention_weight.shape
W = attention_guide(dec_lens, enc_lens, N, T, g, attention_weight.dtype)
total_tokens = (dec_lens * enc_lens).astype(W.dtype)
loss = paddle.mean(paddle.sum(W * attention_weight, [1, 2]) / total_tokens)
return loss
# Losses for GAN Vocoder # Losses for GAN Vocoder
def stft(x, def stft(x,
fft_size, fft_size,
...@@ -1006,3 +970,40 @@ class FeatureMatchLoss(nn.Layer): ...@@ -1006,3 +970,40 @@ class FeatureMatchLoss(nn.Layer):
feat_match_loss /= i + 1 feat_match_loss /= i + 1
return feat_match_loss return feat_match_loss
# loss for VITS
class KLDivergenceLoss(nn.Layer):
"""KL divergence loss."""
def forward(
self,
z_p: paddle.Tensor,
logs_q: paddle.Tensor,
m_p: paddle.Tensor,
logs_p: paddle.Tensor,
z_mask: paddle.Tensor,
) -> paddle.Tensor:
"""Calculate KL divergence loss.
Args:
z_p (Tensor): Flow hidden representation (B, H, T_feats).
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
z_mask (Tensor): Mask tensor (B, 1, T_feats).
Returns:
Tensor: KL divergence loss.
"""
z_p = paddle.cast(z_p, 'float32')
logs_q = paddle.cast(logs_q, 'float32')
m_p = paddle.cast(m_p, 'float32')
logs_p = paddle.cast(logs_p, 'float32')
z_mask = paddle.cast(z_mask, 'float32')
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * paddle.exp(-2.0 * logs_p)
kl = paddle.sum(kl * z_mask)
loss = kl / paddle.sum(z_mask)
return loss
\ No newline at end of file
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
from typing import Tuple
import paddle import paddle
from paddle import nn from paddle import nn
from typeguard import check_argument_types from typeguard import check_argument_types
...@@ -129,3 +131,66 @@ def initialize(model: nn.Layer, init: str): ...@@ -129,3 +131,66 @@ def initialize(model: nn.Layer, init: str):
nn.initializer.Constant()) nn.initializer.Constant())
else: else:
raise ValueError("Unknown initialization: " + init) raise ValueError("Unknown initialization: " + init)
# for VITS
def get_random_segments(
x: paddle.paddle,
x_lengths: paddle.Tensor,
segment_size: int, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Get random segments.
Args:
x (Tensor): Input tensor (B, C, T).
x_lengths (Tensor): Length tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
Tensor: Start index tensor (B,).
"""
b, c, t = paddle.shape(x)
max_start_idx = x_lengths - segment_size
start_idxs = paddle.cast(paddle.rand([b]) * max_start_idx, 'int64')
segments = get_segments(x, start_idxs, segment_size)
return segments, start_idxs
def get_segments(
x: paddle.Tensor,
start_idxs: paddle.Tensor,
segment_size: int, ) -> paddle.Tensor:
"""Get segments.
Args:
x (Tensor): Input tensor (B, C, T).
start_idxs (Tensor): Start index tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
"""
b, c, t = paddle.shape(x)
segments = paddle.zeros([b, c, segment_size], dtype=x.dtype)
for i, start_idx in enumerate(start_idxs):
segments[i] = x[i, :, start_idx:start_idx + segment_size]
return segments
# see https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/pytorch_project_convertor/API_docs/ops/torch.gather.md
def paddle_gather(x, dim, index):
index_shape = index.shape
index_flatten = index.flatten()
if dim < 0:
dim = len(x.shape) + dim
nd_index = []
for k in range(len(x.shape)):
if k == dim:
nd_index.append(index_flatten)
else:
reshape_shape = [1] * len(x.shape)
reshape_shape[k] = x.shape[k]
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
x_arange = x_arange.reshape(reshape_shape)
dim_index = paddle.expand(x_arange, index_shape).flatten()
nd_index.append(dim_index)
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
return paddle_out
...@@ -36,4 +36,4 @@ def repeat(N, fn): ...@@ -36,4 +36,4 @@ def repeat(N, fn):
Returns: Returns:
MultiSequential: Repeated model instance. MultiSequential: Repeated model instance.
""" """
return MultiSequential(*[fn(n) for n in range(N)]) return MultiSequential(* [fn(n) for n in range(N)])
...@@ -14,6 +14,14 @@ ...@@ -14,6 +14,14 @@
import paddle import paddle
from paddle import nn from paddle import nn
scheduler_classes = dict(
ReduceOnPlateau=paddle.optimizer.lr.ReduceOnPlateau,
lambda_decay=paddle.optimizer.lr.LambdaDecay,
step_decay=paddle.optimizer.lr.StepDecay,
multistep_decay=paddle.optimizer.lr.MultiStepDecay,
exponential_decay=paddle.optimizer.lr.ExponentialDecay,
CosineAnnealingDecay=paddle.optimizer.lr.CosineAnnealingDecay, )
optim_classes = dict( optim_classes = dict(
adadelta=paddle.optimizer.Adadelta, adadelta=paddle.optimizer.Adadelta,
adagrad=paddle.optimizer.Adagrad, adagrad=paddle.optimizer.Adagrad,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2
import six
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--profile_path',
type=str,
default='',
help='Input profile file name. If there are multiple file, the format '
'should be trainer1=file1,trainer2=file2,ps=file3')
parser.add_argument(
'--timeline_path', type=str, default='', help='Output timeline file name.')
args = parser.parse_args()
class _ChromeTraceFormatter(object):
def __init__(self):
self._events = []
self._metadata = []
def _create_event(self, ph, category, name, pid, tid, timestamp):
"""Creates a new Chrome Trace event.
For details of the file format, see:
https://github.com/catapult-project/catapult/blob/master/tracing/README.md
Args:
ph: The type of event - usually a single character.
category: The event category as a string.
name: The event name as a string.
pid: Identifier of the process generating this event as an integer.
tid: Identifier of the thread generating this event as an integer.
timestamp: The timestamp of this event as a long integer.
Returns:
A JSON compatible event object.
"""
event = {}
event['ph'] = ph
event['cat'] = category
event['name'] = name.replace("ParallelExecutor::Run/", "")
event['pid'] = pid
event['tid'] = tid
event['ts'] = timestamp
return event
def emit_pid(self, name, pid):
"""Adds a process metadata event to the trace.
Args:
name: The process name as a string.
pid: Identifier of the process as an integer.
"""
event = {}
event['name'] = 'process_name'
event['ph'] = 'M'
event['pid'] = pid
event['args'] = {'name': name}
self._metadata.append(event)
def emit_region(self, timestamp, duration, pid, tid, category, name, args):
"""Adds a region event to the trace.
Args:
timestamp: The start timestamp of this region as a long integer.
duration: The duration of this region as a long integer.
pid: Identifier of the process generating this event as an integer.
tid: Identifier of the thread generating this event as an integer.
category: The event category as a string.
name: The event name as a string.
args: A JSON-compatible dictionary of event arguments.
"""
event = self._create_event('X', category, name, pid, tid, timestamp)
event['dur'] = duration
event['args'] = args
self._events.append(event)
def emit_counter(self, category, name, pid, timestamp, counter, value):
"""Emits a record for a single counter.
Args:
category: The event category as string
name: The event name as string
pid: Identifier of the process generating this event as integer
timestamp: The timestamps of this event as long integer
counter: Name of the counter as string
value: Value of the counter as integer
tid: Thread id of the allocation as integer
"""
event = self._create_event('C', category, name, pid, 0, timestamp)
event['args'] = {counter: value}
self._events.append(event)
def format_to_string(self, pretty=False):
"""Formats the chrome trace to a string.
Args:
pretty: (Optional.) If True, produce human-readable JSON output.
Returns:
A JSON-formatted string in Chrome Trace format.
"""
trace = {}
trace['traceEvents'] = self._metadata + self._events
if pretty:
return json.dumps(trace, indent=4, separators=(',', ': '))
else:
return json.dumps(trace, separators=(',', ':'))
class Timeline(object):
def __init__(self, profile_dict):
self._profile_dict = profile_dict
self._pid = 0
self._devices = dict()
self._mem_devices = dict()
self._chrome_trace = _ChromeTraceFormatter()
def _allocate_pid(self):
cur_pid = self._pid
self._pid += 1
return cur_pid
def _allocate_pids(self):
for k, profile_pb in six.iteritems(self._profile_dict):
for event in profile_pb.events:
if event.type == profiler_pb2.Event.CPU:
if (k, event.device_id, "CPU") not in self._devices:
pid = self._allocate_pid()
self._devices[(k, event.device_id, "CPU")] = pid
# -1 device id represents CUDA API(RunTime) call.(e.g. cudaLaunch, cudaMemcpy)
if event.device_id == -1:
self._chrome_trace.emit_pid("%s:cuda_api" % k, pid)
else:
self._chrome_trace.emit_pid(
"%s:cpu:block:%d" % (k, event.device_id), pid)
elif event.type == profiler_pb2.Event.GPUKernel:
if (k, event.device_id, "GPUKernel") not in self._devices:
pid = self._allocate_pid()
self._devices[(k, event.device_id, "GPUKernel")] = pid
self._chrome_trace.emit_pid("%s:gpu:%d" %
(k, event.device_id), pid)
if not hasattr(profile_pb, "mem_events"):
continue
for mevent in profile_pb.mem_events:
if mevent.place == profiler_pb2.MemEvent.CUDAPlace:
if (k, mevent.device_id, "GPU") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, mevent.device_id, "GPU")] = pid
self._chrome_trace.emit_pid(
"memory usage on %s:gpu:%d" % (k, mevent.device_id),
pid)
elif mevent.place == profiler_pb2.MemEvent.CPUPlace:
if (k, mevent.device_id, "CPU") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, mevent.device_id, "CPU")] = pid
self._chrome_trace.emit_pid(
"memory usage on %s:cpu:%d" % (k, mevent.device_id),
pid)
elif mevent.place == profiler_pb2.MemEvent.CUDAPinnedPlace:
if (k, mevent.device_id,
"CUDAPinnedPlace") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, mevent.device_id,
"CUDAPinnedPlace")] = pid
self._chrome_trace.emit_pid(
"memory usage on %s:cudapinnedplace:%d" %
(k, mevent.device_id), pid)
elif mevent.place == profiler_pb2.MemEvent.NPUPlace:
if (k, mevent.device_id, "NPU") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, mevent.device_id, "NPU")] = pid
self._chrome_trace.emit_pid(
"memory usage on %s:npu:%d" % (k, mevent.device_id),
pid)
if (k, 0, "CPU") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, 0, "CPU")] = pid
self._chrome_trace.emit_pid("memory usage on %s:cpu:%d" %
(k, 0), pid)
if (k, 0, "GPU") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, 0, "GPU")] = pid
self._chrome_trace.emit_pid("memory usage on %s:gpu:%d" %
(k, 0), pid)
if (k, 0, "CUDAPinnedPlace") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, 0, "CUDAPinnedPlace")] = pid
self._chrome_trace.emit_pid(
"memory usage on %s:cudapinnedplace:%d" % (k, 0), pid)
if (k, 0, "NPU") not in self._mem_devices:
pid = self._allocate_pid()
self._mem_devices[(k, 0, "NPU")] = pid
self._chrome_trace.emit_pid("memory usage on %s:npu:%d" %
(k, 0), pid)
def _allocate_events(self):
for k, profile_pb in six.iteritems(self._profile_dict):
for event in profile_pb.events:
if event.type == profiler_pb2.Event.CPU:
type = "CPU"
elif event.type == profiler_pb2.Event.GPUKernel:
type = "GPUKernel"
pid = self._devices[(k, event.device_id, type)]
args = {'name': event.name}
if event.memcopy.bytes > 0:
args['mem_bytes'] = event.memcopy.bytes
if hasattr(event, "detail_info") and event.detail_info:
args['detail_info'] = event.detail_info
# TODO(panyx0718): Chrome tracing only handles ms. However, some
# ops takes micro-seconds. Hence, we keep the ns here.
self._chrome_trace.emit_region(
event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid,
event.sub_device_id, 'Op', event.name, args)
def _allocate_memory_event(self):
if not hasattr(profiler_pb2, "MemEvent"):
return
place_to_str = {
profiler_pb2.MemEvent.CPUPlace: "CPU",
profiler_pb2.MemEvent.CUDAPlace: "GPU",
profiler_pb2.MemEvent.CUDAPinnedPlace: "CUDAPinnedPlace",
profiler_pb2.MemEvent.NPUPlace: "NPU"
}
for k, profile_pb in six.iteritems(self._profile_dict):
mem_list = []
end_profiler = 0
for mevent in profile_pb.mem_events:
crt_info = dict()
crt_info['time'] = mevent.start_ns
crt_info['size'] = mevent.bytes
if mevent.place in place_to_str:
place = place_to_str[mevent.place]
else:
place = "UnDefine"
crt_info['place'] = place
pid = self._mem_devices[(k, mevent.device_id, place)]
crt_info['pid'] = pid
crt_info['thread_id'] = mevent.thread_id
crt_info['device_id'] = mevent.device_id
mem_list.append(crt_info)
crt_info = dict()
crt_info['place'] = place
crt_info['pid'] = pid
crt_info['thread_id'] = mevent.thread_id
crt_info['device_id'] = mevent.device_id
crt_info['time'] = mevent.end_ns
crt_info['size'] = -mevent.bytes
mem_list.append(crt_info)
end_profiler = max(end_profiler, crt_info['time'])
mem_list.sort(key=lambda tmp: (tmp.get('time', 0)))
i = 0
total_size = 0
while i < len(mem_list):
total_size += mem_list[i]['size']
while i < len(mem_list) - 1 and mem_list[i]['time'] == mem_list[
i + 1]['time']:
total_size += mem_list[i + 1]['size']
i += 1
self._chrome_trace.emit_counter(
"Memory", "Memory", mem_list[i]['pid'], mem_list[i]['time'],
0, total_size)
i += 1
def generate_chrome_trace(self):
self._allocate_pids()
self._allocate_events()
self._allocate_memory_event()
return self._chrome_trace.format_to_string()
profile_path = '/tmp/profile'
if args.profile_path:
profile_path = args.profile_path
timeline_path = '/tmp/timeline'
if args.timeline_path:
timeline_path = args.timeline_path
profile_paths = profile_path.split(',')
profile_dict = dict()
if len(profile_paths) == 1:
with open(profile_path, 'rb') as f:
profile_s = f.read()
profile_pb = profiler_pb2.Profile()
profile_pb.ParseFromString(profile_s)
profile_dict['trainer'] = profile_pb
else:
for profile_path in profile_paths:
k, v = profile_path.split('=')
with open(v, 'rb') as f:
profile_s = f.read()
profile_pb = profiler_pb2.Profile()
profile_pb.ParseFromString(profile_s)
profile_dict[k] = profile_pb
tl = Timeline(profile_dict)
with open(timeline_path, 'w') as f:
f.write(tl.generate_chrome_trace())
...@@ -11,24 +11,3 @@ ...@@ -11,24 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
import paddle
from paddle.framework import core
from paddle.framework import CUDAPlace
def synchronize():
"""Trigger cuda synchronization for better timing."""
place = paddle.fluid.framework._current_expected_place()
if isinstance(place, CUDAPlace):
paddle.fluid.core._cuda_synchronize(place)
@contextmanager
def nvtx_span(name):
try:
core.nvprof_nvtx_push(name)
yield
finally:
core.nvprof_nvtx_pop()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import importlib
__all__ = ["dynamic_import"]
def dynamic_import(import_path, alias=dict()):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
e.g., 'paddlespeech.s2t.models.u2:U2Model'
:param dict alias: shortcut for registered class
:return: imported class
"""
if import_path not in alias and ":" not in import_path:
raise ValueError(
"import_path should be one of {} or "
'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : '
"{}".format(set(alias), import_path))
if ":" not in import_path:
import_path = alias[import_path]
module_name, objname = import_path.split(":")
m = importlib.import_module(module_name)
return getattr(m, objname)
...@@ -98,7 +98,6 @@ requirements = { ...@@ -98,7 +98,6 @@ requirements = {
} }
def check_call(cmd: str, shell=False, executable=None): def check_call(cmd: str, shell=False, executable=None):
try: try:
sp.check_call( sp.check_call(
...@@ -112,6 +111,7 @@ def check_call(cmd: str, shell=False, executable=None): ...@@ -112,6 +111,7 @@ def check_call(cmd: str, shell=False, executable=None):
file=sys.stderr) file=sys.stderr)
raise e raise e
def check_output(cmd: str, shell=False): def check_output(cmd: str, shell=False):
try: try:
out_bytes = sp.check_output(cmd.split()) out_bytes = sp.check_output(cmd.split())
...@@ -146,6 +146,7 @@ def _remove(files: str): ...@@ -146,6 +146,7 @@ def _remove(files: str):
for f in files: for f in files:
f.unlink() f.unlink()
################################# Install ################################## ################################# Install ##################################
...@@ -308,6 +309,5 @@ setup_info = dict( ...@@ -308,6 +309,5 @@ setup_info = dict(
] ]
}) })
with version_info(): with version_info():
setup(**setup_info) setup(**setup_info)
...@@ -24,35 +24,36 @@ trainer_list=$(func_parser_value "${lines[14]}") ...@@ -24,35 +24,36 @@ trainer_list=$(func_parser_value "${lines[14]}")
if [ ${MODE} = "benchmark_train" ];then if [ ${MODE} = "benchmark_train" ];then
curPath=$(readlink -f "$(dirname "$0")") curPath=$(readlink -f "$(dirname "$0")")
echo "curPath:"${curPath} echo "curPath:"${curPath} # /PaddleSpeech/tests/test_tipc
cd ${curPath}/../.. cd ${curPath}/../..
echo "------------- install for speech "
apt-get install libsndfile1 -y apt-get install libsndfile1 -y
pip install yacs -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install pytest-runner -i https://pypi.tuna.tsinghua.edu.cn/simple pip install pytest-runner -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install kaldiio -i https://pypi.tuna.tsinghua.edu.cn/simple pip install kaldiio -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install setuptools_scm -i https://pypi.tuna.tsinghua.edu.cn/simple pip install setuptools_scm -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install . -i https://pypi.tuna.tsinghua.edu.cn/simple pip install . -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install jsonlines
pip list
cd - cd -
if [ ${model_name} == "conformer" ]; then if [ ${model_name} == "conformer" ]; then
# set the URL for aishell_tiny dataset # set the URL for aishell_tiny dataset
URL=${conformer_data_URL:-"None"} conformer_aishell_URL=${conformer_aishell_URL:-"None"}
echo "URL:"${URL} if [ ${conformer_aishell_URL} == 'None' ];then
if [ ${URL} == 'None' ];then
echo "please contact author to get the URL.\n" echo "please contact author to get the URL.\n"
exit exit
else else
wget -P ${curPath}/../../dataset/aishell/ ${URL} rm -rf ${curPath}/../../dataset/aishell/aishell.py
rm -rf ${curPath}/../../dataset/aishell/data_aishell_tiny*
wget -P ${curPath}/../../dataset/aishell/ ${conformer_aishell_URL}
fi fi
sed -i "s#^URL_ROOT_TAG#URL_ROOT = '${URL}'#g" ${curPath}/conformer/scripts/aishell_tiny.py
cp ${curPath}/conformer/scripts/aishell_tiny.py ${curPath}/../../dataset/aishell/
cd ${curPath}/../../examples/aishell/asr1 cd ${curPath}/../../examples/aishell/asr1
source path.sh
# download audio data #Prepare the data
sed -i "s#aishell.py#aishell_tiny.py#g" ./local/data.sh
sed -i "s#python3#python#g" ./local/data.sh sed -i "s#python3#python#g" ./local/data.sh
bash ./local/data.sh || exit -1 bash run.sh --stage 0 --stop_stage 0 # 执行第一遍的时候会偶现报错
if [ $? -ne 0 ]; then bash run.sh --stage 0 --stop_stage 0
exit 1
fi
mkdir -p ${curPath}/conformer/benchmark_train/ mkdir -p ${curPath}/conformer/benchmark_train/
cp -rf conf ${curPath}/conformer/benchmark_train/ cp -rf conf ${curPath}/conformer/benchmark_train/
cp -rf data ${curPath}/conformer/benchmark_train/ cp -rf data ${curPath}/conformer/benchmark_train/
......
...@@ -20,7 +20,6 @@ of each audio file in the data set. ...@@ -20,7 +20,6 @@ of each audio file in the data set.
""" """
import argparse import argparse
import codecs import codecs
import json
import os import os
from pathlib import Path from pathlib import Path
...@@ -89,7 +88,7 @@ def create_manifest(data_dir, manifest_path_prefix): ...@@ -89,7 +88,7 @@ def create_manifest(data_dir, manifest_path_prefix):
duration = float(len(audio_data) / samplerate) duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id] text = transcript_dict[audio_id]
json_lines.append(audio_path) json_lines.append(audio_path)
reference_lines.append(str(total_num+1) + "\t" + text) reference_lines.append(str(total_num + 1) + "\t" + text)
total_sec += duration total_sec += duration
total_text += len(text) total_text += len(text)
...@@ -106,6 +105,7 @@ def create_manifest(data_dir, manifest_path_prefix): ...@@ -106,6 +105,7 @@ def create_manifest(data_dir, manifest_path_prefix):
manifest_dir = os.path.dirname(manifest_path_prefix) manifest_dir = os.path.dirname(manifest_path_prefix)
def prepare_dataset(url, md5sum, target_dir, manifest_path=None): def prepare_dataset(url, md5sum, target_dir, manifest_path=None):
"""Download, unpack and create manifest file.""" """Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell') data_dir = os.path.join(target_dir, 'data_aishell')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册