提交 05d41523 编写于 作者: H huangyuxin

Merge branch 'develop' into webdataset

([简体中文](./README_cn.md)|English) ([简体中文](./README_cn.md)|English)
<p align="center"> <p align="center">
<img src="./docs/images/PaddleSpeech_logo.png" /> <img src="./docs/images/PaddleSpeech_logo.png" />
...@@ -494,6 +495,14 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r ...@@ -494,6 +495,14 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
<a href = "./examples/aishell3/vc1">ge2e-fastspeech2-aishell3</a> <a href = "./examples/aishell3/vc1">ge2e-fastspeech2-aishell3</a>
</td> </td>
</tr> </tr>
<tr>
<td rowspan="3">End-to-End</td>
<td>VITS</td>
<td >CSMSC</td>
<td>
<a href = "./examples/csmsc/vits">VITS-csmsc</a>
</td>
</tr>
</tbody> </tbody>
</table> </table>
......
(简体中文|[English](./README.md)) (简体中文|[English](./README.md))
<p align="center"> <p align="center">
<img src="./docs/images/PaddleSpeech_logo.png" /> <img src="./docs/images/PaddleSpeech_logo.png" />
...@@ -481,6 +482,15 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 ...@@ -481,6 +482,15 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
<a href = "./examples/aishell3/vc1">ge2e-fastspeech2-aishell3</a> <a href = "./examples/aishell3/vc1">ge2e-fastspeech2-aishell3</a>
</td> </td>
</tr> </tr>
</tr>
<tr>
<td rowspan="3">端到端</td>
<td>VITS</td>
<td >CSMSC</td>
<td>
<a href = "./examples/csmsc/vits">VITS-csmsc</a>
</td>
</tr>
</tbody> </tbody>
</table> </table>
......
# [Aidatatang_200zh](http://www.openslr.org/62/) # [Aidatatang_200zh](http://openslr.elda.org/62/)
Aidatatang_200zh is a free Chinese Mandarin speech corpus provided by Beijing DataTang Technology Co., Ltd under Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License. Aidatatang_200zh is a free Chinese Mandarin speech corpus provided by Beijing DataTang Technology Co., Ltd under Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License.
The contents and the corresponding descriptions of the corpus include: The contents and the corresponding descriptions of the corpus include:
......
# [Aishell1](http://www.openslr.org/33/) # [Aishell1](http://openslr.elda.org/33/)
This Open Source Mandarin Speech Corpus, AISHELL-ASR0009-OS1, is 178 hours long. It is a part of AISHELL-ASR0009, of which utterance contains 11 domains, including smart home, autonomous driving, and industrial production. The whole recording was put in quiet indoor environment, using 3 different devices at the same time: high fidelity microphone (44.1kHz, 16-bit,); Android-system mobile phone (16kHz, 16-bit), iOS-system mobile phone (16kHz, 16-bit). Audios in high fidelity were re-sampled to 16kHz to build AISHELL- ASR0009-OS1. 400 speakers from different accent areas in China were invited to participate in the recording. The manual transcription accuracy rate is above 95%, through professional speech annotation and strict quality inspection. The corpus is divided into training, development and testing sets. ( This database is free for academic research, not in the commerce, if without permission. ) This Open Source Mandarin Speech Corpus, AISHELL-ASR0009-OS1, is 178 hours long. It is a part of AISHELL-ASR0009, of which utterance contains 11 domains, including smart home, autonomous driving, and industrial production. The whole recording was put in quiet indoor environment, using 3 different devices at the same time: high fidelity microphone (44.1kHz, 16-bit,); Android-system mobile phone (16kHz, 16-bit), iOS-system mobile phone (16kHz, 16-bit). Audios in high fidelity were re-sampled to 16kHz to build AISHELL- ASR0009-OS1. 400 speakers from different accent areas in China were invited to participate in the recording. The manual transcription accuracy rate is above 95%, through professional speech annotation and strict quality inspection. The corpus is divided into training, development and testing sets. ( This database is free for academic research, not in the commerce, if without permission. )
...@@ -31,7 +31,7 @@ from utils.utility import unpack ...@@ -31,7 +31,7 @@ from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/33' URL_ROOT = 'http://openslr.elda.org/resources/33'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/33' # URL_ROOT = 'https://openslr.magicdatatech.com/resources/33'
DATA_URL = URL_ROOT + '/data_aishell.tgz' DATA_URL = URL_ROOT + '/data_aishell.tgz'
MD5_DATA = '2f494334227864a8a8fec932999db9d8' MD5_DATA = '2f494334227864a8a8fec932999db9d8'
......
...@@ -31,7 +31,7 @@ import soundfile ...@@ -31,7 +31,7 @@ import soundfile
from utils.utility import download from utils.utility import download
from utils.utility import unpack from utils.utility import unpack
URL_ROOT = "http://www.openslr.org/resources/12" URL_ROOT = "http://openslr.elda.org/resources/12"
#URL_ROOT = "https://openslr.magicdatatech.com/resources/12" #URL_ROOT = "https://openslr.magicdatatech.com/resources/12"
URL_TEST_CLEAN = URL_ROOT + "/test-clean.tar.gz" URL_TEST_CLEAN = URL_ROOT + "/test-clean.tar.gz"
URL_TEST_OTHER = URL_ROOT + "/test-other.tar.gz" URL_TEST_OTHER = URL_ROOT + "/test-other.tar.gz"
......
# [MagicData](http://www.openslr.org/68/) # [MagicData](http://openslr.elda.org/68/)
MAGICDATA Mandarin Chinese Read Speech Corpus was developed by MAGIC DATA Technology Co., Ltd. and freely published for non-commercial use. MAGICDATA Mandarin Chinese Read Speech Corpus was developed by MAGIC DATA Technology Co., Ltd. and freely published for non-commercial use.
The contents and the corresponding descriptions of the corpus include: The contents and the corresponding descriptions of the corpus include:
......
...@@ -30,7 +30,7 @@ import soundfile ...@@ -30,7 +30,7 @@ import soundfile
from utils.utility import download from utils.utility import download
from utils.utility import unpack from utils.utility import unpack
URL_ROOT = "http://www.openslr.org/resources/31" URL_ROOT = "http://openslr.elda.org/resources/31"
URL_TRAIN_CLEAN = URL_ROOT + "/train-clean-5.tar.gz" URL_TRAIN_CLEAN = URL_ROOT + "/train-clean-5.tar.gz"
URL_DEV_CLEAN = URL_ROOT + "/dev-clean-2.tar.gz" URL_DEV_CLEAN = URL_ROOT + "/dev-clean-2.tar.gz"
......
...@@ -34,7 +34,7 @@ from utils.utility import unpack ...@@ -34,7 +34,7 @@ from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'https://www.openslr.org/resources/17' URL_ROOT = 'https://openslr.elda.org/resources/17'
DATA_URL = URL_ROOT + '/musan.tar.gz' DATA_URL = URL_ROOT + '/musan.tar.gz'
MD5_DATA = '0c472d4fc0c5141eca47ad1ffeb2a7df' MD5_DATA = '0c472d4fc0c5141eca47ad1ffeb2a7df'
......
# [Primewords](http://www.openslr.org/47/) # [Primewords](http://openslr.elda.org/47/)
This free Chinese Mandarin speech corpus set is released by Shanghai Primewords Information Technology Co., Ltd. This free Chinese Mandarin speech corpus set is released by Shanghai Primewords Information Technology Co., Ltd.
The corpus is recorded by smart mobile phones from 296 native Chinese speakers. The transcription accuracy is larger than 98%, at the confidence level of 95%. It is free for academic use. The corpus is recorded by smart mobile phones from 296 native Chinese speakers. The transcription accuracy is larger than 98%, at the confidence level of 95%. It is free for academic use.
......
...@@ -34,7 +34,7 @@ from utils.utility import unzip ...@@ -34,7 +34,7 @@ from utils.utility import unzip
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = '--no-check-certificate http://www.openslr.org/resources/28' URL_ROOT = '--no-check-certificate https://us.openslr.org/resources/28/rirs_noises.zip'
DATA_URL = URL_ROOT + '/rirs_noises.zip' DATA_URL = URL_ROOT + '/rirs_noises.zip'
MD5_DATA = 'e6f48e257286e05de56413b4779d8ffb' MD5_DATA = 'e6f48e257286e05de56413b4779d8ffb'
......
# [FreeST](http://www.openslr.org/38/) # [FreeST](http://openslr.elda.org/38/)
# [THCHS30](http://www.openslr.org/18/) # [THCHS30](http://openslr.elda.org/18/)
This is the *data part* of the `THCHS30 2015` acoustic data This is the *data part* of the `THCHS30 2015` acoustic data
& scripts dataset. & scripts dataset.
......
...@@ -32,7 +32,7 @@ from utils.utility import unpack ...@@ -32,7 +32,7 @@ from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/18' URL_ROOT = 'http://openslr.elda.org/resources/18'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/18' # URL_ROOT = 'https://openslr.magicdatatech.com/resources/18'
DATA_URL = URL_ROOT + '/data_thchs30.tgz' DATA_URL = URL_ROOT + '/data_thchs30.tgz'
TEST_NOISE_URL = URL_ROOT + '/test-noise.tgz' TEST_NOISE_URL = URL_ROOT + '/test-noise.tgz'
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Author: zhendong.peng@mobvoi.com (Zhendong Peng)
import argparse
from flask import Flask
from flask import render_template
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--port', default=19999, type=int, help='port id')
args = parser.parse_args()
app = Flask(__name__)
@app.route('/')
def index():
return render_template('index.html')
if __name__ == '__main__':
app.run(host='0.0.0.0', port=args.port, debug=True)
此差异已折叠。
# paddlespeech serving 网页Demo # paddlespeech serving 网页Demo
- 感谢[wenet](https://github.com/wenet-e2e/wenet)团队的前端demo代码. ![图片](./paddle_web_demo.png)
step1: 开启流式语音识别服务器端
## 使用方法 ```
### 1. 在本地电脑启动网页服务 # 开启流式语音识别服务
``` cd PaddleSpeech/demos/streaming_asr_server
python app.py paddlespeech_server start --config_file conf/ws_conformer_wenetspeech_application_faster.yaml
```
``` step2: 谷歌游览器打开 `web`目录下`index.html`
### 2. 本地电脑浏览器 step3: 点击`连接`,验证WebSocket是否成功连接
step4:点击开始录音(弹窗询问,允许录音)
在浏览器中输入127.0.0.1:19999 即可看到相关网页Demo。
![图片](./paddle_web_demo.png)
/*
* @Author: baipengxia
* @Date: 2021-03-12 11:44:28
* @Last Modified by: baipengxia
* @Last Modified time: 2021-03-12 15:14:24
*/
/** COMMON RESET **/
* {
-webkit-tap-highlight-color: rgba(0, 0, 0, 0);
}
body,
h1,
h2,
h3,
h4,
h5,
h6,
hr,
p,
dl,
dt,
dd,
ul,
ol,
li,
fieldset,
lengend,
button,
input,
textarea,
th,
td {
margin: 0;
padding: 0;
color: #000;
}
body {
font-size: 14px;
}
html, body {
min-width: 1200px;
}
button,
input,
select,
textarea {
font-size: 14px;
}
h1 {
font-size: 18px;
}
h2 {
font-size: 14px;
}
h3 {
font-size: 14px;
}
ul,
ol,
li {
list-style: none;
}
a {
text-decoration: none;
}
a:hover {
text-decoration: none;
}
fieldset,
img {
border: none;
}
table {
border-collapse: collapse;
border-spacing: 0;
}
i {
font-style: normal;
}
label {
position: inherit;
}
.clearfix:after {
content: ".";
display: block;
height: 0;
clear: both;
visibility: hidden;
}
.clearfix {
zoom: 1;
display: block;
}
html,
body {
font-family: Tahoma, Arial, 'microsoft yahei', 'Roboto', 'Droid Sans', 'Helvetica Neue', 'Droid Sans Fallback', 'Heiti SC', 'Hiragino Sans GB', 'Simsun', 'sans-self';
}
.audio-banner {
width: 100%;
overflow: auto;
padding: 0;
background: url('../image/voice-dictation.svg');
background-size: cover;
}
.weaper {
width: 1200px;
height: 155px;
margin: 72px auto;
}
.text-content {
width: 670px;
height: 100%;
float: left;
}
.text-content .title {
font-size: 34px;
font-family: 'PingFangSC-Medium';
font-weight: 500;
color: rgba(255, 255, 255, 1);
line-height: 48px;
}
.text-content .con {
font-size: 16px;
font-family: PingFangSC-Light;
font-weight: 300;
color: rgba(255, 255, 255, 1);
line-height: 30px;
}
.img-con {
width: 416px;
height: 100%;
float: right;
}
.img-con img {
width: 100%;
height: 100%;
}
.con-container {
margin-top: 34px;
}
.audio-advantage {
background: #f8f9fa;
}
.asr-advantage {
width: 1200px;
margin: 0 auto;
}
.asr-advantage h2 {
text-align: center;
font-size: 22px;
padding: 30px 0 0 0;
}
.asr-advantage > ul > li {
box-sizing: border-box;
padding: 0 16px;
width: 33%;
text-align: center;
margin-bottom: 35px;
}
.asr-advantage > ul > li .icons{
margin-top: 10px;
margin-bottom: 20px;
width: 42px;
height: 42px;
}
.service-item-content {
margin-top: 35px;
display: flex;
justify-content: center;
flex-wrap: wrap;
}
.service-item-content img {
width: 160px;
vertical-align: bottom;
}
.service-item-content > li {
box-sizing: border-box;
padding: 0 16px;
width: 33%;
text-align: center;
margin-bottom: 35px;
}
.service-item-content > li .service-item-content-title {
line-height: 1.5;
font-weight: 700;
margin-top: 10px;
}
.service-item-content > li .service-item-content-desc {
margin-top: 5px;
line-height: 1.8;
color: #657384;
}
.audio-scene-con {
width: 100%;
padding-bottom: 84px;
background: #fff;
}
.audio-scene {
overflow: auto;
width: 1200px;
background: #fff;
text-align: center;
padding: 0;
margin: 0 auto;
}
.audio-scene h2 {
padding: 30px 0 0 0;
font-size: 22px;
text-align: center;
}
.audio-experience {
width: 100%;
height: 538px;
background: #fff;
padding: 0;
margin: 0;
overflow: auto;
}
.asr-box {
width: 1200px;
height: 394px;
margin: 64px auto;
}
.asr-box h2 {
font-size: 22px;
text-align: center;
margin-bottom: 64px;
}
.voice-container {
position: relative;
width: 1200px;
height: 308px;
background: rgba(255, 255, 255, 1);
border-radius: 8px;
border: 1px solid rgba(225, 225, 225, 1);
}
.voice-container .voice {
height: 236px;
width: 100%;
border-radius: 8px;
}
.voice-container .voice textarea {
height: 100%;
width: 100%;
border: none;
outline: none;
border-radius: 8px;
padding: 25px;
font-size: 14px;
box-sizing: border-box;
resize: none;
}
.voice-input {
width: 100%;
height: 72px;
box-sizing: border-box;
padding-left: 35px;
background: rgba(242, 244, 245, 1);
border-radius: 8px;
line-height: 72px;
}
.voice-input .el-select {
width: 492px;
}
.start-voice {
display: inline-block;
margin-left: 10px;
}
.start-voice .time {
margin-right: 25px;
}
.asr-advantage > ul > li {
margin-bottom: 77px;
}
#msg {
width: 100%;
line-height: 40px;
font-size: 14px;
margin-left: 330px;
}
#captcha {
margin-left: 350px !important;
display: inline-block;
position: relative;
}
.black {
position: fixed;
width: 100%;
height: 100%;
z-index: 5;
background: rgba(0, 0, 0, 0.5);
top: 0;
left: 0;
}
.container {
position: fixed;
z-index: 6;
top: 25%;
left: 10%;
}
.audio-scene-con {
width: 100%;
padding-bottom: 84px;
background: #fff;
}
#sound {
color: #fff;
cursor: pointer;
background: #147ede;
padding: 10px;
margin-top: 30px;
margin-left: 135px;
width: 176px;
height: 30px !important;
text-align: center;
line-height: 30px !important;
border-radius: 10px;
}
.con-ten {
position: absolute;
width: 100%;
height: 100%;
z-index: 5;
background: #fff;
opacity: 0.5;
top: 0;
left: 0;
}
.websocket-url {
width: 320px;
height: 20px;
border: 1px solid #dcdfe6;
line-height: 20px;
padding: 10px;
border-radius: 4px;
}
.voice-btn {
color: #fff;
background-color: #409eff;
font-weight: 500;
padding: 12px 20px;
font-size: 14px;
border-radius: 4px;
border: 0;
cursor: pointer;
}
.voice-btn.end {
display: none;
}
.result-text {
background: #fff;
padding: 20px;
}
.voice-footer {
border-top: 1px solid #dddede;
background: #f7f9fa;
text-align: center;
margin-bottom: 8px;
color: #333;
font-size: 12px;
padding: 20px 0;
}
/** line animate **/
.time-box {
display: none;
margin-left: 10px;
width: 300px;
}
.total-time {
font-size: 14px;
color: #545454;
}
.voice-btn.end.show,
.time-box.show {
display: inline;
}
.start-taste-line {
margin-right: 20px;
display: inline-block;
}
.start-taste-line hr {
background-color: #187cff;
width: 3px;
height: 8px;
margin: 0 3px;
display: inline-block;
border: none;
}
.hr {
animation: note 0.2s ease-in-out;
animation-iteration-count: infinite;
animation-direction: alternate;
}
.hr-one {
animation-delay: -0.9s;
}
.hr-two {
animation-delay: -0.8s;
}
.hr-three {
animation-delay: -0.7s;
}
.hr-four {
animation-delay: -0.6s;
}
.hr-five {
animation-delay: -0.5s;
}
.hr-six {
animation-delay: -0.4s;
}
.hr-seven {
animation-delay: -0.3s;
}
.hr-eight {
animation-delay: -0.2s;
}
.hr-nine {
animation-delay: -0.1s;
}
@keyframes note {
from {
transform: scaleY(1);
}
to {
transform: scaleY(4);
}
}
\ No newline at end of file
因为 它太大了无法显示 source diff 。你可以改为 查看blob
因为 它太大了无法显示 source diff 。你可以改为 查看blob
SoundRecognizer = {
rec: null,
wave: null,
SampleRate: 16000,
testBitRate: 16,
isCloseRecorder: false,
SendInterval: 300,
realTimeSendTryType: 'pcm',
realTimeSendTryEncBusy: 0,
realTimeSendTryTime: 0,
realTimeSendTryNumber: 0,
transferUploadNumberMax: 0,
realTimeSendTryChunk: null,
soundType: "pcm",
init: function (config) {
this.soundType = config.soundType || 'pcm';
this.SampleRate = config.sampleRate || 16000;
this.recwaveElm = config.recwaveElm || '';
this.TransferUpload = config.translerCallBack || this.TransferProcess;
this.initRecorder();
},
RealTimeSendTryReset: function (type) {
this.realTimeSendTryType = type;
this.realTimeSendTryTime = 0;
},
RealTimeSendTry: function (rec, isClose) {
var that = this;
var t1 = Date.now(), endT = 0, recImpl = Recorder.prototype;
if (this.realTimeSendTryTime == 0) {
this.realTimeSendTryTime = t1;
this.realTimeSendTryEncBusy = 0;
this.realTimeSendTryNumber = 0;
this.transferUploadNumberMax = 0;
this.realTimeSendTryChunk = null;
}
if (!isClose && t1 - this.realTimeSendTryTime < this.SendInterval) {
return;//控制缓冲达到指定间隔才进行传输
}
this.realTimeSendTryTime = t1;
var number = ++this.realTimeSendTryNumber;
//借用SampleData函数进行数据的连续处理,采样率转换是顺带的
var chunk = Recorder.SampleData(rec.buffers, rec.srcSampleRate, this.SampleRate, this.realTimeSendTryChunk, { frameType: isClose ? "" : this.realTimeSendTryType });
//清理已处理完的缓冲数据,释放内存以支持长时间录音,最后完成录音时不能调用stop,因为数据已经被清掉了
for (var i = this.realTimeSendTryChunk ? this.realTimeSendTryChunk.index : 0; i < chunk.index; i++) {
rec.buffers[i] = null;
}
this.realTimeSendTryChunk = chunk;
//没有新数据,或结束时的数据量太小,不能进行mock转码
if (chunk.data.length == 0 || isClose && chunk.data.length < 2000) {
this.TransferUpload(number, null, 0, null, isClose);
return;
}
//实时编码队列阻塞处理
if (!isClose) {
if (this.realTimeSendTryEncBusy >= 2) {
console.log("编码队列阻塞,已丢弃一帧", 1);
return;
}
}
this.realTimeSendTryEncBusy++;
//通过mock方法实时转码成mp3、wav
var encStartTime = Date.now();
var recMock = Recorder({
type: this.realTimeSendTryType
, sampleRate: this.SampleRate //采样率
, bitRate: this.testBitRate //比特率
});
recMock.mock(chunk.data, chunk.sampleRate);
recMock.stop(function (blob, duration) {
that.realTimeSendTryEncBusy && (that.realTimeSendTryEncBusy--);
blob.encTime = Date.now() - encStartTime;
//转码好就推入传输
that.TransferUpload(number, blob, duration, recMock, isClose);
}, function (msg) {
that.realTimeSendTryEncBusy && (that.realTimeSendTryEncBusy--);
//转码错误?没想到什么时候会产生错误!
console.log("不应该出现的错误:" + msg, 1);
});
},
recordClose: function () {
try {
this.rec.close(function () {
this.isCloseRecorder = true;
});
this.RealTimeSendTry(this.rec, true);//最后一次发送
} catch (ex) {
// recordClose();
}
},
recordEnd: function () {
try {
this.rec.stop(function (blob, time) {
this.recordClose();
}, function (s) {
this.recordClose();
});
} catch (ex) {
}
},
initRecorder: function () {
var that = this;
var rec = Recorder({
type: that.soundType
, bitRate: that.testBitRate
, sampleRate: that.SampleRate
, onProcess: function (buffers, level, time, sampleRate) {
that.wave.input(buffers[buffers.length - 1], level, sampleRate);
that.RealTimeSendTry(rec, false);//推入实时处理,因为是unknown格式,这里简化函数调用,没有用到buffers和bufferSampleRate,因为这些数据和rec.buffers是完全相同的。
}
});
rec.open(function () {
that.wave = Recorder.FrequencyHistogramView({
elem: that.recwaveElm, lineCount: 90
, position: 0
, minHeight: 1
, stripeEnable: false
});
rec.start();
that.isCloseRecorder = false;
that.RealTimeSendTryReset(that.soundType);//重置
});
this.rec = rec;
},
TransferProcess: function (number, blobOrNull, duration, blobRec, isClose) {
}
}
\ No newline at end of file
/*
录音
https://github.com/xiangyuecn/Recorder
src: engine/pcm.js
*/
!function(){"use strict";Recorder.prototype.enc_pcm={stable:!0,testmsg:"pcm为未封装的原始音频数据,pcm数据文件无法直接播放;支持位数8位、16位(填在比特率里面),采样率取值无限制"},Recorder.prototype.pcm=function(e,t,r){var a=this.set,n=e.length,o=8==a.bitRate?8:16,c=new ArrayBuffer(n*(o/8)),s=new DataView(c),l=0;if(8==o)for(var p=0;p<n;p++,l++){var i=128+(e[p]>>8);s.setInt8(l,i,!0)}else for(p=0;p<n;p++,l+=2)s.setInt16(l,e[p],!0);t(new Blob([s.buffer],{type:"audio/pcm"}))},Recorder.pcm2wav=function(e,a,n){e.slice&&null!=e.type&&(e={blob:e});var o=e.sampleRate||16e3,c=e.bitRate||16;if(e.sampleRate&&e.bitRate||console.warn("pcm2wav必须提供sampleRate和bitRate"),Recorder.prototype.wav){var s=new FileReader;s.onloadend=function(){var e;if(8==c){var t=new Uint8Array(s.result);e=new Int16Array(t.length);for(var r=0;r<t.length;r++)e[r]=t[r]-128<<8}else e=new Int16Array(s.result);Recorder({type:"wav",sampleRate:o,bitRate:c}).mock(e,o).stop(function(e,t){a(e,t)},n)},s.readAsArrayBuffer(e.blob)}else n("pcm2wav必须先加载wav编码器wav.js")}}();
\ No newline at end of file
/*
录音
https://github.com/xiangyuecn/Recorder
src: engine/wav.js
*/
!function(){"use strict";Recorder.prototype.enc_wav={stable:!0,testmsg:"支持位数8位、16位(填在比特率里面),采样率取值无限制"},Recorder.prototype.wav=function(t,e,n){var r=this.set,a=t.length,o=r.sampleRate,f=8==r.bitRate?8:16,i=a*(f/8),s=new ArrayBuffer(44+i),c=new DataView(s),u=0,v=function(t){for(var e=0;e<t.length;e++,u++)c.setUint8(u,t.charCodeAt(e))},w=function(t){c.setUint16(u,t,!0),u+=2},l=function(t){c.setUint32(u,t,!0),u+=4};if(v("RIFF"),l(36+i),v("WAVE"),v("fmt "),l(16),w(1),w(1),l(o),l(o*(f/8)),w(f/8),w(f),v("data"),l(i),8==f)for(var p=0;p<a;p++,u++){var d=128+(t[p]>>8);c.setInt8(u,d,!0)}else for(p=0;p<a;p++,u+=2)c.setInt16(u,t[p],!0);e(new Blob([c.buffer],{type:"audio/wav"}))}}();
\ No newline at end of file
/*
录音
https://github.com/xiangyuecn/Recorder
src: extensions/frequency.histogram.view.js
*/
!function(){"use strict";var t=function(t){return new e(t)},e=function(t){var e=this,r={scale:2,fps:20,lineCount:30,widthRatio:.6,spaceWidth:0,minHeight:0,position:-1,mirrorEnable:!1,stripeEnable:!0,stripeHeight:3,stripeMargin:6,fallDuration:1e3,stripeFallDuration:3500,linear:[0,"rgba(0,187,17,1)",.5,"rgba(255,215,0,1)",1,"rgba(255,102,0,1)"],stripeLinear:null,shadowBlur:0,shadowColor:"#bbb",stripeShadowBlur:-1,stripeShadowColor:"",onDraw:function(t,e){}};for(var a in t)r[a]=t[a];e.set=t=r;var i=t.elem;i&&("string"==typeof i?i=document.querySelector(i):i.length&&(i=i[0])),i&&(t.width=i.offsetWidth,t.height=i.offsetHeight);var o=t.scale,l=t.width*o,n=t.height*o,h=e.elem=document.createElement("div"),s=["","transform-origin:0 0;","transform:scale("+1/o+");"];h.innerHTML='<div style="width:'+t.width+"px;height:"+t.height+'px;overflow:hidden"><div style="width:'+l+"px;height:"+n+"px;"+s.join("-webkit-")+s.join("-ms-")+s.join("-moz-")+s.join("")+'"><canvas/></div></div>';var f=e.canvas=h.querySelector("canvas");e.ctx=f.getContext("2d");if(f.width=l,f.height=n,i&&(i.innerHTML="",i.appendChild(h)),!Recorder.LibFFT)throw new Error("需要lib.fft.js支持");e.fft=Recorder.LibFFT(1024),e.lastH=[],e.stripesH=[]};e.prototype=t.prototype={genLinear:function(t,e,r,a){for(var i=t.createLinearGradient(0,r,0,a),o=0;o<e.length;)i.addColorStop(e[o++],e[o++]);return i},input:function(t,e,r){var a=this;a.sampleRate=r,a.pcmData=t,a.pcmPos=0,a.inputTime=Date.now(),a.schedule()},schedule:function(){var t=this,e=t.set,r=Math.floor(1e3/e.fps);t.timer||(t.timer=setInterval(function(){t.schedule()},r));var a=Date.now(),i=t.drawTime||0;if(a-t.inputTime>1.3*e.stripeFallDuration)return clearInterval(t.timer),void(t.timer=0);if(!(a-i<r)){t.drawTime=a;for(var o=t.fft.bufferSize,l=t.pcmData,n=t.pcmPos,h=new Int16Array(o),s=0;s<o&&n<l.length;s++,n++)h[s]=l[n];t.pcmPos=n;var f=t.fft.transform(h);t.draw(f,t.sampleRate)}},draw:function(t,e){var r=this,a=r.set,i=r.ctx,o=a.scale,l=a.width*o,n=a.height*o,h=a.lineCount,s=r.fft.bufferSize,f=a.position,d=Math.abs(a.position),c=1==f?0:n,p=n;d<1&&(c=p/=2,p=Math.floor(p*(1+d)),c=Math.floor(0<f?c*(1-d):c*(1+d)));for(var u=r.lastH,v=r.stripesH,w=Math.ceil(p/(a.fallDuration/(1e3/a.fps))),g=Math.ceil(p/(a.stripeFallDuration/(1e3/a.fps))),m=a.stripeMargin*o,M=1<<(Math.round(Math.log(s)/Math.log(2)+3)<<1),b=Math.log(M)/Math.log(10),L=20*Math.log(32767)/Math.log(10),y=s/2,S=Math.min(y,Math.floor(5e3*y/(e/2))),C=S==y,H=C?h:Math.round(.8*h),R=S/H,D=C?0:(y-S)/(h-H),x=0,F=0;F<h;F++){var T=Math.ceil(x);x+=F<H?R:D;for(var B=Math.min(Math.ceil(x),y),E=0,j=T;j<B;j++)E=Math.max(E,Math.abs(t[j]));var I=M<E?Math.floor(17*(Math.log(E)/Math.log(10)-b)):0,q=p*Math.min(I/L,1);u[F]=(u[F]||0)-w,q<u[F]&&(q=u[F]),q<0&&(q=0),u[F]=q;var z=v[F]||0;if(q&&z<q+m)v[F]=q+m;else{var P=z-g;P<0&&(P=0),v[F]=P}}i.clearRect(0,0,l,n);var W=r.genLinear(i,a.linear,c,c-p),k=a.stripeLinear&&r.genLinear(i,a.stripeLinear,c,c-p)||W,A=r.genLinear(i,a.linear,c,c+p),G=a.stripeLinear&&r.genLinear(i,a.stripeLinear,c,c+p)||A;i.shadowBlur=a.shadowBlur*o,i.shadowColor=a.shadowColor;var V=a.mirrorEnable,J=V?2*h-1:h,K=a.widthRatio,N=a.spaceWidth*o;0!=N&&(K=(l-N*(J+1))/l);for(var O=Math.max(1*o,Math.floor(l*K/J)),Q=(l-J*O)/(J+1),U=a.minHeight*o,X=V?l/2-(Q+O/2):0,Y=(F=0,X);F<h;F++)Y+=Q,$=Math.floor(Y),q=Math.max(u[F],U),0!=c&&(_=c-q,i.fillStyle=W,i.fillRect($,_,O,q)),c!=n&&(i.fillStyle=A,i.fillRect($,c,O,q)),Y+=O;if(a.stripeEnable){var Z=a.stripeShadowBlur;i.shadowBlur=(-1==Z?a.shadowBlur:Z)*o,i.shadowColor=a.stripeShadowColor||a.shadowColor;var $,_,tt=a.stripeHeight*o;for(F=0,Y=X;F<h;F++)Y+=Q,$=Math.floor(Y),q=v[F],0!=c&&((_=c-q-tt)<0&&(_=0),i.fillStyle=k,i.fillRect($,_,O,tt)),c!=n&&(n<(_=c+q)+tt&&(_=n-tt),i.fillStyle=G,i.fillRect($,_,O,tt)),Y+=O}if(V){var et=Math.floor(l/2);i.save(),i.scale(-1,1),i.drawImage(r.canvas,Math.ceil(l/2),0,et,n,-et,0,et,n),i.restore()}a.onDraw(t,e)}},Recorder.FrequencyHistogramView=t}();
\ No newline at end of file
/*
录音
https://github.com/xiangyuecn/Recorder
src: extensions/lib.fft.js
*/
Recorder.LibFFT=function(r){"use strict";var s,v,d,l,F,b,g,m;return function(r){var o,t,a,f;for(s=Math.round(Math.log(r)/Math.log(2)),d=((v=1<<s)<<2)*Math.sqrt(2),l=[],F=[],b=[0],g=[0],m=[],o=0;o<v;o++){for(a=o,f=t=0;t!=s;t++)f<<=1,f|=1&a,a>>>=1;m[o]=f}var n,u=2*Math.PI/v;for(o=(v>>1)-1;0<o;o--)n=o*u,g[o]=Math.cos(n),b[o]=Math.sin(n)}(r),{transform:function(r){var o,t,a,f,n,u,e,h,M=1,i=s-1;for(o=0;o!=v;o++)l[o]=r[m[o]],F[o]=0;for(o=s;0!=o;o--){for(t=0;t!=M;t++)for(n=g[t<<i],u=b[t<<i],a=t;a<v;a+=M<<1)e=n*l[f=a+M]-u*F[f],h=n*F[f]+u*l[f],l[f]=l[a]-e,F[f]=F[a]-h,l[a]+=e,F[a]+=h;M<<=1,i--}t=v>>1;var c=new Float64Array(t);for(n=-(u=d),o=t;0!=o;o--)e=l[o],h=F[o],c[o-1]=n<e&&e<u&&n<h&&h<u?0:Math.round(e*e+h*h);return c},bufferSize:v}};
\ No newline at end of file
/*
录音
https://github.com/xiangyuecn/Recorder
src: recorder-core.js
*/
!function(y){"use strict";var h=function(){},A=function(e){return new t(e)};A.IsOpen=function(){var e=A.Stream;if(e){var t=e.getTracks&&e.getTracks()||e.audioTracks||[],n=t[0];if(n){var r=n.readyState;return"live"==r||r==n.LIVE}}return!1},A.BufferSize=4096,A.Destroy=function(){for(var e in M("Recorder Destroy"),g(),n)n[e]()};var n={};A.BindDestroy=function(e,t){n[e]=t},A.Support=function(){var e=y.AudioContext;if(e||(e=y.webkitAudioContext),!e)return!1;var t=navigator.mediaDevices||{};return t.getUserMedia||(t=navigator).getUserMedia||(t.getUserMedia=t.webkitGetUserMedia||t.mozGetUserMedia||t.msGetUserMedia),!!t.getUserMedia&&(A.Scope=t,A.Ctx&&"closed"!=A.Ctx.state||(A.Ctx=new e,A.BindDestroy("Ctx",function(){var e=A.Ctx;e&&e.close&&(e.close(),A.Ctx=0)})),!0)};var k="ConnectEnableWorklet";A[k]=!1;var d=function(e){var t=(e=e||A).BufferSize||A.BufferSize,r=A.Ctx,n=e.Stream,a=n._m=r.createMediaStreamSource(n),u=n._call,o=function(e,t){if(!t||h)for(var n in u){for(var r=t||e.inputBuffer.getChannelData(0),a=r.length,o=new Int16Array(a),s=0,i=0;i<a;i++){var c=Math.max(-1,Math.min(1,r[i]));c=c<0?32768*c:32767*c,o[i]=c,s+=Math.abs(c)}for(var f in u)u[f](o,s);return}else M(l+"多余回调",3)},s="ScriptProcessor",l="audioWorklet",i="Recorder",c=i+" "+l,f="RecProc",p=r.createScriptProcessor||r.createJavaScriptNode,v="。由于"+l+"内部1秒375次回调,在移动端可能会有性能问题导致回调丢失录音变短,PC端无影响,暂不建议开启"+l+"",m=function(){h=n.isWorklet=!1,I(n),M("Connect采用老的"+s+""+(A[k]?"但已":"")+"设置"+i+"."+k+"=true尝试启用"+l+v,3);var e=n._p=p.call(r,t,1,1);a.connect(e),e.connect(r.destination),e.onaudioprocess=function(e){o(e)}},h=n.isWorklet=!p||A[k],d=y.AudioWorkletNode;if(h&&r[l]&&d){var g,S=function(){return h&&n._na},_=n._na=function(){""!==g&&(clearTimeout(g),g=setTimeout(function(){g=0,M(l+"未返回任何音频,恢复使用"+s,3),S()&&p&&m()},500))},C=function(){if(S()){var e=n._n=new d(r,f,{processorOptions:{bufferSize:t}});a.connect(e),e.connect(r.destination),e.port.onmessage=function(e){g&&(clearTimeout(g),g=""),o(0,e.data.val)},M("Connect采用"+l+"方式,设置"+i+"."+k+"=false可恢复老式"+s+v,3)}};r.resume()[u&&"finally"](function(){if(S())if(r[f])C();else{var e,t,n=(t="class "+f+" extends AudioWorkletProcessor{",t+="constructor "+(e=function(e){return e.toString().replace(/^function|DEL_/g,"").replace(/\$RA/g,c)})(function(e){DEL_super(e);var t=this,n=e.processorOptions.bufferSize;t.bufferSize=n,t.buffer=new Float32Array(2*n),t.pos=0,t.port.onmessage=function(e){e.data.kill&&(t.kill=!0,console.log("$RA kill call"))},console.log("$RA .ctor call",e)}),t+="process "+e(function(e,t,n){var r=this,a=r.bufferSize,o=r.buffer,s=r.pos;if((e=(e[0]||[])[0]||[]).length){o.set(e,s);var i=~~((s+=e.length)/a)*a;if(i){this.port.postMessage({val:o.slice(0,i)});var c=o.subarray(i,s);(o=new Float32Array(2*a)).set(c),s=c.length,r.buffer=o}r.pos=s}return!r.kill}),t+='}try{registerProcessor("'+f+'", '+f+')}catch(e){console.error("'+c+'注册失败",e)}',"data:text/javascript;base64,"+btoa(unescape(encodeURIComponent(t))));r[l].addModule(n).then(function(e){S()&&(r[f]=1,C(),g&&_())})[u&&"catch"](function(e){M(l+".addModule失败",1,e),S()&&m()})}})}else m()},I=function(e){e._na=null,e._n&&(e._n.port.postMessage({kill:!0}),e._n.disconnect(),e._n=null)},g=function(e){var t=(e=e||A)==A,n=e.Stream;if(n&&(n._m&&(n._m.disconnect(),n._m=null),n._p&&(n._p.disconnect(),n._p.onaudioprocess=n._p=null),I(n),t)){for(var r=n.getTracks&&n.getTracks()||n.audioTracks||[],a=0;a<r.length;a++){var o=r[a];o.stop&&o.stop()}n.stop&&n.stop()}e.Stream=0};A.SampleData=function(e,t,n,r,a){r||(r={});var o=r.index||0,s=r.offset||0,i=r.frameNext||[];a||(a={});var c=a.frameSize||1;a.frameType&&(c="mp3"==a.frameType?1152:1);for(var f=0,u=o;u<e.length;u++)f+=e[u].length;f=Math.max(0,f-Math.floor(s));var l=t/n;1<l?f=Math.floor(f/l):(l=1,n=t),f+=i.length;for(var p=new Int16Array(f),v=0,u=0;u<i.length;u++)p[v]=i[u],v++;for(var m=e.length;o<m;o++){for(var h=e[o],u=s,d=h.length;u<d;){var g=Math.floor(u),S=Math.ceil(u),_=u-g,C=h[g],y=S<d?h[S]:(e[o+1]||[C])[0]||0;p[v]=C+(y-C)*_,v++,u+=l}s=u-d}i=null;var k=p.length%c;if(0<k){var I=2*(p.length-k);i=new Int16Array(p.buffer.slice(I)),p=new Int16Array(p.buffer.slice(0,I))}return{index:o,offset:s,frameNext:i,sampleRate:n,data:p}},A.PowerLevel=function(e,t){var n=e/t||0;return n<1251?Math.round(n/1250*10):Math.round(Math.min(100,Math.max(0,100*(1+Math.log(n/1e4)/Math.log(10)))))};var M=function(e,t){var n=new Date,r=("0"+n.getMinutes()).substr(-2)+":"+("0"+n.getSeconds()).substr(-2)+"."+("00"+n.getMilliseconds()).substr(-3),a=this&&this.envIn&&this.envCheck&&this.id,o=["["+r+" Recorder"+(a?":"+a:"")+"]"+e],s=arguments,i=y.console||{},c=2,f=i.log;for("number"==typeof t?f=1==t?i.error:3==t?i.warn:f:c=1;c<s.length;c++)o.push(s[c]);u?f&&f("[IsLoser]"+o[0],1<o.length?o:""):f.apply(i,o)},u=!0;try{u=!console.log.apply}catch(e){}A.CLog=M;var r=0;function t(e){this.id=++r,A.Traffic&&A.Traffic();var t={type:"mp3",bitRate:16,sampleRate:16e3,onProcess:h};for(var n in e)t[n]=e[n];this.set=t,this._S=9,this.Sync={O:9,C:9}}A.Sync={O:9,C:9},A.prototype=t.prototype={CLog:M,_streamStore:function(){return this.set.sourceStream?this:A},open:function(e,n){var r=this,t=r._streamStore();e=e||h;var a=function(e,t){t=!!t,r.CLog("录音open失败:"+e+",isUserNotAllow:"+t,1),n&&n(e,t)},o=function(){r.CLog("open ok id:"+r.id),e(),r._SO=0},s=t.Sync,i=++s.O,c=s.C;r._O=r._O_=i,r._SO=r._S;var f=function(){if(c!=s.C||!r._O){var e="open被取消";return i==s.O?r.close():e="open被中断",a(e),!0}},u=r.envCheck({envName:"H5",canProcess:!0});if(u)a("不能录音:"+u);else if(r.set.sourceStream){if(!A.Support())return void a("不支持此浏览器从流中获取录音");g(t),r.Stream=r.set.sourceStream,r.Stream._call={};try{d(t)}catch(e){return void a("从流中打开录音失败:"+e.message)}o()}else{var l=function(e,t){try{y.top.a}catch(e){return void a('无权录音(跨域,请尝试给iframe添加麦克风访问策略,如allow="camera;microphone")')}/Permission|Allow/i.test(e)?a("用户拒绝了录音权限",!0):!1===y.isSecureContext?a("无权录音(需https)"):/Found/i.test(e)?a(t+",无可用麦克风"):a(t)};if(A.IsOpen())o();else if(A.Support()){var p=function(e){(A.Stream=e)._call={},f()||setTimeout(function(){f()||(A.IsOpen()?(d(),o()):a("录音功能无效:无音频流"))},100)},v=function(e){var t=e.name||e.message||e.code+":"+e;r.CLog("请求录音权限错误",1,e),l(t,"无法录音:"+t)},m=A.Scope.getUserMedia({audio:r.set.audioTrackSet||!0},p,v);m&&m.then&&m.then(p)[e&&"catch"](v)}else l("","此浏览器不支持录音")}},close:function(e){e=e||h;var t=this,n=t._streamStore();t._stop();var r=n.Sync;if(t._O=0,t._O_!=r.O)return t.CLog("close被忽略(因为同时open了多个rec,只有最后一个会真正close)",3),void e();r.C++,g(n),t.CLog("close"),e()},mock:function(e,t){var n=this;return n._stop(),n.isMock=1,n.mockEnvInfo=null,n.buffers=[e],n.recSize=e.length,n.srcSampleRate=t,n},envCheck:function(e){var t,n=this.set;return t||(this[n.type+"_envCheck"]?t=this[n.type+"_envCheck"](e,n):n.takeoffEncodeChunk&&(t=n.type+"类型不支持设置takeoffEncodeChunk")),t||""},envStart:function(e,t){var n=this,r=n.set;if(n.isMock=e?1:0,n.mockEnvInfo=e,n.buffers=[],n.recSize=0,n.envInLast=0,n.envInFirst=0,n.envInFix=0,n.envInFixTs=[],r.sampleRate=Math.min(t,r.sampleRate),n.srcSampleRate=t,n.engineCtx=0,n[r.type+"_start"]){var a=n.engineCtx=n[r.type+"_start"](r);a&&(a.pcmDatas=[],a.pcmSize=0)}},envResume:function(){this.envInFixTs=[]},envIn:function(e,t){var a=this,o=a.set,s=a.engineCtx,n=a.srcSampleRate,r=e.length,i=A.PowerLevel(t,r),c=a.buffers,f=c.length;c.push(e);var u=c,l=f,p=Date.now(),v=Math.round(r/n*1e3);a.envInLast=p,1==a.buffers.length&&(a.envInFirst=p-v);var m=a.envInFixTs;m.splice(0,0,{t:p,d:v});for(var h=p,d=0,g=0;g<m.length;g++){var S=m[g];if(3e3<p-S.t){m.length=g;break}h=S.t,d+=S.d}var _=m[1],C=p-h;if(C/3<C-d&&(_&&1e3<C||6<=m.length)){var y=p-_.t-v;if(v/5<y){var k=!o.disableEnvInFix;if(a.CLog("["+p+"]"+(k?"":"")+"补偿"+y+"ms",3),a.envInFix+=y,k){var I=new Int16Array(y*n/1e3);r+=I.length,c.push(I)}}}var M=a.recSize,x=r,b=M+x;if(a.recSize=b,s){var R=A.SampleData(c,n,o.sampleRate,s.chunkInfo);s.chunkInfo=R,b=(M=s.pcmSize)+(x=R.data.length),s.pcmSize=b,c=s.pcmDatas,f=c.length,c.push(R.data),n=R.sampleRate}var L=Math.round(b/n*1e3),w=c.length,T=u.length,z=function(){for(var e=O?0:-x,t=null==c[0],n=f;n<w;n++){var r=c[n];null==r?t=1:(e+=r.length,s&&r.length&&a[o.type+"_encode"](s,r))}if(t&&s)for(n=l,u[0]&&(n=0);n<T;n++)u[n]=null;t&&(e=O?x:0,c[0]=null),s?s.pcmSize+=e:a.recSize+=e},O=o.onProcess(c,i,L,n,f,z);if(!0===O){var D=0;for(g=f;g<w;g++)null==c[g]?D=1:c[g]=new Int16Array(0);D?a.CLog("未进入异步前不能清除buffers",3):s?s.pcmSize-=x:a.recSize-=x}else z()},start:function(){var e=this,t=A.Ctx,n=1;if(e.set.sourceStream?e.Stream||(n=0):A.IsOpen()||(n=0),n)if(e.CLog("开始录音"),e._stop(),e.state=0,e.envStart(null,t.sampleRate),e._SO&&e._SO+1!=e._S)e.CLog("start被中断",3);else{e._SO=0;var r=function(){e.state=1,e.resume()};"suspended"==t.state?(e.CLog("wait ctx resume..."),e.state=3,t.resume().then(function(){e.CLog("ctx resume"),3==e.state&&r()})):r()}else e.CLog("未open",1)},pause:function(){var e=this;e.state&&(e.state=2,e.CLog("pause"),delete e._streamStore().Stream._call[e.id])},resume:function(){var e,n=this;if(n.state){n.state=1,n.CLog("resume"),n.envResume();var t=n._streamStore();t.Stream._call[n.id]=function(e,t){1==n.state&&n.envIn(e,t)},(e=(t||A).Stream)._na&&e._na()}},_stop:function(e){var t=this,n=t.set;t.isMock||t._S++,t.state&&(t.pause(),t.state=0),!e&&t[n.type+"_stop"]&&(t[n.type+"_stop"](t.engineCtx),t.engineCtx=0)},stop:function(n,t,e){var r,a=this,o=a.set;a.CLog("stop "+(a.envInLast?a.envInLast-a.envInFirst+"ms 补"+a.envInFix+"ms":"-"));var s=function(){a._stop(),e&&a.close()},i=function(e){a.CLog("结束录音失败:"+e,1),t&&t(e),s()},c=function(e,t){if(a.CLog("结束录音 编码"+(Date.now()-r)+"ms 音频"+t+"ms/"+e.size+"b"),o.takeoffEncodeChunk)a.CLog("启用takeoffEncodeChunk后stop返回的blob长度为0不提供音频数据",3);else if(e.size<Math.max(100,t/2))return void i("生成的"+o.type+"无效");n&&n(e,t),s()};if(!a.isMock){var f=3==a.state;if(!a.state||f)return void i("未开始录音"+(f?",开始录音前无用户交互导致AudioContext未运行":""));a._stop(!0)}var u=a.recSize;if(u)if(a.buffers[0])if(a[o.type]){if(a.isMock){var l=a.envCheck(a.mockEnvInfo||{envName:"mock",canProcess:!1});if(l)return void i("录音错误:"+l)}var p=a.engineCtx;if(a[o.type+"_complete"]&&p){var v=Math.round(p.pcmSize/o.sampleRate*1e3);return r=Date.now(),void a[o.type+"_complete"](p,function(e){c(e,v)},i)}r=Date.now();var m=A.SampleData(a.buffers,a.srcSampleRate,o.sampleRate);o.sampleRate=m.sampleRate;var h=m.data;v=Math.round(h.length/o.sampleRate*1e3),a.CLog("采样"+u+"->"+h.length+" 花:"+(Date.now()-r)+"ms"),setTimeout(function(){r=Date.now(),a[o.type](h,function(e){c(e,v)},function(e){i(e)})})}else i("未加载"+o.type+"编码器");else i("音频buffers被释放");else i("未采集到录音")}},y.Recorder&&y.Recorder.Destroy(),(y.Recorder=A).LM="2022-03-05 11:53:19",A.TrafficImgUrl="//ia.51.la/go1?id=20469973&pvFlag=1",A.Traffic=function(){var e=A.TrafficImgUrl;if(e){var t=A.Traffic,n=location.href.replace(/#.*/,"");if(0==e.indexOf("//")&&(e=/^https:/i.test(n)?"https:"+e:"http:"+e),!t[n]){t[n]=1;var r=new Image;r.src=e,M("Traffic Analysis Image: Recorder.TrafficImgUrl="+A.TrafficImgUrl)}}}}(window),"function"==typeof define&&define.amd&&define(function(){return Recorder}),"object"==typeof module&&module.exports&&(module.exports=Recorder);
\ No newline at end of file
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>PaddleSpeech Serving-语音实时转写</title>
<link rel="shortcut icon" href="./static/paddle.ico">
<script src="../static/js/jquery-3.2.1.min.js"></script>
<script src="../static/js/recorder/recorder-core.js"></script>
<script src="../static/js/recorder/extensions/lib.fft.js"></script>
<script src="../static/js/recorder/extensions/frequency.histogram.view.js"></script>
<script src="../static/js/recorder/engine/pcm.js"></script>
<script src="../static/js/SoundRecognizer.js"></script>
<link rel="stylesheet" href="../static/css/style.css">
<link rel="stylesheet" href="../static/css/font-awesome.min.css">
</head>
<body>
<div class="asr-content">
<div class="audio-banner">
<div class="weaper">
<div class="text-content">
<p><span class="title">PaddleSpeech Serving简介</span></p>
<p class="con-container">
<span class="con">PaddleSpeech 是基于飞桨 PaddlePaddle 的语音方向的开源模型库,用于语音和音频中的各种关键任务的开发。PaddleSpeech Serving是基于python + fastapi 的语音算法模型的C/S类型后端服务,旨在统一paddle speech下的各语音算子来对外提供后端服务。</span>
</p>
</div>
<div class="img-con">
<img src="../static/image/PaddleSpeech_logo.png" alt="" />
</div>
</div>
</div>
<div class="audio-experience">
<div class="asr-box">
<h2>产品体验</h2>
<div id="client-word-recorder" style="position: relative;">
<div class="pd">
<div style="text-align:center;height:20px;width:100%;
border:0px solid #bcbcbc;color:#000;box-sizing: border-box;display:inline-block"
class="recwave">
</div>
</div>
</div>
<div class="voice-container">
<div class="voice-input">
<span>WebSocket URL:</span>
<input type="text" id="socketUrl" class="websocket-url" value="ws://127.0.0.1:8091/ws/asr"
placeholder="请输入服务器地址,如:ws://127.0.0.1:8091/ws/asr">
<div class="start-voice">
<button type="primary" id="beginBtn" class="voice-btn">
<span class="fa fa-microphone"> 开始识别</span>
</button>
<button type="primary" id="endBtn" class="voice-btn end">
<span class="fa fa-microphone-slash"> 结束识别</span>
</button>
<div id="timeBox" class="time-box flex-display-1">
<span class="total-time">识别中,<i id="timeCount"></i> 秒后自动停止识别</span>
</div>
</div>
</div>
<div class="voice">
<div class="result-text" id="resultPanel">此处显示识别结果</div>
</div>
</div>
</div>
</div>
</div>
<script>
var wenetWs = null
var timeLoop = null
var result = ""
$(document).ready(function () {
$('#beginBtn').on('click', startRecording)
$('#endBtn').on('click', stopRecording)
})
function openWebSocket(url) {
if ("WebSocket" in window) {
wenetWs = new WebSocket(url)
wenetWs.onopen = function () {
console.log("Websocket 连接成功,开始识别")
wenetWs.send(JSON.stringify({
"signal": "start"
}))
}
wenetWs.onmessage = function (_msg) { parseResult(_msg.data) }
wenetWs.onclose = function () {
console.log("WebSocket 连接断开")
}
wenetWs.onerror = function () { console.log("WebSocket 连接失败") }
}
}
function parseResult(data) {
var data = JSON.parse(data)
console.log('result json:', data)
var result = data.result
console.log(result)
$("#resultPanel").html(result)
}
function TransferUpload(number, blobOrNull, duration, blobRec, isClose) {
if (blobOrNull) {
var blob = blobOrNull
var encTime = blob.encTime
var reader = new FileReader()
reader.onloadend = function () { wenetWs.send(reader.result) }
reader.readAsArrayBuffer(blob)
}
}
function startRecording() {
// Check socket url
var socketUrl = $('#socketUrl').val()
if (!socketUrl.trim()) {
alert('请输入 WebSocket 服务器地址,如:ws://127.0.0.1:8091/ws/asr')
$('#socketUrl').focus()
return
}
// init recorder
SoundRecognizer.init({
soundType: 'pcm',
sampleRate: 16000,
recwaveElm: '.recwave',
translerCallBack: TransferUpload
})
openWebSocket(socketUrl)
// Change button state
$('#beginBtn').hide()
$('#endBtn, #timeBox').addClass('show')
// Start countdown
var seconds = 180
$('#timeCount').text(seconds)
timeLoop = setInterval(function () {
seconds--
$('#timeCount').text(seconds)
if (seconds === 0) {
stopRecording()
}
}, 1000)
}
function stopRecording() {
wenetWs.send(JSON.stringify({ "signal": "end" }))
SoundRecognizer.recordClose()
$('#endBtn').add($('#timeBox')).removeClass('show')
$('#beginBtn').show()
$('#timeCount').text('')
clearInterval(timeLoop)
}
</script>
</body>
</html>
...@@ -22,6 +22,7 @@ onnxruntime ...@@ -22,6 +22,7 @@ onnxruntime
pandas pandas
paddlenlp paddlenlp
paddlespeech_feat paddlespeech_feat
Pillow>=9.0.0
praatio==5.0.0 praatio==5.0.0
pypinyin pypinyin
pypinyin-dict pypinyin-dict
......
...@@ -10,7 +10,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | ...@@ -10,7 +10,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python | [Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python |
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python | [Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python |
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python | [Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0464 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python | [Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.0.1.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0460 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python |
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python | [Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python |
[Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz)| Librispeech Dataset | Char-based | 1.3 GB | 2 Conv + 5 bidirectional LSTM layers| - |0.0467| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0) | inference/python | [Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz)| Librispeech Dataset | Char-based | 1.3 GB | 2 Conv + 5 bidirectional LSTM layers| - |0.0467| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0) | inference/python |
[Conformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 191 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0338 | 960 h | [Conformer Librispeech ASR1](../../examples/librispeech/asr1) | python | [Conformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 191 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0338 | 960 h | [Conformer Librispeech ASR1](../../examples/librispeech/asr1) | python |
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
## Conformer ## Conformer
paddle version: 2.2.2 paddle version: 2.2.2
paddlespeech version: 0.2.0 paddlespeech version: 1.0.1
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER | | Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0530 | | conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0522 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0495 | | conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0481 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug| test | ctc_prefix_beam_search | - | 0.0494 | | conformer | 47.07M | conf/conformer.yaml | spec_aug| test | ctc_prefix_beam_search | - | 0.0480 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0464 | | conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0460 |
## Conformer Streaming ## Conformer Streaming
......
...@@ -57,7 +57,7 @@ feat_dim: 80 ...@@ -57,7 +57,7 @@ feat_dim: 80
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 64 batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug minibatches: 0 # for debug
...@@ -73,10 +73,10 @@ num_encs: 1 ...@@ -73,10 +73,10 @@ num_encs: 1
########################################### ###########################################
# Training # # Training #
########################################### ###########################################
n_epoch: 240 n_epoch: 150
accum_grad: 2 accum_grad: 8
global_grad_clip: 5.0 global_grad_clip: 5.0
dist_sampler: True dist_sampler: False
optim: adam optim: adam
optim_conf: optim_conf:
lr: 0.002 lr: 0.002
......
...@@ -144,3 +144,34 @@ optional arguments: ...@@ -144,3 +144,34 @@ optional arguments:
6. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. 6. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Model ## Pretrained Model
The pretrained model can be downloaded here:
- [vits_csmsc_ckpt_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/vits/vits_csmsc_ckpt_1.1.0.zip) (add_blank=true)
VITS checkpoint contains files listed below.
```text
vits_csmsc_ckpt_1.1.0
├── default.yaml # default config used to train vitx
├── phone_id_map.txt # phone vocabulary file when training vits
└── snapshot_iter_350000.pdz # model parameters and optimizer states
```
ps: This ckpt is not good enough, a better result is training
You can use the following scripts to synthesize for `${BIN_DIR}/../sentences.txt` using pretrained VITS.
```bash
source path.sh
add_blank=true
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize_e2e.py \
--config=vits_csmsc_ckpt_1.1.0/default.yaml \
--ckpt=vits_csmsc_ckpt_1.1.0/snapshot_iter_350000.pdz \
--phones_dict=vits_csmsc_ckpt_1.1.0/phone_id_map.txt \
--output_dir=exp/default/test_e2e \
--text=${BIN_DIR}/../sentences.txt \
--add-blank=${add_blank}
```
...@@ -15,4 +15,4 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -15,4 +15,4 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--phones_dict=dump/phone_id_map.txt \ --phones_dict=dump/phone_id_map.txt \
--test_metadata=dump/test/norm/metadata.jsonl \ --test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test --output_dir=${train_output_path}/test
fi fi
\ No newline at end of file
...@@ -3,6 +3,11 @@ ...@@ -3,6 +3,11 @@
config_path=$1 config_path=$1
train_output_path=$2 train_output_path=$2
# install monotonic_align
cd ${MAIN_ROOT}/paddlespeech/t2s/models/vits/monotonic_align
python3 setup.py build_ext --inplace
cd -
python3 ${BIN_DIR}/train.py \ python3 ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \ --train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \ --dev-metadata=dump/dev/norm/metadata.jsonl \
......
...@@ -74,7 +74,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then ...@@ -74,7 +74,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# convert the m4a to wav # convert the m4a to wav
# and we will not delete the original m4a file # and we will not delete the original m4a file
echo "start to convert the m4a to wav" echo "start to convert the m4a to wav"
bash local/convert.sh ${TARGET_DIR}/voxceleb/vox2/test/ || exit 1; bash local/convert.sh ${TARGET_DIR}/voxceleb/vox2/ || exit 1;
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Convert voxceleb2 dataset from m4a to wav failed. Terminated." echo "Convert voxceleb2 dataset from m4a to wav failed. Terminated."
......
...@@ -14,10 +14,8 @@ ...@@ -14,10 +14,8 @@
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
"""Spec Augment module for preprocessing i.e., data augmentation""" """Spec Augment module for preprocessing i.e., data augmentation"""
import random import random
import numpy import numpy
from PIL import Image from PIL import Image
from PIL.Image import BICUBIC
from .functional import FuncTrans from .functional import FuncTrans
...@@ -46,9 +44,10 @@ def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): ...@@ -46,9 +44,10 @@ def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
warped = random.randrange(center - window, center + warped = random.randrange(center - window, center +
window) + 1 # 1 ... t - 1 window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC) left = Image.fromarray(x[:center]).resize((x.shape[1], warped),
Image.BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC) Image.BICUBIC)
if inplace: if inplace:
x[:warped] = left x[:warped] = left
x[warped:] = right x[warped:] = right
......
...@@ -133,11 +133,11 @@ class ASRExecutor(BaseExecutor): ...@@ -133,11 +133,11 @@ class ASRExecutor(BaseExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
logger.info("start to init the model") logger.debug("start to init the model")
# default max_len: unit:second # default max_len: unit:second
self.max_len = 50 self.max_len = 50
if hasattr(self, 'model'): if hasattr(self, 'model'):
logger.info('Model had been initialized.') logger.debug('Model had been initialized.')
return return
if cfg_path is None or ckpt_path is None: if cfg_path is None or ckpt_path is None:
...@@ -151,15 +151,15 @@ class ASRExecutor(BaseExecutor): ...@@ -151,15 +151,15 @@ class ASRExecutor(BaseExecutor):
self.ckpt_path = os.path.join( self.ckpt_path = os.path.join(
self.res_path, self.res_path,
self.task_resource.res_dict['ckpt_path'] + ".pdparams") self.task_resource.res_dict['ckpt_path'] + ".pdparams")
logger.info(self.res_path) logger.debug(self.res_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path) logger.debug(self.cfg_path)
logger.info(self.ckpt_path) logger.debug(self.ckpt_path)
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
...@@ -216,7 +216,7 @@ class ASRExecutor(BaseExecutor): ...@@ -216,7 +216,7 @@ class ASRExecutor(BaseExecutor):
max_len = self.config.encoder_conf.max_len max_len = self.config.encoder_conf.max_len
self.max_len = frame_shift_ms * max_len * subsample_rate self.max_len = frame_shift_ms * max_len * subsample_rate
logger.info( logger.debug(
f"The asr server limit max duration len: {self.max_len}") f"The asr server limit max duration len: {self.max_len}")
def preprocess(self, model_type: str, input: Union[str, os.PathLike]): def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
...@@ -227,15 +227,15 @@ class ASRExecutor(BaseExecutor): ...@@ -227,15 +227,15 @@ class ASRExecutor(BaseExecutor):
audio_file = input audio_file = input
if isinstance(audio_file, (str, os.PathLike)): if isinstance(audio_file, (str, os.PathLike)):
logger.info("Preprocess audio_file:" + audio_file) logger.debug("Preprocess audio_file:" + audio_file)
# Get the object for feature extraction # Get the object for feature extraction
if "deepspeech2" in model_type or "conformer" in model_type or "transformer" in model_type: if "deepspeech2" in model_type or "conformer" in model_type or "transformer" in model_type:
logger.info("get the preprocess conf") logger.debug("get the preprocess conf")
preprocess_conf = self.config.preprocess_config preprocess_conf = self.config.preprocess_config
preprocess_args = {"train": False} preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf) preprocessing = Transformation(preprocess_conf)
logger.info("read the audio file") logger.debug("read the audio file")
audio, audio_sample_rate = soundfile.read( audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True) audio_file, dtype="int16", always_2d=True)
if self.change_format: if self.change_format:
...@@ -255,7 +255,7 @@ class ASRExecutor(BaseExecutor): ...@@ -255,7 +255,7 @@ class ASRExecutor(BaseExecutor):
else: else:
audio = audio[:, 0] audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}") logger.debug(f"audio shape: {audio.shape}")
# fbank # fbank
audio = preprocessing(audio, **preprocess_args) audio = preprocessing(audio, **preprocess_args)
...@@ -264,19 +264,19 @@ class ASRExecutor(BaseExecutor): ...@@ -264,19 +264,19 @@ class ASRExecutor(BaseExecutor):
self._inputs["audio"] = audio self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}") logger.debug(f"audio feat shape: {audio.shape}")
else: else:
raise Exception("wrong type") raise Exception("wrong type")
logger.info("audio feat process success") logger.debug("audio feat process success")
@paddle.no_grad() @paddle.no_grad()
def infer(self, model_type: str): def infer(self, model_type: str):
""" """
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
logger.info("start to infer the model to get the output") logger.debug("start to infer the model to get the output")
cfg = self.config.decode cfg = self.config.decode
audio = self._inputs["audio"] audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"] audio_len = self._inputs["audio_len"]
...@@ -293,7 +293,7 @@ class ASRExecutor(BaseExecutor): ...@@ -293,7 +293,7 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0] self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
logger.info( logger.debug(
f"we will use the transformer like model : {model_type}") f"we will use the transformer like model : {model_type}")
try: try:
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
...@@ -352,7 +352,7 @@ class ASRExecutor(BaseExecutor): ...@@ -352,7 +352,7 @@ class ASRExecutor(BaseExecutor):
logger.error("Please input the right audio file path") logger.error("Please input the right audio file path")
return False return False
logger.info("checking the audio file format......") logger.debug("checking the audio file format......")
try: try:
audio, audio_sample_rate = soundfile.read( audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True) audio_file, dtype="int16", always_2d=True)
...@@ -374,7 +374,7 @@ class ASRExecutor(BaseExecutor): ...@@ -374,7 +374,7 @@ class ASRExecutor(BaseExecutor):
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
") ")
return False return False
logger.info("The sample rate is %d" % audio_sample_rate) logger.debug("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate: if audio_sample_rate != self.sample_rate:
logger.warning("The sample rate of the input file is not {}.\n \ logger.warning("The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \ The program will resample the wav file to {}.\n \
...@@ -383,28 +383,28 @@ class ASRExecutor(BaseExecutor): ...@@ -383,28 +383,28 @@ class ASRExecutor(BaseExecutor):
".format(self.sample_rate, self.sample_rate)) ".format(self.sample_rate, self.sample_rate))
if force_yes is False: if force_yes is False:
while (True): while (True):
logger.info( logger.debug(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
) )
content = input("Input(Y/N):") content = input("Input(Y/N):")
if content.strip() == "Y" or content.strip( if content.strip() == "Y" or content.strip(
) == "y" or content.strip() == "yes" or content.strip( ) == "y" or content.strip() == "yes" or content.strip(
) == "Yes": ) == "Yes":
logger.info( logger.debug(
"change the sampele rate, channel to 16k and 1 channel" "change the sampele rate, channel to 16k and 1 channel"
) )
break break
elif content.strip() == "N" or content.strip( elif content.strip() == "N" or content.strip(
) == "n" or content.strip() == "no" or content.strip( ) == "n" or content.strip() == "no" or content.strip(
) == "No": ) == "No":
logger.info("Exit the program") logger.debug("Exit the program")
return False return False
else: else:
logger.warning("Not regular input, please input again") logger.warning("Not regular input, please input again")
self.change_format = True self.change_format = True
else: else:
logger.info("The audio file format is right") logger.debug("The audio file format is right")
self.change_format = False self.change_format = False
return True return True
......
...@@ -92,7 +92,7 @@ class CLSExecutor(BaseExecutor): ...@@ -92,7 +92,7 @@ class CLSExecutor(BaseExecutor):
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if hasattr(self, 'model'): if hasattr(self, 'model'):
logger.info('Model had been initialized.') logger.debug('Model had been initialized.')
return return
if label_file is None or ckpt_path is None: if label_file is None or ckpt_path is None:
...@@ -135,14 +135,14 @@ class CLSExecutor(BaseExecutor): ...@@ -135,14 +135,14 @@ class CLSExecutor(BaseExecutor):
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
""" """
feat_conf = self._conf['feature'] feat_conf = self._conf['feature']
logger.info(feat_conf) logger.debug(feat_conf)
waveform, _ = load( waveform, _ = load(
file=audio_file, file=audio_file,
sr=feat_conf['sample_rate'], sr=feat_conf['sample_rate'],
mono=True, mono=True,
dtype='float32') dtype='float32')
if isinstance(audio_file, (str, os.PathLike)): if isinstance(audio_file, (str, os.PathLike)):
logger.info("Preprocessing audio_file:" + audio_file) logger.debug("Preprocessing audio_file:" + audio_file)
# Feature extraction # Feature extraction
feature_extractor = LogMelSpectrogram( feature_extractor = LogMelSpectrogram(
......
...@@ -61,7 +61,7 @@ def _get_unique_endpoints(trainer_endpoints): ...@@ -61,7 +61,7 @@ def _get_unique_endpoints(trainer_endpoints):
continue continue
ips.add(ip) ips.add(ip)
unique_endpoints.add(endpoint) unique_endpoints.add(endpoint)
logger.info("unique_endpoints {}".format(unique_endpoints)) logger.debug("unique_endpoints {}".format(unique_endpoints))
return unique_endpoints return unique_endpoints
...@@ -96,7 +96,7 @@ def get_path_from_url(url, ...@@ -96,7 +96,7 @@ def get_path_from_url(url,
# data, and the same ip will only download data once. # data, and the same ip will only download data once.
unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:]) unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:])
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath)) logger.debug("Found {}".format(fullpath))
else: else:
if ParallelEnv().current_endpoint in unique_endpoints: if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum, method=method) fullpath = _download(url, root_dir, md5sum, method=method)
...@@ -118,7 +118,7 @@ def _get_download(url, fullname): ...@@ -118,7 +118,7 @@ def _get_download(url, fullname):
try: try:
req = requests.get(url, stream=True) req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format( logger.debug("Downloading {} from {} failed with exception {}".format(
fname, url, str(e))) fname, url, str(e)))
return False return False
...@@ -190,7 +190,7 @@ def _download(url, path, md5sum=None, method='get'): ...@@ -190,7 +190,7 @@ def _download(url, path, md5sum=None, method='get'):
fullname = osp.join(path, fname) fullname = osp.join(path, fname)
retry_cnt = 0 retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url)) logger.debug("Downloading {} from {}".format(fname, url))
while not (osp.exists(fullname) and _md5check(fullname, md5sum)): while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT: if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1 retry_cnt += 1
...@@ -209,7 +209,7 @@ def _md5check(fullname, md5sum=None): ...@@ -209,7 +209,7 @@ def _md5check(fullname, md5sum=None):
if md5sum is None: if md5sum is None:
return True return True
logger.info("File {} md5 checking...".format(fullname)) logger.debug("File {} md5 checking...".format(fullname))
md5 = hashlib.md5() md5 = hashlib.md5()
with open(fullname, 'rb') as f: with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""): for chunk in iter(lambda: f.read(4096), b""):
...@@ -217,8 +217,8 @@ def _md5check(fullname, md5sum=None): ...@@ -217,8 +217,8 @@ def _md5check(fullname, md5sum=None):
calc_md5sum = md5.hexdigest() calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum: if calc_md5sum != md5sum:
logger.info("File {} md5 check failed, {}(calc) != " logger.debug("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum)) "{}(base)".format(fullname, calc_md5sum, md5sum))
return False return False
return True return True
...@@ -227,7 +227,7 @@ def _decompress(fname): ...@@ -227,7 +227,7 @@ def _decompress(fname):
""" """
Decompress for zip and tar file Decompress for zip and tar file
""" """
logger.info("Decompressing {}...".format(fname)) logger.debug("Decompressing {}...".format(fname))
# For protecting decompressing interupted, # For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress # decompress to fpath_tmp directory firstly, if decompress
......
...@@ -217,7 +217,7 @@ class BaseExecutor(ABC): ...@@ -217,7 +217,7 @@ class BaseExecutor(ABC):
logging.getLogger(name) for name in logging.root.manager.loggerDict logging.getLogger(name) for name in logging.root.manager.loggerDict
] ]
for l in loggers: for l in loggers:
l.disabled = True l.setLevel(logging.ERROR)
def show_rtf(self, info: Dict[str, List[float]]): def show_rtf(self, info: Dict[str, List[float]]):
""" """
......
...@@ -88,7 +88,7 @@ class KWSExecutor(BaseExecutor): ...@@ -88,7 +88,7 @@ class KWSExecutor(BaseExecutor):
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if hasattr(self, 'model'): if hasattr(self, 'model'):
logger.info('Model had been initialized.') logger.debug('Model had been initialized.')
return return
if ckpt_path is None: if ckpt_path is None:
...@@ -141,7 +141,7 @@ class KWSExecutor(BaseExecutor): ...@@ -141,7 +141,7 @@ class KWSExecutor(BaseExecutor):
assert os.path.isfile(audio_file) assert os.path.isfile(audio_file)
waveform, _ = load(audio_file) waveform, _ = load(audio_file)
if isinstance(audio_file, (str, os.PathLike)): if isinstance(audio_file, (str, os.PathLike)):
logger.info("Preprocessing audio_file:" + audio_file) logger.debug("Preprocessing audio_file:" + audio_file)
# Feature extraction # Feature extraction
waveform = paddle.to_tensor(waveform).unsqueeze(0) waveform = paddle.to_tensor(waveform).unsqueeze(0)
......
...@@ -49,7 +49,7 @@ class Logger(object): ...@@ -49,7 +49,7 @@ class Logger(object):
self.handler.setFormatter(self.format) self.handler.setFormatter(self.format)
self.logger.addHandler(self.handler) self.logger.addHandler(self.handler)
self.logger.setLevel(logging.DEBUG) self.logger.setLevel(logging.INFO)
self.logger.propagate = False self.logger.propagate = False
def __call__(self, log_level: str, msg: str): def __call__(self, log_level: str, msg: str):
......
...@@ -110,7 +110,7 @@ class STExecutor(BaseExecutor): ...@@ -110,7 +110,7 @@ class STExecutor(BaseExecutor):
""" """
decompressed_path = download_and_decompress(self.kaldi_bins, MODEL_HOME) decompressed_path = download_and_decompress(self.kaldi_bins, MODEL_HOME)
decompressed_path = os.path.abspath(decompressed_path) decompressed_path = os.path.abspath(decompressed_path)
logger.info("Kaldi_bins stored in: {}".format(decompressed_path)) logger.debug("Kaldi_bins stored in: {}".format(decompressed_path))
if "LD_LIBRARY_PATH" in os.environ: if "LD_LIBRARY_PATH" in os.environ:
os.environ["LD_LIBRARY_PATH"] += f":{decompressed_path}" os.environ["LD_LIBRARY_PATH"] += f":{decompressed_path}"
else: else:
...@@ -128,7 +128,7 @@ class STExecutor(BaseExecutor): ...@@ -128,7 +128,7 @@ class STExecutor(BaseExecutor):
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if hasattr(self, 'model'): if hasattr(self, 'model'):
logger.info('Model had been initialized.') logger.debug('Model had been initialized.')
return return
if cfg_path is None or ckpt_path is None: if cfg_path is None or ckpt_path is None:
...@@ -140,8 +140,8 @@ class STExecutor(BaseExecutor): ...@@ -140,8 +140,8 @@ class STExecutor(BaseExecutor):
self.ckpt_path = os.path.join( self.ckpt_path = os.path.join(
self.task_resource.res_dir, self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path']) self.task_resource.res_dict['ckpt_path'])
logger.info(self.cfg_path) logger.debug(self.cfg_path)
logger.info(self.ckpt_path) logger.debug(self.ckpt_path)
res_path = self.task_resource.res_dir res_path = self.task_resource.res_dir
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
...@@ -192,7 +192,7 @@ class STExecutor(BaseExecutor): ...@@ -192,7 +192,7 @@ class STExecutor(BaseExecutor):
Input content can be a file(wav). Input content can be a file(wav).
""" """
audio_file = os.path.abspath(wav_file) audio_file = os.path.abspath(wav_file)
logger.info("Preprocess audio_file:" + audio_file) logger.debug("Preprocess audio_file:" + audio_file)
if "fat_st" in model_type: if "fat_st" in model_type:
cmvn = self.config.cmvn_path cmvn = self.config.cmvn_path
......
...@@ -98,7 +98,7 @@ class TextExecutor(BaseExecutor): ...@@ -98,7 +98,7 @@ class TextExecutor(BaseExecutor):
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if hasattr(self, 'model'): if hasattr(self, 'model'):
logger.info('Model had been initialized.') logger.debug('Model had been initialized.')
return return
self.task = task self.task = task
......
...@@ -173,16 +173,23 @@ class TTSExecutor(BaseExecutor): ...@@ -173,16 +173,23 @@ class TTSExecutor(BaseExecutor):
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'): if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'):
logger.info('Models had been initialized.') logger.debug('Models had been initialized.')
return return
# am # am
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
use_pretrained_am = True
else:
use_pretrained_am = False
am_tag = am + '-' + lang am_tag = am + '-' + lang
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=am_tag, model_tag=am_tag,
model_type=0, # am model_type=0, # am
skip_download=not use_pretrained_am,
version=None, # default version version=None, # default version
) )
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: if use_pretrained_am:
self.am_res_path = self.task_resource.res_dir self.am_res_path = self.task_resource.res_dir
self.am_config = os.path.join(self.am_res_path, self.am_config = os.path.join(self.am_res_path,
self.task_resource.res_dict['config']) self.task_resource.res_dict['config'])
...@@ -193,9 +200,9 @@ class TTSExecutor(BaseExecutor): ...@@ -193,9 +200,9 @@ class TTSExecutor(BaseExecutor):
# must have phones_dict in acoustic # must have phones_dict in acoustic
self.phones_dict = os.path.join( self.phones_dict = os.path.join(
self.am_res_path, self.task_resource.res_dict['phones_dict']) self.am_res_path, self.task_resource.res_dict['phones_dict'])
logger.info(self.am_res_path) logger.debug(self.am_res_path)
logger.info(self.am_config) logger.debug(self.am_config)
logger.info(self.am_ckpt) logger.debug(self.am_ckpt)
else: else:
self.am_config = os.path.abspath(am_config) self.am_config = os.path.abspath(am_config)
self.am_ckpt = os.path.abspath(am_ckpt) self.am_ckpt = os.path.abspath(am_ckpt)
...@@ -220,13 +227,19 @@ class TTSExecutor(BaseExecutor): ...@@ -220,13 +227,19 @@ class TTSExecutor(BaseExecutor):
self.speaker_dict = speaker_dict self.speaker_dict = speaker_dict
# voc # voc
if voc_ckpt is None or voc_config is None or voc_stat is None:
use_pretrained_voc = True
else:
use_pretrained_voc = False
voc_tag = voc + '-' + lang voc_tag = voc + '-' + lang
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=voc_tag, model_tag=voc_tag,
model_type=1, # vocoder model_type=1, # vocoder
skip_download=not use_pretrained_voc,
version=None, # default version version=None, # default version
) )
if voc_ckpt is None or voc_config is None or voc_stat is None: if use_pretrained_voc:
self.voc_res_path = self.task_resource.voc_res_dir self.voc_res_path = self.task_resource.voc_res_dir
self.voc_config = os.path.join( self.voc_config = os.path.join(
self.voc_res_path, self.task_resource.voc_res_dict['config']) self.voc_res_path, self.task_resource.voc_res_dict['config'])
...@@ -235,9 +248,9 @@ class TTSExecutor(BaseExecutor): ...@@ -235,9 +248,9 @@ class TTSExecutor(BaseExecutor):
self.voc_stat = os.path.join( self.voc_stat = os.path.join(
self.voc_res_path, self.voc_res_path,
self.task_resource.voc_res_dict['speech_stats']) self.task_resource.voc_res_dict['speech_stats'])
logger.info(self.voc_res_path) logger.debug(self.voc_res_path)
logger.info(self.voc_config) logger.debug(self.voc_config)
logger.info(self.voc_ckpt) logger.debug(self.voc_ckpt)
else: else:
self.voc_config = os.path.abspath(voc_config) self.voc_config = os.path.abspath(voc_config)
self.voc_ckpt = os.path.abspath(voc_ckpt) self.voc_ckpt = os.path.abspath(voc_ckpt)
...@@ -254,21 +267,18 @@ class TTSExecutor(BaseExecutor): ...@@ -254,21 +267,18 @@ class TTSExecutor(BaseExecutor):
with open(self.phones_dict, "r") as f: with open(self.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
tone_size = None tone_size = None
if self.tones_dict: if self.tones_dict:
with open(self.tones_dict, "r") as f: with open(self.tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()] tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id) tone_size = len(tone_id)
print("tone_size:", tone_size)
spk_num = None spk_num = None
if self.speaker_dict: if self.speaker_dict:
with open(self.speaker_dict, 'rt') as f: with open(self.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id) spk_num = len(spk_id)
print("spk_num:", spk_num)
# frontend # frontend
if lang == 'zh': if lang == 'zh':
...@@ -278,7 +288,6 @@ class TTSExecutor(BaseExecutor): ...@@ -278,7 +288,6 @@ class TTSExecutor(BaseExecutor):
elif lang == 'en': elif lang == 'en':
self.frontend = English(phone_vocab_path=self.phones_dict) self.frontend = English(phone_vocab_path=self.phones_dict)
print("frontend done!")
# acoustic model # acoustic model
odim = self.am_config.n_mels odim = self.am_config.n_mels
...@@ -311,7 +320,6 @@ class TTSExecutor(BaseExecutor): ...@@ -311,7 +320,6 @@ class TTSExecutor(BaseExecutor):
am_normalizer = ZScore(am_mu, am_std) am_normalizer = ZScore(am_mu, am_std)
self.am_inference = am_inference_class(am_normalizer, am) self.am_inference = am_inference_class(am_normalizer, am)
self.am_inference.eval() self.am_inference.eval()
print("acoustic model done!")
# vocoder # vocoder
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
...@@ -334,7 +342,6 @@ class TTSExecutor(BaseExecutor): ...@@ -334,7 +342,6 @@ class TTSExecutor(BaseExecutor):
voc_normalizer = ZScore(voc_mu, voc_std) voc_normalizer = ZScore(voc_mu, voc_std)
self.voc_inference = voc_inference_class(voc_normalizer, voc) self.voc_inference = voc_inference_class(voc_normalizer, voc)
self.voc_inference.eval() self.voc_inference.eval()
print("voc done!")
def preprocess(self, input: Any, *args, **kwargs): def preprocess(self, input: Any, *args, **kwargs):
""" """
...@@ -375,7 +382,7 @@ class TTSExecutor(BaseExecutor): ...@@ -375,7 +382,7 @@ class TTSExecutor(BaseExecutor):
text, merge_sentences=merge_sentences) text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
else: else:
print("lang should in {'zh', 'en'}!") logger.error("lang should in {'zh', 'en'}!")
self.frontend_time = time.time() - frontend_st self.frontend_time = time.time() - frontend_st
self.am_time = 0 self.am_time = 0
......
...@@ -117,7 +117,7 @@ class VectorExecutor(BaseExecutor): ...@@ -117,7 +117,7 @@ class VectorExecutor(BaseExecutor):
# stage 2: read the input data and store them as a list # stage 2: read the input data and store them as a list
task_source = self.get_input_source(parser_args.input) task_source = self.get_input_source(parser_args.input)
logger.info(f"task source: {task_source}") logger.debug(f"task source: {task_source}")
# stage 3: process the audio one by one # stage 3: process the audio one by one
# we do action according the task type # we do action according the task type
...@@ -127,13 +127,13 @@ class VectorExecutor(BaseExecutor): ...@@ -127,13 +127,13 @@ class VectorExecutor(BaseExecutor):
try: try:
# extract the speaker audio embedding # extract the speaker audio embedding
if parser_args.task == "spk": if parser_args.task == "spk":
logger.info("do vector spk task") logger.debug("do vector spk task")
res = self(input_, model, sample_rate, config, ckpt_path, res = self(input_, model, sample_rate, config, ckpt_path,
device) device)
task_result[id_] = res task_result[id_] = res
elif parser_args.task == "score": elif parser_args.task == "score":
logger.info("do vector score task") logger.debug("do vector score task")
logger.info(f"input content {input_}") logger.debug(f"input content {input_}")
if len(input_.split()) != 2: if len(input_.split()) != 2:
logger.error( logger.error(
f"vector score task input {input_} wav num is not two," f"vector score task input {input_} wav num is not two,"
...@@ -142,7 +142,7 @@ class VectorExecutor(BaseExecutor): ...@@ -142,7 +142,7 @@ class VectorExecutor(BaseExecutor):
# get the enroll and test embedding # get the enroll and test embedding
enroll_audio, test_audio = input_.split() enroll_audio, test_audio = input_.split()
logger.info( logger.debug(
f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}" f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}"
) )
enroll_embedding = self(enroll_audio, model, sample_rate, enroll_embedding = self(enroll_audio, model, sample_rate,
...@@ -158,8 +158,8 @@ class VectorExecutor(BaseExecutor): ...@@ -158,8 +158,8 @@ class VectorExecutor(BaseExecutor):
has_exceptions = True has_exceptions = True
task_result[id_] = f'{e.__class__.__name__}: {e}' task_result[id_] = f'{e.__class__.__name__}: {e}'
logger.info("task result as follows: ") logger.debug("task result as follows: ")
logger.info(f"{task_result}") logger.debug(f"{task_result}")
# stage 4: process the all the task results # stage 4: process the all the task results
self.process_task_results(parser_args.input, task_result, self.process_task_results(parser_args.input, task_result,
...@@ -207,7 +207,7 @@ class VectorExecutor(BaseExecutor): ...@@ -207,7 +207,7 @@ class VectorExecutor(BaseExecutor):
""" """
if not hasattr(self, "score_func"): if not hasattr(self, "score_func"):
self.score_func = paddle.nn.CosineSimilarity(axis=0) self.score_func = paddle.nn.CosineSimilarity(axis=0)
logger.info("create the cosine score function ") logger.debug("create the cosine score function ")
score = self.score_func( score = self.score_func(
paddle.to_tensor(enroll_embedding), paddle.to_tensor(enroll_embedding),
...@@ -244,7 +244,7 @@ class VectorExecutor(BaseExecutor): ...@@ -244,7 +244,7 @@ class VectorExecutor(BaseExecutor):
sys.exit(-1) sys.exit(-1)
# stage 1: set the paddle runtime host device # stage 1: set the paddle runtime host device
logger.info(f"device type: {device}") logger.debug(f"device type: {device}")
paddle.device.set_device(device) paddle.device.set_device(device)
# stage 2: read the specific pretrained model # stage 2: read the specific pretrained model
...@@ -283,7 +283,7 @@ class VectorExecutor(BaseExecutor): ...@@ -283,7 +283,7 @@ class VectorExecutor(BaseExecutor):
# stage 0: avoid to init the mode again # stage 0: avoid to init the mode again
self.task = task self.task = task
if hasattr(self, "model"): if hasattr(self, "model"):
logger.info("Model has been initialized") logger.debug("Model has been initialized")
return return
# stage 1: get the model and config path # stage 1: get the model and config path
...@@ -294,7 +294,7 @@ class VectorExecutor(BaseExecutor): ...@@ -294,7 +294,7 @@ class VectorExecutor(BaseExecutor):
sample_rate_str = "16k" if sample_rate == 16000 else "8k" sample_rate_str = "16k" if sample_rate == 16000 else "8k"
tag = model_type + "-" + sample_rate_str tag = model_type + "-" + sample_rate_str
self.task_resource.set_task_model(tag, version=None) self.task_resource.set_task_model(tag, version=None)
logger.info(f"load the pretrained model: {tag}") logger.debug(f"load the pretrained model: {tag}")
# get the model from the pretrained list # get the model from the pretrained list
# we download the pretrained model and store it in the res_path # we download the pretrained model and store it in the res_path
self.res_path = self.task_resource.res_dir self.res_path = self.task_resource.res_dir
...@@ -312,19 +312,19 @@ class VectorExecutor(BaseExecutor): ...@@ -312,19 +312,19 @@ class VectorExecutor(BaseExecutor):
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(f"start to read the ckpt from {self.ckpt_path}") logger.debug(f"start to read the ckpt from {self.ckpt_path}")
logger.info(f"read the config from {self.cfg_path}") logger.debug(f"read the config from {self.cfg_path}")
logger.info(f"get the res path {self.res_path}") logger.debug(f"get the res path {self.res_path}")
# stage 2: read and config and init the model body # stage 2: read and config and init the model body
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
# stage 3: get the model name to instance the model network with dynamic_import # stage 3: get the model name to instance the model network with dynamic_import
logger.info("start to dynamic import the model class") logger.debug("start to dynamic import the model class")
model_name = model_type[:model_type.rindex('_')] model_name = model_type[:model_type.rindex('_')]
model_class = self.task_resource.get_model_class(model_name) model_class = self.task_resource.get_model_class(model_name)
logger.info(f"model name {model_name}") logger.debug(f"model name {model_name}")
model_conf = self.config.model model_conf = self.config.model
backbone = model_class(**model_conf) backbone = model_class(**model_conf)
model = SpeakerIdetification( model = SpeakerIdetification(
...@@ -333,11 +333,11 @@ class VectorExecutor(BaseExecutor): ...@@ -333,11 +333,11 @@ class VectorExecutor(BaseExecutor):
self.model.eval() self.model.eval()
# stage 4: load the model parameters # stage 4: load the model parameters
logger.info("start to set the model parameters to model") logger.debug("start to set the model parameters to model")
model_dict = paddle.load(self.ckpt_path) model_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
logger.info("create the model instance success") logger.debug("create the model instance success")
@paddle.no_grad() @paddle.no_grad()
def infer(self, model_type: str): def infer(self, model_type: str):
...@@ -349,14 +349,14 @@ class VectorExecutor(BaseExecutor): ...@@ -349,14 +349,14 @@ class VectorExecutor(BaseExecutor):
# stage 0: get the feat and length from _inputs # stage 0: get the feat and length from _inputs
feats = self._inputs["feats"] feats = self._inputs["feats"]
lengths = self._inputs["lengths"] lengths = self._inputs["lengths"]
logger.info("start to do backbone network model forward") logger.debug("start to do backbone network model forward")
logger.info( logger.debug(
f"feats shape:{feats.shape}, lengths shape: {lengths.shape}") f"feats shape:{feats.shape}, lengths shape: {lengths.shape}")
# stage 1: get the audio embedding # stage 1: get the audio embedding
# embedding from (1, emb_size, 1) -> (emb_size) # embedding from (1, emb_size, 1) -> (emb_size)
embedding = self.model.backbone(feats, lengths).squeeze().numpy() embedding = self.model.backbone(feats, lengths).squeeze().numpy()
logger.info(f"embedding size: {embedding.shape}") logger.debug(f"embedding size: {embedding.shape}")
# stage 2: put the embedding and dim info to _outputs property # stage 2: put the embedding and dim info to _outputs property
# the embedding type is numpy.array # the embedding type is numpy.array
...@@ -380,12 +380,13 @@ class VectorExecutor(BaseExecutor): ...@@ -380,12 +380,13 @@ class VectorExecutor(BaseExecutor):
""" """
audio_file = input_file audio_file = input_file
if isinstance(audio_file, (str, os.PathLike)): if isinstance(audio_file, (str, os.PathLike)):
logger.info(f"Preprocess audio file: {audio_file}") logger.debug(f"Preprocess audio file: {audio_file}")
# stage 1: load the audio sample points # stage 1: load the audio sample points
# Note: this process must match the training process # Note: this process must match the training process
waveform, sr = load_audio(audio_file) waveform, sr = load_audio(audio_file)
logger.info(f"load the audio sample points, shape is: {waveform.shape}") logger.debug(
f"load the audio sample points, shape is: {waveform.shape}")
# stage 2: get the audio feat # stage 2: get the audio feat
# Note: Now we only support fbank feature # Note: Now we only support fbank feature
...@@ -396,9 +397,9 @@ class VectorExecutor(BaseExecutor): ...@@ -396,9 +397,9 @@ class VectorExecutor(BaseExecutor):
n_mels=self.config.n_mels, n_mels=self.config.n_mels,
window_size=self.config.window_size, window_size=self.config.window_size,
hop_length=self.config.hop_size) hop_length=self.config.hop_size)
logger.info(f"extract the audio feat, shape is: {feat.shape}") logger.debug(f"extract the audio feat, shape is: {feat.shape}")
except Exception as e: except Exception as e:
logger.info(f"feat occurs exception {e}") logger.debug(f"feat occurs exception {e}")
sys.exit(-1) sys.exit(-1)
feat = paddle.to_tensor(feat).unsqueeze(0) feat = paddle.to_tensor(feat).unsqueeze(0)
...@@ -411,11 +412,11 @@ class VectorExecutor(BaseExecutor): ...@@ -411,11 +412,11 @@ class VectorExecutor(BaseExecutor):
# stage 4: store the feat and length in the _inputs, # stage 4: store the feat and length in the _inputs,
# which will be used in other function # which will be used in other function
logger.info(f"feats shape: {feat.shape}") logger.debug(f"feats shape: {feat.shape}")
self._inputs["feats"] = feat self._inputs["feats"] = feat
self._inputs["lengths"] = lengths self._inputs["lengths"] = lengths
logger.info("audio extract the feat success") logger.debug("audio extract the feat success")
def _check(self, audio_file: str, sample_rate: int): def _check(self, audio_file: str, sample_rate: int):
"""Check if the model sample match the audio sample rate """Check if the model sample match the audio sample rate
...@@ -441,7 +442,7 @@ class VectorExecutor(BaseExecutor): ...@@ -441,7 +442,7 @@ class VectorExecutor(BaseExecutor):
logger.error("Please input the right audio file path") logger.error("Please input the right audio file path")
return False return False
logger.info("checking the aduio file format......") logger.debug("checking the aduio file format......")
try: try:
audio, audio_sample_rate = soundfile.read( audio, audio_sample_rate = soundfile.read(
audio_file, dtype="float32", always_2d=True) audio_file, dtype="float32", always_2d=True)
...@@ -458,7 +459,7 @@ class VectorExecutor(BaseExecutor): ...@@ -458,7 +459,7 @@ class VectorExecutor(BaseExecutor):
") ")
return False return False
logger.info(f"The sample rate is {audio_sample_rate}") logger.debug(f"The sample rate is {audio_sample_rate}")
if audio_sample_rate != self.sample_rate: if audio_sample_rate != self.sample_rate:
logger.error("The sample rate of the input file is not {}.\n \ logger.error("The sample rate of the input file is not {}.\n \
...@@ -468,6 +469,6 @@ class VectorExecutor(BaseExecutor): ...@@ -468,6 +469,6 @@ class VectorExecutor(BaseExecutor):
".format(self.sample_rate, self.sample_rate)) ".format(self.sample_rate, self.sample_rate))
sys.exit(-1) sys.exit(-1)
else: else:
logger.info("The audio file format is right") logger.debug("The audio file format is right")
return True return True
...@@ -60,6 +60,7 @@ class CommonTaskResource: ...@@ -60,6 +60,7 @@ class CommonTaskResource:
def set_task_model(self, def set_task_model(self,
model_tag: str, model_tag: str,
model_type: int=0, model_type: int=0,
skip_download: bool=False,
version: Optional[str]=None): version: Optional[str]=None):
"""Set model tag and version of current task. """Set model tag and version of current task.
...@@ -83,16 +84,18 @@ class CommonTaskResource: ...@@ -83,16 +84,18 @@ class CommonTaskResource:
self.version = version self.version = version
self.res_dict = self.pretrained_models[model_tag][version] self.res_dict = self.pretrained_models[model_tag][version]
self._format_path(self.res_dict) self._format_path(self.res_dict)
self.res_dir = self._fetch(self.res_dict, if not skip_download:
self._get_model_dir(model_type)) self.res_dir = self._fetch(self.res_dict,
self._get_model_dir(model_type))
else: else:
assert self.task == 'tts', 'Vocoder will only be used in tts task.' assert self.task == 'tts', 'Vocoder will only be used in tts task.'
self.voc_model_tag = model_tag self.voc_model_tag = model_tag
self.voc_version = version self.voc_version = version
self.voc_res_dict = self.pretrained_models[model_tag][version] self.voc_res_dict = self.pretrained_models[model_tag][version]
self._format_path(self.voc_res_dict) self._format_path(self.voc_res_dict)
self.voc_res_dir = self._fetch(self.voc_res_dict, if not skip_download:
self._get_model_dir(model_type)) self.voc_res_dir = self._fetch(self.voc_res_dict,
self._get_model_dir(model_type))
@staticmethod @staticmethod
def get_model_class(model_name) -> List[object]: def get_model_class(model_name) -> List[object]:
......
...@@ -35,12 +35,6 @@ if __name__ == "__main__": ...@@ -35,12 +35,6 @@ if __name__ == "__main__":
# save jit model to # save jit model to
parser.add_argument( parser.add_argument(
"--export_path", type=str, help="path of the jit model to save") "--export_path", type=str, help="path of the jit model to save")
parser.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
......
...@@ -35,12 +35,6 @@ if __name__ == "__main__": ...@@ -35,12 +35,6 @@ if __name__ == "__main__":
# save asr result to # save asr result to
parser.add_argument( parser.add_argument(
"--result_file", type=str, help="path of save the asr result") "--result_file", type=str, help="path of save the asr result")
parser.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())
......
...@@ -38,12 +38,6 @@ if __name__ == "__main__": ...@@ -38,12 +38,6 @@ if __name__ == "__main__":
#load jit model from #load jit model from
parser.add_argument( parser.add_argument(
"--export_path", type=str, help="path of the jit model to save") "--export_path", type=str, help="path of the jit model to save")
parser.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
parser.add_argument( parser.add_argument(
"--enable-auto-log", action="store_true", help="use auto log") "--enable-auto-log", action="store_true", help="use auto log")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -31,12 +31,6 @@ def main(config, args): ...@@ -31,12 +31,6 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
parser.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())
......
...@@ -16,7 +16,6 @@ import random ...@@ -16,7 +16,6 @@ import random
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from PIL.Image import BICUBIC
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
...@@ -164,9 +163,9 @@ class SpecAugmentor(AugmentorBase): ...@@ -164,9 +163,9 @@ class SpecAugmentor(AugmentorBase):
window) + 1 # 1 ... t - 1 window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), left = Image.fromarray(x[:center]).resize((x.shape[1], warped),
BICUBIC) Image.BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC) Image.BICUBIC)
if self.inplace: if self.inplace:
x[:warped] = left x[:warped] = left
x[warped:] = right x[warped:] = right
......
...@@ -226,10 +226,10 @@ class TextFeaturizer(): ...@@ -226,10 +226,10 @@ class TextFeaturizer():
sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1 sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1 space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1
logger.info(f"BLANK id: {blank_id}") logger.debug(f"BLANK id: {blank_id}")
logger.info(f"UNK id: {unk_id}") logger.debug(f"UNK id: {unk_id}")
logger.info(f"EOS id: {eos_id}") logger.debug(f"EOS id: {eos_id}")
logger.info(f"SOS id: {sos_id}") logger.debug(f"SOS id: {sos_id}")
logger.info(f"SPACE id: {space_id}") logger.debug(f"SPACE id: {space_id}")
logger.info(f"MASKCTC id: {maskctc_id}") logger.debug(f"MASKCTC id: {maskctc_id}")
return token2id, id2token, vocab_list, unk_id, eos_id, blank_id return token2id, id2token, vocab_list, unk_id, eos_id, blank_id
...@@ -827,7 +827,7 @@ class U2Model(U2DecodeModel): ...@@ -827,7 +827,7 @@ class U2Model(U2DecodeModel):
# encoder # encoder
encoder_type = configs.get('encoder', 'transformer') encoder_type = configs.get('encoder', 'transformer')
logger.info(f"U2 Encoder type: {encoder_type}") logger.debug(f"U2 Encoder type: {encoder_type}")
if encoder_type == 'transformer': if encoder_type == 'transformer':
encoder = TransformerEncoder( encoder = TransformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
...@@ -894,7 +894,7 @@ class U2Model(U2DecodeModel): ...@@ -894,7 +894,7 @@ class U2Model(U2DecodeModel):
if checkpoint_path: if checkpoint_path:
infos = checkpoint.Checkpoint().load_parameters( infos = checkpoint.Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path) model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}") logger.debug(f"checkpoint info: {infos}")
layer_tools.summary(model) layer_tools.summary(model)
return model return model
......
...@@ -37,9 +37,9 @@ class CTCLoss(nn.Layer): ...@@ -37,9 +37,9 @@ class CTCLoss(nn.Layer):
self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
self.batch_average = batch_average self.batch_average = batch_average
logger.info( logger.debug(
f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}") f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}")
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") logger.debug(f"CTCLoss Grad Norm Type: {grad_norm_type}")
assert grad_norm_type in ('instance', 'batch', 'frame', None) assert grad_norm_type in ('instance', 'batch', 'frame', None)
self.norm_by_times = False self.norm_by_times = False
...@@ -70,7 +70,8 @@ class CTCLoss(nn.Layer): ...@@ -70,7 +70,8 @@ class CTCLoss(nn.Layer):
param = {} param = {}
self._kwargs = {k: v for k, v in kwargs.items() if k in param} self._kwargs = {k: v for k, v in kwargs.items() if k in param}
_notin = {k: v for k, v in kwargs.items() if k not in param} _notin = {k: v for k, v in kwargs.items() if k not in param}
logger.info(f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}") logger.debug(
f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}")
def forward(self, logits, ys_pad, hlens, ys_lens): def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss. """Compute CTC loss.
......
...@@ -82,6 +82,12 @@ def default_argument_parser(parser=None): ...@@ -82,6 +82,12 @@ def default_argument_parser(parser=None):
type=int, type=int,
default=1, default=1,
help="number of parallel processes. 0 for cpu.") help="number of parallel processes. 0 for cpu.")
train_group.add_argument(
'--nxpu',
type=int,
default=0,
choices=[0, 1],
help="if nxpu == 0 and ngpu == 0, use cpu.")
train_group.add_argument( train_group.add_argument(
"--config", metavar="CONFIG_FILE", help="config file.") "--config", metavar="CONFIG_FILE", help="config file.")
train_group.add_argument( train_group.add_argument(
......
...@@ -94,7 +94,7 @@ def pad_sequence(sequences: List[paddle.Tensor], ...@@ -94,7 +94,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
for i, tensor in enumerate(sequences): for i, tensor in enumerate(sequences):
length = tensor.shape[0] length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor # use index notation to prevent duplicate references to the tensor
logger.info( logger.debug(
f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}" f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}"
) )
if batch_first: if batch_first:
......
...@@ -123,7 +123,6 @@ class TTSClientExecutor(BaseExecutor): ...@@ -123,7 +123,6 @@ class TTSClientExecutor(BaseExecutor):
time_end = time.time() time_end = time.time()
time_consume = time_end - time_start time_consume = time_end - time_start
response_dict = res.json() response_dict = res.json()
logger.info(response_dict["message"])
logger.info("Save synthesized audio successfully on %s." % (output)) logger.info("Save synthesized audio successfully on %s." % (output))
logger.info("Audio duration: %f s." % logger.info("Audio duration: %f s." %
(response_dict['result']['duration'])) (response_dict['result']['duration']))
...@@ -702,7 +701,6 @@ class VectorClientExecutor(BaseExecutor): ...@@ -702,7 +701,6 @@ class VectorClientExecutor(BaseExecutor):
test_audio=args.test, test_audio=args.test,
task=task) task=task)
time_end = time.time() time_end = time.time()
logger.info(f"The vector: {res}")
logger.info("Response time %f s." % (time_end - time_start)) logger.info("Response time %f s." % (time_end - time_start))
return True return True
except Exception as e: except Exception as e:
......
...@@ -30,7 +30,7 @@ class ACSEngine(BaseEngine): ...@@ -30,7 +30,7 @@ class ACSEngine(BaseEngine):
"""The ACSEngine Engine """The ACSEngine Engine
""" """
super(ACSEngine, self).__init__() super(ACSEngine, self).__init__()
logger.info("Create the ACSEngine Instance") logger.debug("Create the ACSEngine Instance")
self.word_list = [] self.word_list = []
def init(self, config: dict): def init(self, config: dict):
...@@ -42,7 +42,7 @@ class ACSEngine(BaseEngine): ...@@ -42,7 +42,7 @@ class ACSEngine(BaseEngine):
Returns: Returns:
bool: The engine instance flag bool: The engine instance flag
""" """
logger.info("Init the acs engine") logger.debug("Init the acs engine")
try: try:
self.config = config self.config = config
self.device = self.config.get("device", paddle.get_device()) self.device = self.config.get("device", paddle.get_device())
...@@ -50,7 +50,7 @@ class ACSEngine(BaseEngine): ...@@ -50,7 +50,7 @@ class ACSEngine(BaseEngine):
# websocket default ping timeout is 20 seconds # websocket default ping timeout is 20 seconds
self.ping_timeout = self.config.get("ping_timeout", 20) self.ping_timeout = self.config.get("ping_timeout", 20)
paddle.set_device(self.device) paddle.set_device(self.device)
logger.info(f"ACS Engine set the device: {self.device}") logger.debug(f"ACS Engine set the device: {self.device}")
except BaseException as e: except BaseException as e:
logger.error( logger.error(
...@@ -66,7 +66,9 @@ class ACSEngine(BaseEngine): ...@@ -66,7 +66,9 @@ class ACSEngine(BaseEngine):
self.url = "ws://" + self.config.asr_server_ip + ":" + str( self.url = "ws://" + self.config.asr_server_ip + ":" + str(
self.config.asr_server_port) + "/paddlespeech/asr/streaming" self.config.asr_server_port) + "/paddlespeech/asr/streaming"
logger.info("Init the acs engine successfully") logger.info("Initialize acs server engine successfully on device: %s." %
(self.device))
return True return True
def read_search_words(self): def read_search_words(self):
...@@ -95,12 +97,12 @@ class ACSEngine(BaseEngine): ...@@ -95,12 +97,12 @@ class ACSEngine(BaseEngine):
Returns: Returns:
_type_: _description_ _type_: _description_
""" """
logger.info("send a message to the server") logger.debug("send a message to the server")
if self.url is None: if self.url is None:
logger.error("No asr server, please input valid ip and port") logger.error("No asr server, please input valid ip and port")
return "" return ""
ws = websocket.WebSocket() ws = websocket.WebSocket()
logger.info(f"set the ping timeout: {self.ping_timeout} seconds") logger.debug(f"set the ping timeout: {self.ping_timeout} seconds")
ws.connect(self.url, ping_timeout=self.ping_timeout) ws.connect(self.url, ping_timeout=self.ping_timeout)
audio_info = json.dumps( audio_info = json.dumps(
{ {
...@@ -123,7 +125,7 @@ class ACSEngine(BaseEngine): ...@@ -123,7 +125,7 @@ class ACSEngine(BaseEngine):
logger.info(f"audio result: {msg}") logger.info(f"audio result: {msg}")
# 3. send chunk audio data to engine # 3. send chunk audio data to engine
logger.info("send the end signal") logger.debug("send the end signal")
audio_info = json.dumps( audio_info = json.dumps(
{ {
"name": "test.wav", "name": "test.wav",
...@@ -197,7 +199,7 @@ class ACSEngine(BaseEngine): ...@@ -197,7 +199,7 @@ class ACSEngine(BaseEngine):
start = max(time_stamp[m.start(0)]['bg'] - offset, 0) start = max(time_stamp[m.start(0)]['bg'] - offset, 0)
end = min(time_stamp[m.end(0) - 1]['ed'] + offset, max_ed) end = min(time_stamp[m.end(0) - 1]['ed'] + offset, max_ed)
logger.info(f'start: {start}, end: {end}') logger.debug(f'start: {start}, end: {end}')
acs_result.append({'w': w, 'bg': start, 'ed': end}) acs_result.append({'w': w, 'bg': start, 'ed': end})
return acs_result, asr_result return acs_result, asr_result
...@@ -212,7 +214,7 @@ class ACSEngine(BaseEngine): ...@@ -212,7 +214,7 @@ class ACSEngine(BaseEngine):
Returns: Returns:
acs_result, asr_result: the acs result and the asr result acs_result, asr_result: the acs result and the asr result
""" """
logger.info("start to process the audio content search") logger.debug("start to process the audio content search")
msg = self.get_asr_content(io.BytesIO(audio_data)) msg = self.get_asr_content(io.BytesIO(audio_data))
acs_result, asr_result = self.get_macthed_word(msg) acs_result, asr_result = self.get_macthed_word(msg)
......
...@@ -44,7 +44,7 @@ class PaddleASRConnectionHanddler: ...@@ -44,7 +44,7 @@ class PaddleASRConnectionHanddler:
asr_engine (ASREngine): the global asr engine asr_engine (ASREngine): the global asr engine
""" """
super().__init__() super().__init__()
logger.info( logger.debug(
"create an paddle asr connection handler to process the websocket connection" "create an paddle asr connection handler to process the websocket connection"
) )
self.config = asr_engine.config # server config self.config = asr_engine.config # server config
...@@ -152,12 +152,12 @@ class PaddleASRConnectionHanddler: ...@@ -152,12 +152,12 @@ class PaddleASRConnectionHanddler:
self.output_reset() self.output_reset()
def extract_feat(self, samples: ByteString): def extract_feat(self, samples: ByteString):
logger.info("Online ASR extract the feat") logger.debug("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16) samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1 assert samples.ndim == 1
self.num_samples += samples.shape[0] self.num_samples += samples.shape[0]
logger.info( logger.debug(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}" f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
) )
...@@ -168,7 +168,7 @@ class PaddleASRConnectionHanddler: ...@@ -168,7 +168,7 @@ class PaddleASRConnectionHanddler:
else: else:
assert self.remained_wav.ndim == 1 # (T,) assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples]) self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info( logger.debug(
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
) )
...@@ -202,14 +202,14 @@ class PaddleASRConnectionHanddler: ...@@ -202,14 +202,14 @@ class PaddleASRConnectionHanddler:
# update remained wav # update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:] self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info( logger.debug(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
) )
logger.info( logger.debug(
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
) )
logger.info(f"global samples: {self.num_samples}") logger.debug(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}") logger.debug(f"global frames: {self.num_frames}")
def decode(self, is_finished=False): def decode(self, is_finished=False):
"""advance decoding """advance decoding
...@@ -237,7 +237,7 @@ class PaddleASRConnectionHanddler: ...@@ -237,7 +237,7 @@ class PaddleASRConnectionHanddler:
return return
num_frames = self.cached_feat.shape[1] num_frames = self.cached_feat.shape[1]
logger.info( logger.debug(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
) )
...@@ -355,7 +355,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -355,7 +355,7 @@ class ASRServerExecutor(ASRExecutor):
lm_url = self.task_resource.res_dict['lm_url'] lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5'] lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}") logger.debug(f"Start to load language model {lm_url}")
self.download_lm( self.download_lm(
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
...@@ -367,7 +367,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -367,7 +367,7 @@ class ASRServerExecutor(ASRExecutor):
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
# AM predictor # AM predictor
logger.info("ASR engine start to init the am predictor") logger.debug("ASR engine start to init the am predictor")
self.am_predictor = onnx_infer.get_sess( self.am_predictor = onnx_infer.get_sess(
model_path=self.am_model, sess_conf=self.am_predictor_conf) model_path=self.am_model, sess_conf=self.am_predictor_conf)
else: else:
...@@ -400,7 +400,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -400,7 +400,7 @@ class ASRServerExecutor(ASRExecutor):
self.num_decoding_left_chunks = num_decoding_left_chunks self.num_decoding_left_chunks = num_decoding_left_chunks
# conf for paddleinference predictor or onnx # conf for paddleinference predictor or onnx
self.am_predictor_conf = am_predictor_conf self.am_predictor_conf = am_predictor_conf
logger.info(f"model_type: {self.model_type}") logger.debug(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
...@@ -422,12 +422,11 @@ class ASRServerExecutor(ASRExecutor): ...@@ -422,12 +422,11 @@ class ASRServerExecutor(ASRExecutor):
# self.res_path, self.task_resource.res_dict[ # self.res_path, self.task_resource.res_dict[
# 'params']) if am_params is None else os.path.abspath(am_params) # 'params']) if am_params is None else os.path.abspath(am_params)
logger.info("Load the pretrained model:") logger.debug("Load the pretrained model:")
logger.info(f" tag = {tag}") logger.debug(f" tag = {tag}")
logger.info(f" res_path: {self.res_path}") logger.debug(f" res_path: {self.res_path}")
logger.info(f" cfg path: {self.cfg_path}") logger.debug(f" cfg path: {self.cfg_path}")
logger.info(f" am_model path: {self.am_model}") logger.debug(f" am_model path: {self.am_model}")
# logger.info(f" am_params path: {self.am_params}")
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
...@@ -436,7 +435,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -436,7 +435,7 @@ class ASRServerExecutor(ASRExecutor):
if self.config.spm_model_prefix: if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join( self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix) self.res_path, self.config.spm_model_prefix)
logger.info(f"spm model path: {self.config.spm_model_prefix}") logger.debug(f"spm model path: {self.config.spm_model_prefix}")
self.vocab = self.config.vocab_filepath self.vocab = self.config.vocab_filepath
...@@ -450,7 +449,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -450,7 +449,7 @@ class ASRServerExecutor(ASRExecutor):
# AM predictor # AM predictor
self.init_model() self.init_model()
logger.info(f"create the {model_type} model success") logger.debug(f"create the {model_type} model success")
return True return True
...@@ -501,7 +500,7 @@ class ASREngine(BaseEngine): ...@@ -501,7 +500,7 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'") "If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1) sys.exit(-1)
logger.info(f"paddlespeech_server set the device: {self.device}") logger.debug(f"paddlespeech_server set the device: {self.device}")
if not self.init_model(): if not self.init_model():
logger.error( logger.error(
...@@ -509,7 +508,8 @@ class ASREngine(BaseEngine): ...@@ -509,7 +508,8 @@ class ASREngine(BaseEngine):
) )
return False return False
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully on device: %s." %
(self.device))
return True return True
def new_handler(self): def new_handler(self):
......
...@@ -44,7 +44,7 @@ class PaddleASRConnectionHanddler: ...@@ -44,7 +44,7 @@ class PaddleASRConnectionHanddler:
asr_engine (ASREngine): the global asr engine asr_engine (ASREngine): the global asr engine
""" """
super().__init__() super().__init__()
logger.info( logger.debug(
"create an paddle asr connection handler to process the websocket connection" "create an paddle asr connection handler to process the websocket connection"
) )
self.config = asr_engine.config # server config self.config = asr_engine.config # server config
...@@ -157,7 +157,7 @@ class PaddleASRConnectionHanddler: ...@@ -157,7 +157,7 @@ class PaddleASRConnectionHanddler:
assert samples.ndim == 1 assert samples.ndim == 1
self.num_samples += samples.shape[0] self.num_samples += samples.shape[0]
logger.info( logger.debug(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}" f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
) )
...@@ -168,7 +168,7 @@ class PaddleASRConnectionHanddler: ...@@ -168,7 +168,7 @@ class PaddleASRConnectionHanddler:
else: else:
assert self.remained_wav.ndim == 1 # (T,) assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples]) self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info( logger.debug(
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
) )
...@@ -202,14 +202,14 @@ class PaddleASRConnectionHanddler: ...@@ -202,14 +202,14 @@ class PaddleASRConnectionHanddler:
# update remained wav # update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:] self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info( logger.debug(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
) )
logger.info( logger.debug(
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
) )
logger.info(f"global samples: {self.num_samples}") logger.debug(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}") logger.debug(f"global frames: {self.num_frames}")
def decode(self, is_finished=False): def decode(self, is_finished=False):
"""advance decoding """advance decoding
...@@ -237,13 +237,13 @@ class PaddleASRConnectionHanddler: ...@@ -237,13 +237,13 @@ class PaddleASRConnectionHanddler:
return return
num_frames = self.cached_feat.shape[1] num_frames = self.cached_feat.shape[1]
logger.info( logger.debug(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
) )
# the cached feat must be larger decoding_window # the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished: if num_frames < decoding_window and not is_finished:
logger.info( logger.debug(
f"frame feat num is less than {decoding_window}, please input more pcm data" f"frame feat num is less than {decoding_window}, please input more pcm data"
) )
return None, None return None, None
...@@ -294,7 +294,7 @@ class PaddleASRConnectionHanddler: ...@@ -294,7 +294,7 @@ class PaddleASRConnectionHanddler:
Returns: Returns:
logprob: poster probability. logprob: poster probability.
""" """
logger.info("start to decoce one chunk for deepspeech2") logger.debug("start to decoce one chunk for deepspeech2")
input_names = self.am_predictor.get_input_names() input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
...@@ -369,7 +369,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -369,7 +369,7 @@ class ASRServerExecutor(ASRExecutor):
lm_url = self.task_resource.res_dict['lm_url'] lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5'] lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}") logger.debug(f"Start to load language model {lm_url}")
self.download_lm( self.download_lm(
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
...@@ -381,7 +381,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -381,7 +381,7 @@ class ASRServerExecutor(ASRExecutor):
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
# AM predictor # AM predictor
logger.info("ASR engine start to init the am predictor") logger.debug("ASR engine start to init the am predictor")
self.am_predictor = init_predictor( self.am_predictor = init_predictor(
model_file=self.am_model, model_file=self.am_model,
params_file=self.am_params, params_file=self.am_params,
...@@ -415,7 +415,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -415,7 +415,7 @@ class ASRServerExecutor(ASRExecutor):
self.num_decoding_left_chunks = num_decoding_left_chunks self.num_decoding_left_chunks = num_decoding_left_chunks
# conf for paddleinference predictor or onnx # conf for paddleinference predictor or onnx
self.am_predictor_conf = am_predictor_conf self.am_predictor_conf = am_predictor_conf
logger.info(f"model_type: {self.model_type}") logger.debug(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
...@@ -437,12 +437,12 @@ class ASRServerExecutor(ASRExecutor): ...@@ -437,12 +437,12 @@ class ASRServerExecutor(ASRExecutor):
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info("Load the pretrained model:") logger.debug("Load the pretrained model:")
logger.info(f" tag = {tag}") logger.debug(f" tag = {tag}")
logger.info(f" res_path: {self.res_path}") logger.debug(f" res_path: {self.res_path}")
logger.info(f" cfg path: {self.cfg_path}") logger.debug(f" cfg path: {self.cfg_path}")
logger.info(f" am_model path: {self.am_model}") logger.debug(f" am_model path: {self.am_model}")
logger.info(f" am_params path: {self.am_params}") logger.debug(f" am_params path: {self.am_params}")
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
...@@ -451,7 +451,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -451,7 +451,7 @@ class ASRServerExecutor(ASRExecutor):
if self.config.spm_model_prefix: if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join( self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix) self.res_path, self.config.spm_model_prefix)
logger.info(f"spm model path: {self.config.spm_model_prefix}") logger.debug(f"spm model path: {self.config.spm_model_prefix}")
self.vocab = self.config.vocab_filepath self.vocab = self.config.vocab_filepath
...@@ -465,7 +465,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -465,7 +465,7 @@ class ASRServerExecutor(ASRExecutor):
# AM predictor # AM predictor
self.init_model() self.init_model()
logger.info(f"create the {model_type} model success") logger.debug(f"create the {model_type} model success")
return True return True
...@@ -516,7 +516,7 @@ class ASREngine(BaseEngine): ...@@ -516,7 +516,7 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'") "If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1) sys.exit(-1)
logger.info(f"paddlespeech_server set the device: {self.device}") logger.debug(f"paddlespeech_server set the device: {self.device}")
if not self.init_model(): if not self.init_model():
logger.error( logger.error(
...@@ -524,7 +524,9 @@ class ASREngine(BaseEngine): ...@@ -524,7 +524,9 @@ class ASREngine(BaseEngine):
) )
return False return False
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully on device: %s." %
(self.device))
return True return True
def new_handler(self): def new_handler(self):
......
...@@ -49,7 +49,7 @@ class PaddleASRConnectionHanddler: ...@@ -49,7 +49,7 @@ class PaddleASRConnectionHanddler:
asr_engine (ASREngine): the global asr engine asr_engine (ASREngine): the global asr engine
""" """
super().__init__() super().__init__()
logger.info( logger.debug(
"create an paddle asr connection handler to process the websocket connection" "create an paddle asr connection handler to process the websocket connection"
) )
self.config = asr_engine.config # server config self.config = asr_engine.config # server config
...@@ -107,7 +107,7 @@ class PaddleASRConnectionHanddler: ...@@ -107,7 +107,7 @@ class PaddleASRConnectionHanddler:
# acoustic model # acoustic model
self.model = self.asr_engine.executor.model self.model = self.asr_engine.executor.model
self.continuous_decoding = self.config.continuous_decoding self.continuous_decoding = self.config.continuous_decoding
logger.info(f"continue decoding: {self.continuous_decoding}") logger.debug(f"continue decoding: {self.continuous_decoding}")
# ctc decoding config # ctc decoding config
self.ctc_decode_config = self.asr_engine.executor.config.decode self.ctc_decode_config = self.asr_engine.executor.config.decode
...@@ -207,7 +207,7 @@ class PaddleASRConnectionHanddler: ...@@ -207,7 +207,7 @@ class PaddleASRConnectionHanddler:
assert samples.ndim == 1 assert samples.ndim == 1
self.num_samples += samples.shape[0] self.num_samples += samples.shape[0]
logger.info( logger.debug(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}" f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
) )
...@@ -218,7 +218,7 @@ class PaddleASRConnectionHanddler: ...@@ -218,7 +218,7 @@ class PaddleASRConnectionHanddler:
else: else:
assert self.remained_wav.ndim == 1 # (T,) assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples]) self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info( logger.debug(
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
) )
...@@ -252,14 +252,14 @@ class PaddleASRConnectionHanddler: ...@@ -252,14 +252,14 @@ class PaddleASRConnectionHanddler:
# update remained wav # update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:] self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info( logger.debug(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
) )
logger.info( logger.debug(
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
) )
logger.info(f"global samples: {self.num_samples}") logger.debug(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}") logger.debug(f"global frames: {self.num_frames}")
def decode(self, is_finished=False): def decode(self, is_finished=False):
"""advance decoding """advance decoding
...@@ -283,24 +283,24 @@ class PaddleASRConnectionHanddler: ...@@ -283,24 +283,24 @@ class PaddleASRConnectionHanddler:
stride = subsampling * decoding_chunk_size stride = subsampling * decoding_chunk_size
if self.cached_feat is None: if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data") logger.debug("no audio feat, please input more pcm data")
return return
num_frames = self.cached_feat.shape[1] num_frames = self.cached_feat.shape[1]
logger.info( logger.debug(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
) )
# the cached feat must be larger decoding_window # the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished: if num_frames < decoding_window and not is_finished:
logger.info( logger.debug(
f"frame feat num is less than {decoding_window}, please input more pcm data" f"frame feat num is less than {decoding_window}, please input more pcm data"
) )
return None, None return None, None
# if is_finished=True, we need at least context frames # if is_finished=True, we need at least context frames
if num_frames < context: if num_frames < context:
logger.info( logger.debug(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward" "flast {num_frames} is less than context {context} frames, and we cannot do model forward"
) )
return None, None return None, None
...@@ -354,7 +354,7 @@ class PaddleASRConnectionHanddler: ...@@ -354,7 +354,7 @@ class PaddleASRConnectionHanddler:
Returns: Returns:
logprob: poster probability. logprob: poster probability.
""" """
logger.info("start to decoce one chunk for deepspeech2") logger.debug("start to decoce one chunk for deepspeech2")
input_names = self.am_predictor.get_input_names() input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
...@@ -391,7 +391,7 @@ class PaddleASRConnectionHanddler: ...@@ -391,7 +391,7 @@ class PaddleASRConnectionHanddler:
self.decoder.next(output_chunk_probs, output_chunk_lens) self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode() trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}") logger.debug(f"decode one best result for deepspeech2: {trans_best[0]}")
return trans_best[0] return trans_best[0]
@paddle.no_grad() @paddle.no_grad()
...@@ -402,7 +402,7 @@ class PaddleASRConnectionHanddler: ...@@ -402,7 +402,7 @@ class PaddleASRConnectionHanddler:
# reset endpiont state # reset endpiont state
self.endpoint_state = False self.endpoint_state = False
logger.info( logger.debug(
"Conformer/Transformer: start to decode with advanced_decoding method" "Conformer/Transformer: start to decode with advanced_decoding method"
) )
cfg = self.ctc_decode_config cfg = self.ctc_decode_config
...@@ -427,25 +427,25 @@ class PaddleASRConnectionHanddler: ...@@ -427,25 +427,25 @@ class PaddleASRConnectionHanddler:
stride = subsampling * decoding_chunk_size stride = subsampling * decoding_chunk_size
if self.cached_feat is None: if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data") logger.debug("no audio feat, please input more pcm data")
return return
# (B=1,T,D) # (B=1,T,D)
num_frames = self.cached_feat.shape[1] num_frames = self.cached_feat.shape[1]
logger.info( logger.debug(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
) )
# the cached feat must be larger decoding_window # the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished: if num_frames < decoding_window and not is_finished:
logger.info( logger.debug(
f"frame feat num is less than {decoding_window}, please input more pcm data" f"frame feat num is less than {decoding_window}, please input more pcm data"
) )
return None, None return None, None
# if is_finished=True, we need at least context frames # if is_finished=True, we need at least context frames
if num_frames < context: if num_frames < context:
logger.info( logger.debug(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward" "flast {num_frames} is less than context {context} frames, and we cannot do model forward"
) )
return None, None return None, None
...@@ -489,7 +489,7 @@ class PaddleASRConnectionHanddler: ...@@ -489,7 +489,7 @@ class PaddleASRConnectionHanddler:
self.encoder_out = ys self.encoder_out = ys
else: else:
self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1)
logger.info( logger.debug(
f"This connection handler encoder out shape: {self.encoder_out.shape}" f"This connection handler encoder out shape: {self.encoder_out.shape}"
) )
...@@ -513,7 +513,8 @@ class PaddleASRConnectionHanddler: ...@@ -513,7 +513,8 @@ class PaddleASRConnectionHanddler:
if self.endpointer.endpoint_detected(ctc_probs.numpy(), if self.endpointer.endpoint_detected(ctc_probs.numpy(),
decoding_something): decoding_something):
self.endpoint_state = True self.endpoint_state = True
logger.info(f"Endpoint is detected at {self.num_frames} frame.") logger.debug(
f"Endpoint is detected at {self.num_frames} frame.")
# advance cache of feat # advance cache of feat
assert self.cached_feat.shape[0] == 1 #(B=1,T,D) assert self.cached_feat.shape[0] == 1 #(B=1,T,D)
...@@ -526,7 +527,7 @@ class PaddleASRConnectionHanddler: ...@@ -526,7 +527,7 @@ class PaddleASRConnectionHanddler:
def update_result(self): def update_result(self):
"""Conformer/Transformer hyps to result. """Conformer/Transformer hyps to result.
""" """
logger.info("update the final result") logger.debug("update the final result")
hyps = self.hyps hyps = self.hyps
# output results and tokenids # output results and tokenids
...@@ -560,16 +561,16 @@ class PaddleASRConnectionHanddler: ...@@ -560,16 +561,16 @@ class PaddleASRConnectionHanddler:
only for conformer and transformer model. only for conformer and transformer model.
""" """
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
logger.info("deepspeech2 not support rescoring decoding.") logger.debug("deepspeech2 not support rescoring decoding.")
return return
if "attention_rescoring" != self.ctc_decode_config.decoding_method: if "attention_rescoring" != self.ctc_decode_config.decoding_method:
logger.info( logger.debug(
f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring" f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring"
) )
return return
logger.info("rescoring the final result") logger.debug("rescoring the final result")
# last decoding for last audio # last decoding for last audio
self.searcher.finalize_search() self.searcher.finalize_search()
...@@ -685,7 +686,6 @@ class PaddleASRConnectionHanddler: ...@@ -685,7 +686,6 @@ class PaddleASRConnectionHanddler:
"bg": global_offset_in_sec + start, "bg": global_offset_in_sec + start,
"ed": global_offset_in_sec + end "ed": global_offset_in_sec + end
}) })
# logger.info(f"{word_time_stamp[-1]}")
self.word_time_stamp = word_time_stamp self.word_time_stamp = word_time_stamp
logger.info(f"word time stamp: {self.word_time_stamp}") logger.info(f"word time stamp: {self.word_time_stamp}")
...@@ -707,13 +707,13 @@ class ASRServerExecutor(ASRExecutor): ...@@ -707,13 +707,13 @@ class ASRServerExecutor(ASRExecutor):
lm_url = self.task_resource.res_dict['lm_url'] lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5'] lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}") logger.debug(f"Start to load language model {lm_url}")
self.download_lm( self.download_lm(
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in self.model_type or "transformer" in self.model_type: elif "conformer" in self.model_type or "transformer" in self.model_type:
with UpdateConfig(self.config): with UpdateConfig(self.config):
logger.info("start to create the stream conformer asr engine") logger.debug("start to create the stream conformer asr engine")
# update the decoding method # update the decoding method
if self.decode_method: if self.decode_method:
self.config.decode.decoding_method = self.decode_method self.config.decode.decoding_method = self.decode_method
...@@ -726,7 +726,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -726,7 +726,7 @@ class ASRServerExecutor(ASRExecutor):
if self.config.decode.decoding_method not in [ if self.config.decode.decoding_method not in [
"ctc_prefix_beam_search", "attention_rescoring" "ctc_prefix_beam_search", "attention_rescoring"
]: ]:
logger.info( logger.debug(
"we set the decoding_method to attention_rescoring") "we set the decoding_method to attention_rescoring")
self.config.decode.decoding_method = "attention_rescoring" self.config.decode.decoding_method = "attention_rescoring"
...@@ -739,7 +739,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -739,7 +739,7 @@ class ASRServerExecutor(ASRExecutor):
def init_model(self) -> None: def init_model(self) -> None:
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
# AM predictor # AM predictor
logger.info("ASR engine start to init the am predictor") logger.debug("ASR engine start to init the am predictor")
self.am_predictor = init_predictor( self.am_predictor = init_predictor(
model_file=self.am_model, model_file=self.am_model,
params_file=self.am_params, params_file=self.am_params,
...@@ -748,7 +748,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -748,7 +748,7 @@ class ASRServerExecutor(ASRExecutor):
# load model # load model
# model_type: {model_name}_{dataset} # model_type: {model_name}_{dataset}
model_name = self.model_type[:self.model_type.rindex('_')] model_name = self.model_type[:self.model_type.rindex('_')]
logger.info(f"model name: {model_name}") logger.debug(f"model name: {model_name}")
model_class = self.task_resource.get_model_class(model_name) model_class = self.task_resource.get_model_class(model_name)
model = model_class.from_config(self.config) model = model_class.from_config(self.config)
self.model = model self.model = model
...@@ -782,7 +782,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -782,7 +782,7 @@ class ASRServerExecutor(ASRExecutor):
self.num_decoding_left_chunks = num_decoding_left_chunks self.num_decoding_left_chunks = num_decoding_left_chunks
# conf for paddleinference predictor or onnx # conf for paddleinference predictor or onnx
self.am_predictor_conf = am_predictor_conf self.am_predictor_conf = am_predictor_conf
logger.info(f"model_type: {self.model_type}") logger.debug(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
...@@ -804,12 +804,12 @@ class ASRServerExecutor(ASRExecutor): ...@@ -804,12 +804,12 @@ class ASRServerExecutor(ASRExecutor):
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info("Load the pretrained model:") logger.debug("Load the pretrained model:")
logger.info(f" tag = {tag}") logger.debug(f" tag = {tag}")
logger.info(f" res_path: {self.res_path}") logger.debug(f" res_path: {self.res_path}")
logger.info(f" cfg path: {self.cfg_path}") logger.debug(f" cfg path: {self.cfg_path}")
logger.info(f" am_model path: {self.am_model}") logger.debug(f" am_model path: {self.am_model}")
logger.info(f" am_params path: {self.am_params}") logger.debug(f" am_params path: {self.am_params}")
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
...@@ -818,7 +818,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -818,7 +818,7 @@ class ASRServerExecutor(ASRExecutor):
if self.config.spm_model_prefix: if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join( self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix) self.res_path, self.config.spm_model_prefix)
logger.info(f"spm model path: {self.config.spm_model_prefix}") logger.debug(f"spm model path: {self.config.spm_model_prefix}")
self.vocab = self.config.vocab_filepath self.vocab = self.config.vocab_filepath
...@@ -832,7 +832,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -832,7 +832,7 @@ class ASRServerExecutor(ASRExecutor):
# AM predictor # AM predictor
self.init_model() self.init_model()
logger.info(f"create the {model_type} model success") logger.debug(f"create the {model_type} model success")
return True return True
...@@ -883,7 +883,7 @@ class ASREngine(BaseEngine): ...@@ -883,7 +883,7 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'") "If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1) sys.exit(-1)
logger.info(f"paddlespeech_server set the device: {self.device}") logger.debug(f"paddlespeech_server set the device: {self.device}")
if not self.init_model(): if not self.init_model():
logger.error( logger.error(
...@@ -891,7 +891,9 @@ class ASREngine(BaseEngine): ...@@ -891,7 +891,9 @@ class ASREngine(BaseEngine):
) )
return False return False
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully on device: %s." %
(self.device))
return True return True
def new_handler(self): def new_handler(self):
......
...@@ -65,10 +65,10 @@ class ASRServerExecutor(ASRExecutor): ...@@ -65,10 +65,10 @@ class ASRServerExecutor(ASRExecutor):
self.task_resource.res_dict['model']) self.task_resource.res_dict['model'])
self.am_params = os.path.join(self.res_path, self.am_params = os.path.join(self.res_path,
self.task_resource.res_dict['params']) self.task_resource.res_dict['params'])
logger.info(self.res_path) logger.debug(self.res_path)
logger.info(self.cfg_path) logger.debug(self.cfg_path)
logger.info(self.am_model) logger.debug(self.am_model)
logger.info(self.am_params) logger.debug(self.am_params)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model) self.am_model = os.path.abspath(am_model)
...@@ -236,16 +236,16 @@ class PaddleASRConnectionHandler(ASRServerExecutor): ...@@ -236,16 +236,16 @@ class PaddleASRConnectionHandler(ASRServerExecutor):
if self._check( if self._check(
io.BytesIO(audio_data), self.asr_engine.config.sample_rate, io.BytesIO(audio_data), self.asr_engine.config.sample_rate,
self.asr_engine.config.force_yes): self.asr_engine.config.force_yes):
logger.info("start running asr engine") logger.debug("start running asr engine")
self.preprocess(self.asr_engine.config.model_type, self.preprocess(self.asr_engine.config.model_type,
io.BytesIO(audio_data)) io.BytesIO(audio_data))
st = time.time() st = time.time()
self.infer(self.asr_engine.config.model_type) self.infer(self.asr_engine.config.model_type)
infer_time = time.time() - st infer_time = time.time() - st
self.output = self.postprocess() # Retrieve result of asr. self.output = self.postprocess() # Retrieve result of asr.
logger.info("end inferring asr engine") logger.debug("end inferring asr engine")
else: else:
logger.info("file check failed!") logger.error("file check failed!")
self.output = None self.output = None
logger.info("inference time: {}".format(infer_time)) logger.info("inference time: {}".format(infer_time))
......
...@@ -104,7 +104,7 @@ class PaddleASRConnectionHandler(ASRServerExecutor): ...@@ -104,7 +104,7 @@ class PaddleASRConnectionHandler(ASRServerExecutor):
if self._check( if self._check(
io.BytesIO(audio_data), self.asr_engine.config.sample_rate, io.BytesIO(audio_data), self.asr_engine.config.sample_rate,
self.asr_engine.config.force_yes): self.asr_engine.config.force_yes):
logger.info("start run asr engine") logger.debug("start run asr engine")
self.preprocess(self.asr_engine.config.model, self.preprocess(self.asr_engine.config.model,
io.BytesIO(audio_data)) io.BytesIO(audio_data))
st = time.time() st = time.time()
...@@ -112,7 +112,7 @@ class PaddleASRConnectionHandler(ASRServerExecutor): ...@@ -112,7 +112,7 @@ class PaddleASRConnectionHandler(ASRServerExecutor):
infer_time = time.time() - st infer_time = time.time() - st
self.output = self.postprocess() # Retrieve result of asr. self.output = self.postprocess() # Retrieve result of asr.
else: else:
logger.info("file check failed!") logger.error("file check failed!")
self.output = None self.output = None
logger.info("inference time: {}".format(infer_time)) logger.info("inference time: {}".format(infer_time))
......
...@@ -67,22 +67,22 @@ class CLSServerExecutor(CLSExecutor): ...@@ -67,22 +67,22 @@ class CLSServerExecutor(CLSExecutor):
self.params_path = os.path.abspath(params_path) self.params_path = os.path.abspath(params_path)
self.label_file = os.path.abspath(label_file) self.label_file = os.path.abspath(label_file)
logger.info(self.cfg_path) logger.debug(self.cfg_path)
logger.info(self.model_path) logger.debug(self.model_path)
logger.info(self.params_path) logger.debug(self.params_path)
logger.info(self.label_file) logger.debug(self.label_file)
# config # config
with open(self.cfg_path, 'r') as f: with open(self.cfg_path, 'r') as f:
self._conf = yaml.safe_load(f) self._conf = yaml.safe_load(f)
logger.info("Read cfg file successfully.") logger.debug("Read cfg file successfully.")
# labels # labels
self._label_list = [] self._label_list = []
with open(self.label_file, 'r') as f: with open(self.label_file, 'r') as f:
for line in f: for line in f:
self._label_list.append(line.strip()) self._label_list.append(line.strip())
logger.info("Read label file successfully.") logger.debug("Read label file successfully.")
# Create predictor # Create predictor
self.predictor_conf = predictor_conf self.predictor_conf = predictor_conf
...@@ -90,7 +90,7 @@ class CLSServerExecutor(CLSExecutor): ...@@ -90,7 +90,7 @@ class CLSServerExecutor(CLSExecutor):
model_file=self.model_path, model_file=self.model_path,
params_file=self.params_path, params_file=self.params_path,
predictor_conf=self.predictor_conf) predictor_conf=self.predictor_conf)
logger.info("Create predictor successfully.") logger.debug("Create predictor successfully.")
@paddle.no_grad() @paddle.no_grad()
def infer(self): def infer(self):
...@@ -148,7 +148,8 @@ class CLSEngine(BaseEngine): ...@@ -148,7 +148,8 @@ class CLSEngine(BaseEngine):
logger.error(e) logger.error(e)
return False return False
logger.info("Initialize CLS server engine successfully.") logger.info("Initialize CLS server engine successfully on device: %s." %
(self.device))
return True return True
...@@ -160,7 +161,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor): ...@@ -160,7 +161,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
cls_engine (CLSEngine): The CLS engine cls_engine (CLSEngine): The CLS engine
""" """
super().__init__() super().__init__()
logger.info( logger.debug(
"Create PaddleCLSConnectionHandler to process the cls request") "Create PaddleCLSConnectionHandler to process the cls request")
self._inputs = OrderedDict() self._inputs = OrderedDict()
...@@ -183,7 +184,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor): ...@@ -183,7 +184,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
self.infer() self.infer()
infer_time = time.time() - st infer_time = time.time() - st
logger.info("inference time: {}".format(infer_time)) logger.debug("inference time: {}".format(infer_time))
logger.info("cls engine type: inference") logger.info("cls engine type: inference")
def postprocess(self, topk: int): def postprocess(self, topk: int):
......
...@@ -88,7 +88,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor): ...@@ -88,7 +88,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
cls_engine (CLSEngine): The CLS engine cls_engine (CLSEngine): The CLS engine
""" """
super().__init__() super().__init__()
logger.info( logger.debug(
"Create PaddleCLSConnectionHandler to process the cls request") "Create PaddleCLSConnectionHandler to process the cls request")
self._inputs = OrderedDict() self._inputs = OrderedDict()
...@@ -110,7 +110,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor): ...@@ -110,7 +110,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
self.infer() self.infer()
infer_time = time.time() - st infer_time = time.time() - st
logger.info("inference time: {}".format(infer_time)) logger.debug("inference time: {}".format(infer_time))
logger.info("cls engine type: python") logger.info("cls engine type: python")
def postprocess(self, topk: int): def postprocess(self, topk: int):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Text from typing import Text
from ..utils.log import logger from paddlespeech.cli.log import logger
__all__ = ['EngineFactory'] __all__ = ['EngineFactory']
......
...@@ -45,7 +45,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool: ...@@ -45,7 +45,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
logger.error("Please check tte engine type.") logger.error("Please check tte engine type.")
try: try:
logger.info("Start to warm up tts engine.") logger.debug("Start to warm up tts engine.")
for i in range(warm_up_time): for i in range(warm_up_time):
connection_handler = PaddleTTSConnectionHandler(tts_engine) connection_handler = PaddleTTSConnectionHandler(tts_engine)
if flag_online: if flag_online:
...@@ -53,7 +53,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool: ...@@ -53,7 +53,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
text=sentence, text=sentence,
lang=tts_engine.lang, lang=tts_engine.lang,
am=tts_engine.config.am): am=tts_engine.config.am):
logger.info( logger.debug(
f"The first response time of the {i} warm up: {connection_handler.first_response_time} s" f"The first response time of the {i} warm up: {connection_handler.first_response_time} s"
) )
break break
...@@ -62,7 +62,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool: ...@@ -62,7 +62,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
st = time.time() st = time.time()
connection_handler.infer(text=sentence) connection_handler.infer(text=sentence)
et = time.time() et = time.time()
logger.info( logger.debug(
f"The response time of the {i} warm up: {et - st} s") f"The response time of the {i} warm up: {et - st} s")
except Exception as e: except Exception as e:
logger.error("Failed to warm up on tts engine.") logger.error("Failed to warm up on tts engine.")
......
...@@ -28,7 +28,7 @@ class PaddleTextConnectionHandler: ...@@ -28,7 +28,7 @@ class PaddleTextConnectionHandler:
text_engine (TextEngine): The Text engine text_engine (TextEngine): The Text engine
""" """
super().__init__() super().__init__()
logger.info( logger.debug(
"Create PaddleTextConnectionHandler to process the text request") "Create PaddleTextConnectionHandler to process the text request")
self.text_engine = text_engine self.text_engine = text_engine
self.task = self.text_engine.executor.task self.task = self.text_engine.executor.task
...@@ -130,7 +130,7 @@ class TextEngine(BaseEngine): ...@@ -130,7 +130,7 @@ class TextEngine(BaseEngine):
"""The Text Engine """The Text Engine
""" """
super(TextEngine, self).__init__() super(TextEngine, self).__init__()
logger.info("Create the TextEngine Instance") logger.debug("Create the TextEngine Instance")
def init(self, config: dict): def init(self, config: dict):
"""Init the Text Engine """Init the Text Engine
...@@ -141,7 +141,7 @@ class TextEngine(BaseEngine): ...@@ -141,7 +141,7 @@ class TextEngine(BaseEngine):
Returns: Returns:
bool: The engine instance flag bool: The engine instance flag
""" """
logger.info("Init the text engine") logger.debug("Init the text engine")
try: try:
self.config = config self.config = config
if self.config.device: if self.config.device:
...@@ -150,7 +150,7 @@ class TextEngine(BaseEngine): ...@@ -150,7 +150,7 @@ class TextEngine(BaseEngine):
self.device = paddle.get_device() self.device = paddle.get_device()
paddle.set_device(self.device) paddle.set_device(self.device)
logger.info(f"Text Engine set the device: {self.device}") logger.debug(f"Text Engine set the device: {self.device}")
except BaseException as e: except BaseException as e:
logger.error( logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file" "Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
...@@ -168,5 +168,6 @@ class TextEngine(BaseEngine): ...@@ -168,5 +168,6 @@ class TextEngine(BaseEngine):
ckpt_path=config.ckpt_path, ckpt_path=config.ckpt_path,
vocab_file=config.vocab_file) vocab_file=config.vocab_file)
logger.info("Init the text engine successfully") logger.info("Initialize Text server engine successfully on device: %s."
% (self.device))
return True return True
...@@ -62,7 +62,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -62,7 +62,7 @@ class TTSServerExecutor(TTSExecutor):
(hasattr(self, 'am_encoder_infer_sess') and (hasattr(self, 'am_encoder_infer_sess') and
hasattr(self, 'am_decoder_sess') and hasattr( hasattr(self, 'am_decoder_sess') and hasattr(
self, 'am_postnet_sess'))) and hasattr(self, 'voc_inference'): self, 'am_postnet_sess'))) and hasattr(self, 'voc_inference'):
logger.info('Models had been initialized.') logger.debug('Models had been initialized.')
return return
# am # am
am_tag = am + '-' + lang am_tag = am + '-' + lang
...@@ -85,8 +85,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -85,8 +85,7 @@ class TTSServerExecutor(TTSExecutor):
else: else:
self.am_ckpt = os.path.abspath(am_ckpt[0]) self.am_ckpt = os.path.abspath(am_ckpt[0])
self.phones_dict = os.path.abspath(phones_dict) self.phones_dict = os.path.abspath(phones_dict)
self.am_res_path = os.path.dirname( self.am_res_path = os.path.dirname(os.path.abspath(am_ckpt))
os.path.abspath(am_ckpt))
# create am sess # create am sess
self.am_sess = get_sess(self.am_ckpt, am_sess_conf) self.am_sess = get_sess(self.am_ckpt, am_sess_conf)
...@@ -119,8 +118,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -119,8 +118,7 @@ class TTSServerExecutor(TTSExecutor):
self.am_postnet = os.path.abspath(am_ckpt[2]) self.am_postnet = os.path.abspath(am_ckpt[2])
self.phones_dict = os.path.abspath(phones_dict) self.phones_dict = os.path.abspath(phones_dict)
self.am_stat = os.path.abspath(am_stat) self.am_stat = os.path.abspath(am_stat)
self.am_res_path = os.path.dirname( self.am_res_path = os.path.dirname(os.path.abspath(am_ckpt[0]))
os.path.abspath(am_ckpt[0]))
# create am sess # create am sess
self.am_encoder_infer_sess = get_sess(self.am_encoder_infer, self.am_encoder_infer_sess = get_sess(self.am_encoder_infer,
...@@ -130,13 +128,13 @@ class TTSServerExecutor(TTSExecutor): ...@@ -130,13 +128,13 @@ class TTSServerExecutor(TTSExecutor):
self.am_mu, self.am_std = np.load(self.am_stat) self.am_mu, self.am_std = np.load(self.am_stat)
logger.info(f"self.phones_dict: {self.phones_dict}") logger.debug(f"self.phones_dict: {self.phones_dict}")
logger.info(f"am model dir: {self.am_res_path}") logger.debug(f"am model dir: {self.am_res_path}")
logger.info("Create am sess successfully.") logger.debug("Create am sess successfully.")
# voc model info # voc model info
voc_tag = voc + '-' + lang voc_tag = voc + '-' + lang
if voc_ckpt is None: if voc_ckpt is None:
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=voc_tag, model_tag=voc_tag,
...@@ -149,16 +147,16 @@ class TTSServerExecutor(TTSExecutor): ...@@ -149,16 +147,16 @@ class TTSServerExecutor(TTSExecutor):
else: else:
self.voc_ckpt = os.path.abspath(voc_ckpt) self.voc_ckpt = os.path.abspath(voc_ckpt)
self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt)) self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt))
logger.info(self.voc_res_path) logger.debug(self.voc_res_path)
# create voc sess # create voc sess
self.voc_sess = get_sess(self.voc_ckpt, voc_sess_conf) self.voc_sess = get_sess(self.voc_ckpt, voc_sess_conf)
logger.info("Create voc sess successfully.") logger.debug("Create voc sess successfully.")
with open(self.phones_dict, "r") as f: with open(self.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
self.vocab_size = len(phn_id) self.vocab_size = len(phn_id)
logger.info(f"vocab_size: {self.vocab_size}") logger.debug(f"vocab_size: {self.vocab_size}")
# frontend # frontend
self.tones_dict = None self.tones_dict = None
...@@ -169,7 +167,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -169,7 +167,7 @@ class TTSServerExecutor(TTSExecutor):
elif lang == 'en': elif lang == 'en':
self.frontend = English(phone_vocab_path=self.phones_dict) self.frontend = English(phone_vocab_path=self.phones_dict)
logger.info("frontend done!") logger.debug("frontend done!")
class TTSEngine(BaseEngine): class TTSEngine(BaseEngine):
...@@ -267,7 +265,7 @@ class PaddleTTSConnectionHandler: ...@@ -267,7 +265,7 @@ class PaddleTTSConnectionHandler:
tts_engine (TTSEngine): The TTS engine tts_engine (TTSEngine): The TTS engine
""" """
super().__init__() super().__init__()
logger.info( logger.debug(
"Create PaddleTTSConnectionHandler to process the tts request") "Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine self.tts_engine = tts_engine
......
...@@ -102,16 +102,22 @@ class TTSServerExecutor(TTSExecutor): ...@@ -102,16 +102,22 @@ class TTSServerExecutor(TTSExecutor):
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'): if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'):
logger.info('Models had been initialized.') logger.debug('Models had been initialized.')
return return
# am model info # am model info
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
use_pretrained_am = True
else:
use_pretrained_am = False
am_tag = am + '-' + lang am_tag = am + '-' + lang
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=am_tag, model_tag=am_tag,
model_type=0, # am model_type=0, # am
skip_download=not use_pretrained_am,
version=None, # default version version=None, # default version
) )
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: if use_pretrained_am:
self.am_res_path = self.task_resource.res_dir self.am_res_path = self.task_resource.res_dir
self.am_config = os.path.join(self.am_res_path, self.am_config = os.path.join(self.am_res_path,
self.task_resource.res_dict['config']) self.task_resource.res_dict['config'])
...@@ -122,29 +128,33 @@ class TTSServerExecutor(TTSExecutor): ...@@ -122,29 +128,33 @@ class TTSServerExecutor(TTSExecutor):
# must have phones_dict in acoustic # must have phones_dict in acoustic
self.phones_dict = os.path.join( self.phones_dict = os.path.join(
self.am_res_path, self.task_resource.res_dict['phones_dict']) self.am_res_path, self.task_resource.res_dict['phones_dict'])
print("self.phones_dict:", self.phones_dict) logger.debug(self.am_res_path)
logger.info(self.am_res_path) logger.debug(self.am_config)
logger.info(self.am_config) logger.debug(self.am_ckpt)
logger.info(self.am_ckpt)
else: else:
self.am_config = os.path.abspath(am_config) self.am_config = os.path.abspath(am_config)
self.am_ckpt = os.path.abspath(am_ckpt) self.am_ckpt = os.path.abspath(am_ckpt)
self.am_stat = os.path.abspath(am_stat) self.am_stat = os.path.abspath(am_stat)
self.phones_dict = os.path.abspath(phones_dict) self.phones_dict = os.path.abspath(phones_dict)
self.am_res_path = os.path.dirname(os.path.abspath(self.am_config)) self.am_res_path = os.path.dirname(os.path.abspath(self.am_config))
print("self.phones_dict:", self.phones_dict)
self.tones_dict = None self.tones_dict = None
self.speaker_dict = None self.speaker_dict = None
# voc model info # voc model info
if voc_ckpt is None or voc_config is None or voc_stat is None:
use_pretrained_voc = True
else:
use_pretrained_voc = False
voc_tag = voc + '-' + lang voc_tag = voc + '-' + lang
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=voc_tag, model_tag=voc_tag,
model_type=1, # vocoder model_type=1, # vocoder
skip_download=not use_pretrained_voc,
version=None, # default version version=None, # default version
) )
if voc_ckpt is None or voc_config is None or voc_stat is None: if use_pretrained_voc:
self.voc_res_path = self.task_resource.voc_res_dir self.voc_res_path = self.task_resource.voc_res_dir
self.voc_config = os.path.join( self.voc_config = os.path.join(
self.voc_res_path, self.task_resource.voc_res_dict['config']) self.voc_res_path, self.task_resource.voc_res_dict['config'])
...@@ -153,9 +163,9 @@ class TTSServerExecutor(TTSExecutor): ...@@ -153,9 +163,9 @@ class TTSServerExecutor(TTSExecutor):
self.voc_stat = os.path.join( self.voc_stat = os.path.join(
self.voc_res_path, self.voc_res_path,
self.task_resource.voc_res_dict['speech_stats']) self.task_resource.voc_res_dict['speech_stats'])
logger.info(self.voc_res_path) logger.debug(self.voc_res_path)
logger.info(self.voc_config) logger.debug(self.voc_config)
logger.info(self.voc_ckpt) logger.debug(self.voc_ckpt)
else: else:
self.voc_config = os.path.abspath(voc_config) self.voc_config = os.path.abspath(voc_config)
self.voc_ckpt = os.path.abspath(voc_ckpt) self.voc_ckpt = os.path.abspath(voc_ckpt)
...@@ -172,7 +182,6 @@ class TTSServerExecutor(TTSExecutor): ...@@ -172,7 +182,6 @@ class TTSServerExecutor(TTSExecutor):
with open(self.phones_dict, "r") as f: with open(self.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
self.vocab_size = len(phn_id) self.vocab_size = len(phn_id)
print("vocab_size:", self.vocab_size)
# frontend # frontend
if lang == 'zh': if lang == 'zh':
...@@ -182,7 +191,6 @@ class TTSServerExecutor(TTSExecutor): ...@@ -182,7 +191,6 @@ class TTSServerExecutor(TTSExecutor):
elif lang == 'en': elif lang == 'en':
self.frontend = English(phone_vocab_path=self.phones_dict) self.frontend = English(phone_vocab_path=self.phones_dict)
print("frontend done!")
# am infer info # am infer info
self.am_name = am[:am.rindex('_')] self.am_name = am[:am.rindex('_')]
...@@ -197,7 +205,6 @@ class TTSServerExecutor(TTSExecutor): ...@@ -197,7 +205,6 @@ class TTSServerExecutor(TTSExecutor):
self.am_name + '_inference') self.am_name + '_inference')
self.am_inference = am_inference_class(am_normalizer, am) self.am_inference = am_inference_class(am_normalizer, am)
self.am_inference.eval() self.am_inference.eval()
print("acoustic model done!")
# voc infer info # voc infer info
self.voc_name = voc[:voc.rindex('_')] self.voc_name = voc[:voc.rindex('_')]
...@@ -208,7 +215,6 @@ class TTSServerExecutor(TTSExecutor): ...@@ -208,7 +215,6 @@ class TTSServerExecutor(TTSExecutor):
'_inference') '_inference')
self.voc_inference = voc_inference_class(voc_normalizer, voc) self.voc_inference = voc_inference_class(voc_normalizer, voc)
self.voc_inference.eval() self.voc_inference.eval()
print("voc done!")
class TTSEngine(BaseEngine): class TTSEngine(BaseEngine):
...@@ -297,7 +303,7 @@ class PaddleTTSConnectionHandler: ...@@ -297,7 +303,7 @@ class PaddleTTSConnectionHandler:
tts_engine (TTSEngine): The TTS engine tts_engine (TTSEngine): The TTS engine
""" """
super().__init__() super().__init__()
logger.info( logger.debug(
"Create PaddleTTSConnectionHandler to process the tts request") "Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine self.tts_engine = tts_engine
...@@ -357,7 +363,7 @@ class PaddleTTSConnectionHandler: ...@@ -357,7 +363,7 @@ class PaddleTTSConnectionHandler:
text, merge_sentences=merge_sentences) text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
else: else:
print("lang should in {'zh', 'en'}!") logger.error("lang should in {'zh', 'en'}!")
frontend_et = time.time() frontend_et = time.time()
self.frontend_time = frontend_et - frontend_st self.frontend_time = frontend_et - frontend_st
......
...@@ -65,16 +65,22 @@ class TTSServerExecutor(TTSExecutor): ...@@ -65,16 +65,22 @@ class TTSServerExecutor(TTSExecutor):
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if hasattr(self, 'am_predictor') and hasattr(self, 'voc_predictor'): if hasattr(self, 'am_predictor') and hasattr(self, 'voc_predictor'):
logger.info('Models had been initialized.') logger.debug('Models had been initialized.')
return return
# am # am
if am_model is None or am_params is None or phones_dict is None:
use_pretrained_am = True
else:
use_pretrained_am = False
am_tag = am + '-' + lang am_tag = am + '-' + lang
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=am_tag, model_tag=am_tag,
model_type=0, # am model_type=0, # am
skip_download=not use_pretrained_am,
version=None, # default version version=None, # default version
) )
if am_model is None or am_params is None or phones_dict is None: if use_pretrained_am:
self.am_res_path = self.task_resource.res_dir self.am_res_path = self.task_resource.res_dir
self.am_model = os.path.join(self.am_res_path, self.am_model = os.path.join(self.am_res_path,
self.task_resource.res_dict['model']) self.task_resource.res_dict['model'])
...@@ -85,16 +91,16 @@ class TTSServerExecutor(TTSExecutor): ...@@ -85,16 +91,16 @@ class TTSServerExecutor(TTSExecutor):
self.am_res_path, self.task_resource.res_dict['phones_dict']) self.am_res_path, self.task_resource.res_dict['phones_dict'])
self.am_sample_rate = self.task_resource.res_dict['sample_rate'] self.am_sample_rate = self.task_resource.res_dict['sample_rate']
logger.info(self.am_res_path) logger.debug(self.am_res_path)
logger.info(self.am_model) logger.debug(self.am_model)
logger.info(self.am_params) logger.debug(self.am_params)
else: else:
self.am_model = os.path.abspath(am_model) self.am_model = os.path.abspath(am_model)
self.am_params = os.path.abspath(am_params) self.am_params = os.path.abspath(am_params)
self.phones_dict = os.path.abspath(phones_dict) self.phones_dict = os.path.abspath(phones_dict)
self.am_sample_rate = am_sample_rate self.am_sample_rate = am_sample_rate
self.am_res_path = os.path.dirname(os.path.abspath(self.am_model)) self.am_res_path = os.path.dirname(os.path.abspath(self.am_model))
logger.info("self.phones_dict: {}".format(self.phones_dict)) logger.debug("self.phones_dict: {}".format(self.phones_dict))
# for speedyspeech # for speedyspeech
self.tones_dict = None self.tones_dict = None
...@@ -113,13 +119,19 @@ class TTSServerExecutor(TTSExecutor): ...@@ -113,13 +119,19 @@ class TTSServerExecutor(TTSExecutor):
self.speaker_dict = speaker_dict self.speaker_dict = speaker_dict
# voc # voc
if voc_model is None or voc_params is None:
use_pretrained_voc = True
else:
use_pretrained_voc = False
voc_tag = voc + '-' + lang voc_tag = voc + '-' + lang
self.task_resource.set_task_model( self.task_resource.set_task_model(
model_tag=voc_tag, model_tag=voc_tag,
model_type=1, # vocoder model_type=1, # vocoder
skip_download=not use_pretrained_voc,
version=None, # default version version=None, # default version
) )
if voc_model is None or voc_params is None: if use_pretrained_voc:
self.voc_res_path = self.task_resource.voc_res_dir self.voc_res_path = self.task_resource.voc_res_dir
self.voc_model = os.path.join( self.voc_model = os.path.join(
self.voc_res_path, self.task_resource.voc_res_dict['model']) self.voc_res_path, self.task_resource.voc_res_dict['model'])
...@@ -127,9 +139,9 @@ class TTSServerExecutor(TTSExecutor): ...@@ -127,9 +139,9 @@ class TTSServerExecutor(TTSExecutor):
self.voc_res_path, self.task_resource.voc_res_dict['params']) self.voc_res_path, self.task_resource.voc_res_dict['params'])
self.voc_sample_rate = self.task_resource.voc_res_dict[ self.voc_sample_rate = self.task_resource.voc_res_dict[
'sample_rate'] 'sample_rate']
logger.info(self.voc_res_path) logger.debug(self.voc_res_path)
logger.info(self.voc_model) logger.debug(self.voc_model)
logger.info(self.voc_params) logger.debug(self.voc_params)
else: else:
self.voc_model = os.path.abspath(voc_model) self.voc_model = os.path.abspath(voc_model)
self.voc_params = os.path.abspath(voc_params) self.voc_params = os.path.abspath(voc_params)
...@@ -144,21 +156,21 @@ class TTSServerExecutor(TTSExecutor): ...@@ -144,21 +156,21 @@ class TTSServerExecutor(TTSExecutor):
with open(self.phones_dict, "r") as f: with open(self.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
logger.info("vocab_size: {}".format(vocab_size)) logger.debug("vocab_size: {}".format(vocab_size))
tone_size = None tone_size = None
if self.tones_dict: if self.tones_dict:
with open(self.tones_dict, "r") as f: with open(self.tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()] tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id) tone_size = len(tone_id)
logger.info("tone_size: {}".format(tone_size)) logger.debug("tone_size: {}".format(tone_size))
spk_num = None spk_num = None
if self.speaker_dict: if self.speaker_dict:
with open(self.speaker_dict, 'rt') as f: with open(self.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id) spk_num = len(spk_id)
logger.info("spk_num: {}".format(spk_num)) logger.debug("spk_num: {}".format(spk_num))
# frontend # frontend
if lang == 'zh': if lang == 'zh':
...@@ -168,7 +180,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -168,7 +180,7 @@ class TTSServerExecutor(TTSExecutor):
elif lang == 'en': elif lang == 'en':
self.frontend = English(phone_vocab_path=self.phones_dict) self.frontend = English(phone_vocab_path=self.phones_dict)
logger.info("frontend done!") logger.debug("frontend done!")
# Create am predictor # Create am predictor
self.am_predictor_conf = am_predictor_conf self.am_predictor_conf = am_predictor_conf
...@@ -176,7 +188,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -176,7 +188,7 @@ class TTSServerExecutor(TTSExecutor):
model_file=self.am_model, model_file=self.am_model,
params_file=self.am_params, params_file=self.am_params,
predictor_conf=self.am_predictor_conf) predictor_conf=self.am_predictor_conf)
logger.info("Create AM predictor successfully.") logger.debug("Create AM predictor successfully.")
# Create voc predictor # Create voc predictor
self.voc_predictor_conf = voc_predictor_conf self.voc_predictor_conf = voc_predictor_conf
...@@ -184,7 +196,7 @@ class TTSServerExecutor(TTSExecutor): ...@@ -184,7 +196,7 @@ class TTSServerExecutor(TTSExecutor):
model_file=self.voc_model, model_file=self.voc_model,
params_file=self.voc_params, params_file=self.voc_params,
predictor_conf=self.voc_predictor_conf) predictor_conf=self.voc_predictor_conf)
logger.info("Create Vocoder predictor successfully.") logger.debug("Create Vocoder predictor successfully.")
@paddle.no_grad() @paddle.no_grad()
def infer(self, def infer(self,
...@@ -316,7 +328,8 @@ class TTSEngine(BaseEngine): ...@@ -316,7 +328,8 @@ class TTSEngine(BaseEngine):
logger.error(e) logger.error(e)
return False return False
logger.info("Initialize TTS server engine successfully.") logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
return True return True
...@@ -328,7 +341,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor): ...@@ -328,7 +341,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
tts_engine (TTSEngine): The TTS engine tts_engine (TTSEngine): The TTS engine
""" """
super().__init__() super().__init__()
logger.info( logger.debug(
"Create PaddleTTSConnectionHandler to process the tts request") "Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine self.tts_engine = tts_engine
...@@ -366,23 +379,23 @@ class PaddleTTSConnectionHandler(TTSServerExecutor): ...@@ -366,23 +379,23 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
if target_fs == 0 or target_fs > original_fs: if target_fs == 0 or target_fs > original_fs:
target_fs = original_fs target_fs = original_fs
wav_tar_fs = wav wav_tar_fs = wav
logger.info( logger.debug(
"The sample rate of synthesized audio is the same as model, which is {}Hz". "The sample rate of synthesized audio is the same as model, which is {}Hz".
format(original_fs)) format(original_fs))
else: else:
wav_tar_fs = librosa.resample( wav_tar_fs = librosa.resample(
np.squeeze(wav), original_fs, target_fs) np.squeeze(wav), original_fs, target_fs)
logger.info( logger.debug(
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully.". "The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully.".
format(original_fs, target_fs)) format(original_fs, target_fs))
# transform volume # transform volume
wav_vol = wav_tar_fs * volume wav_vol = wav_tar_fs * volume
logger.info("Transform the volume of the audio successfully.") logger.debug("Transform the volume of the audio successfully.")
# transform speed # transform speed
try: # windows not support soxbindings try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs) wav_speed = change_speed(wav_vol, speed, target_fs)
logger.info("Transform the speed of the audio successfully.") logger.debug("Transform the speed of the audio successfully.")
except ServerBaseException: except ServerBaseException:
raise ServerBaseException( raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR, ErrorCode.SERVER_INTERNAL_ERR,
...@@ -399,7 +412,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor): ...@@ -399,7 +412,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
wavfile.write(buf, target_fs, wav_speed) wavfile.write(buf, target_fs, wav_speed)
base64_bytes = base64.b64encode(buf.read()) base64_bytes = base64.b64encode(buf.read())
wav_base64 = base64_bytes.decode('utf-8') wav_base64 = base64_bytes.decode('utf-8')
logger.info("Audio to string successfully.") logger.debug("Audio to string successfully.")
# save audio # save audio
if audio_path is not None: if audio_path is not None:
...@@ -487,15 +500,15 @@ class PaddleTTSConnectionHandler(TTSServerExecutor): ...@@ -487,15 +500,15 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
logger.error(e) logger.error(e)
sys.exit(-1) sys.exit(-1)
logger.info("AM model: {}".format(self.config.am)) logger.debug("AM model: {}".format(self.config.am))
logger.info("Vocoder model: {}".format(self.config.voc)) logger.debug("Vocoder model: {}".format(self.config.voc))
logger.info("Language: {}".format(lang)) logger.debug("Language: {}".format(lang))
logger.info("tts engine type: python") logger.info("tts engine type: python")
logger.info("audio duration: {}".format(duration)) logger.info("audio duration: {}".format(duration))
logger.info("frontend inference time: {}".format(self.frontend_time)) logger.debug("frontend inference time: {}".format(self.frontend_time))
logger.info("AM inference time: {}".format(self.am_time)) logger.debug("AM inference time: {}".format(self.am_time))
logger.info("Vocoder inference time: {}".format(self.voc_time)) logger.debug("Vocoder inference time: {}".format(self.voc_time))
logger.info("total inference time: {}".format(infer_time)) logger.info("total inference time: {}".format(infer_time))
logger.info( logger.info(
"postprocess (change speed, volume, target sample rate) time: {}". "postprocess (change speed, volume, target sample rate) time: {}".
...@@ -503,6 +516,6 @@ class PaddleTTSConnectionHandler(TTSServerExecutor): ...@@ -503,6 +516,6 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
logger.info("total generate audio time: {}".format(infer_time + logger.info("total generate audio time: {}".format(infer_time +
postprocess_time)) postprocess_time))
logger.info("RTF: {}".format(rtf)) logger.info("RTF: {}".format(rtf))
logger.info("device: {}".format(self.tts_engine.device)) logger.debug("device: {}".format(self.tts_engine.device))
return lang, target_sample_rate, duration, wav_base64 return lang, target_sample_rate, duration, wav_base64
...@@ -105,7 +105,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor): ...@@ -105,7 +105,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
tts_engine (TTSEngine): The TTS engine tts_engine (TTSEngine): The TTS engine
""" """
super().__init__() super().__init__()
logger.info( logger.debug(
"Create PaddleTTSConnectionHandler to process the tts request") "Create PaddleTTSConnectionHandler to process the tts request")
self.tts_engine = tts_engine self.tts_engine = tts_engine
...@@ -143,23 +143,23 @@ class PaddleTTSConnectionHandler(TTSServerExecutor): ...@@ -143,23 +143,23 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
if target_fs == 0 or target_fs > original_fs: if target_fs == 0 or target_fs > original_fs:
target_fs = original_fs target_fs = original_fs
wav_tar_fs = wav wav_tar_fs = wav
logger.info( logger.debug(
"The sample rate of synthesized audio is the same as model, which is {}Hz". "The sample rate of synthesized audio is the same as model, which is {}Hz".
format(original_fs)) format(original_fs))
else: else:
wav_tar_fs = librosa.resample( wav_tar_fs = librosa.resample(
np.squeeze(wav), original_fs, target_fs) np.squeeze(wav), original_fs, target_fs)
logger.info( logger.debug(
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully.". "The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully.".
format(original_fs, target_fs)) format(original_fs, target_fs))
# transform volume # transform volume
wav_vol = wav_tar_fs * volume wav_vol = wav_tar_fs * volume
logger.info("Transform the volume of the audio successfully.") logger.debug("Transform the volume of the audio successfully.")
# transform speed # transform speed
try: # windows not support soxbindings try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs) wav_speed = change_speed(wav_vol, speed, target_fs)
logger.info("Transform the speed of the audio successfully.") logger.debug("Transform the speed of the audio successfully.")
except ServerBaseException: except ServerBaseException:
raise ServerBaseException( raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR, ErrorCode.SERVER_INTERNAL_ERR,
...@@ -176,7 +176,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor): ...@@ -176,7 +176,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
wavfile.write(buf, target_fs, wav_speed) wavfile.write(buf, target_fs, wav_speed)
base64_bytes = base64.b64encode(buf.read()) base64_bytes = base64.b64encode(buf.read())
wav_base64 = base64_bytes.decode('utf-8') wav_base64 = base64_bytes.decode('utf-8')
logger.info("Audio to string successfully.") logger.debug("Audio to string successfully.")
# save audio # save audio
if audio_path is not None: if audio_path is not None:
...@@ -264,15 +264,15 @@ class PaddleTTSConnectionHandler(TTSServerExecutor): ...@@ -264,15 +264,15 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
logger.error(e) logger.error(e)
sys.exit(-1) sys.exit(-1)
logger.info("AM model: {}".format(self.config.am)) logger.debug("AM model: {}".format(self.config.am))
logger.info("Vocoder model: {}".format(self.config.voc)) logger.debug("Vocoder model: {}".format(self.config.voc))
logger.info("Language: {}".format(lang)) logger.debug("Language: {}".format(lang))
logger.info("tts engine type: python") logger.info("tts engine type: python")
logger.info("audio duration: {}".format(duration)) logger.info("audio duration: {}".format(duration))
logger.info("frontend inference time: {}".format(self.frontend_time)) logger.debug("frontend inference time: {}".format(self.frontend_time))
logger.info("AM inference time: {}".format(self.am_time)) logger.debug("AM inference time: {}".format(self.am_time))
logger.info("Vocoder inference time: {}".format(self.voc_time)) logger.debug("Vocoder inference time: {}".format(self.voc_time))
logger.info("total inference time: {}".format(infer_time)) logger.info("total inference time: {}".format(infer_time))
logger.info( logger.info(
"postprocess (change speed, volume, target sample rate) time: {}". "postprocess (change speed, volume, target sample rate) time: {}".
...@@ -280,6 +280,6 @@ class PaddleTTSConnectionHandler(TTSServerExecutor): ...@@ -280,6 +280,6 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
logger.info("total generate audio time: {}".format(infer_time + logger.info("total generate audio time: {}".format(infer_time +
postprocess_time)) postprocess_time))
logger.info("RTF: {}".format(rtf)) logger.info("RTF: {}".format(rtf))
logger.info("device: {}".format(self.tts_engine.device)) logger.debug("device: {}".format(self.tts_engine.device))
return lang, target_sample_rate, duration, wav_base64 return lang, target_sample_rate, duration, wav_base64
...@@ -33,7 +33,7 @@ class PaddleVectorConnectionHandler: ...@@ -33,7 +33,7 @@ class PaddleVectorConnectionHandler:
vector_engine (VectorEngine): The Vector engine vector_engine (VectorEngine): The Vector engine
""" """
super().__init__() super().__init__()
logger.info( logger.debug(
"Create PaddleVectorConnectionHandler to process the vector request") "Create PaddleVectorConnectionHandler to process the vector request")
self.vector_engine = vector_engine self.vector_engine = vector_engine
self.executor = self.vector_engine.executor self.executor = self.vector_engine.executor
...@@ -54,7 +54,7 @@ class PaddleVectorConnectionHandler: ...@@ -54,7 +54,7 @@ class PaddleVectorConnectionHandler:
Returns: Returns:
str: the punctuation text str: the punctuation text
""" """
logger.info( logger.debug(
f"start to extract the do vector {self.task} from the http request") f"start to extract the do vector {self.task} from the http request")
if self.task == "spk" and task == "spk": if self.task == "spk" and task == "spk":
embedding = self.extract_audio_embedding(audio_data) embedding = self.extract_audio_embedding(audio_data)
...@@ -81,17 +81,17 @@ class PaddleVectorConnectionHandler: ...@@ -81,17 +81,17 @@ class PaddleVectorConnectionHandler:
Returns: Returns:
float: the score between enroll and test audio float: the score between enroll and test audio
""" """
logger.info("start to extract the enroll audio embedding") logger.debug("start to extract the enroll audio embedding")
enroll_emb = self.extract_audio_embedding(enroll_audio) enroll_emb = self.extract_audio_embedding(enroll_audio)
logger.info("start to extract the test audio embedding") logger.debug("start to extract the test audio embedding")
test_emb = self.extract_audio_embedding(test_audio) test_emb = self.extract_audio_embedding(test_audio)
logger.info( logger.debug(
"start to get the score between the enroll and test embedding") "start to get the score between the enroll and test embedding")
score = self.executor.get_embeddings_score(enroll_emb, test_emb) score = self.executor.get_embeddings_score(enroll_emb, test_emb)
logger.info(f"get the enroll vs test score: {score}") logger.debug(f"get the enroll vs test score: {score}")
return score return score
@paddle.no_grad() @paddle.no_grad()
...@@ -106,11 +106,12 @@ class PaddleVectorConnectionHandler: ...@@ -106,11 +106,12 @@ class PaddleVectorConnectionHandler:
# because the soundfile will change the io.BytesIO(audio) to the end # because the soundfile will change the io.BytesIO(audio) to the end
# thus we should convert the base64 string to io.BytesIO when we need the audio data # thus we should convert the base64 string to io.BytesIO when we need the audio data
if not self.executor._check(io.BytesIO(audio), sample_rate): if not self.executor._check(io.BytesIO(audio), sample_rate):
logger.info("check the audio sample rate occurs error") logger.debug("check the audio sample rate occurs error")
return np.array([0.0]) return np.array([0.0])
waveform, sr = load_audio(io.BytesIO(audio)) waveform, sr = load_audio(io.BytesIO(audio))
logger.info(f"load the audio sample points, shape is: {waveform.shape}") logger.debug(
f"load the audio sample points, shape is: {waveform.shape}")
# stage 2: get the audio feat # stage 2: get the audio feat
# Note: Now we only support fbank feature # Note: Now we only support fbank feature
...@@ -121,9 +122,9 @@ class PaddleVectorConnectionHandler: ...@@ -121,9 +122,9 @@ class PaddleVectorConnectionHandler:
n_mels=self.config.n_mels, n_mels=self.config.n_mels,
window_size=self.config.window_size, window_size=self.config.window_size,
hop_length=self.config.hop_size) hop_length=self.config.hop_size)
logger.info(f"extract the audio feats, shape is: {feats.shape}") logger.debug(f"extract the audio feats, shape is: {feats.shape}")
except Exception as e: except Exception as e:
logger.info(f"feats occurs exception {e}") logger.error(f"feats occurs exception {e}")
sys.exit(-1) sys.exit(-1)
feats = paddle.to_tensor(feats).unsqueeze(0) feats = paddle.to_tensor(feats).unsqueeze(0)
...@@ -159,7 +160,7 @@ class VectorEngine(BaseEngine): ...@@ -159,7 +160,7 @@ class VectorEngine(BaseEngine):
"""The Vector Engine """The Vector Engine
""" """
super(VectorEngine, self).__init__() super(VectorEngine, self).__init__()
logger.info("Create the VectorEngine Instance") logger.debug("Create the VectorEngine Instance")
def init(self, config: dict): def init(self, config: dict):
"""Init the Vector Engine """Init the Vector Engine
...@@ -170,7 +171,7 @@ class VectorEngine(BaseEngine): ...@@ -170,7 +171,7 @@ class VectorEngine(BaseEngine):
Returns: Returns:
bool: The engine instance flag bool: The engine instance flag
""" """
logger.info("Init the vector engine") logger.debug("Init the vector engine")
try: try:
self.config = config self.config = config
if self.config.device: if self.config.device:
...@@ -179,7 +180,7 @@ class VectorEngine(BaseEngine): ...@@ -179,7 +180,7 @@ class VectorEngine(BaseEngine):
self.device = paddle.get_device() self.device = paddle.get_device()
paddle.set_device(self.device) paddle.set_device(self.device)
logger.info(f"Vector Engine set the device: {self.device}") logger.debug(f"Vector Engine set the device: {self.device}")
except BaseException as e: except BaseException as e:
logger.error( logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file" "Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
...@@ -196,5 +197,7 @@ class VectorEngine(BaseEngine): ...@@ -196,5 +197,7 @@ class VectorEngine(BaseEngine):
ckpt_path=config.ckpt_path, ckpt_path=config.ckpt_path,
task=config.task) task=config.task)
logger.info("Init the Vector engine successfully") logger.info(
"Initialize Vector server engine successfully on device: %s." %
(self.device))
return True return True
...@@ -138,7 +138,7 @@ class ASRWsAudioHandler: ...@@ -138,7 +138,7 @@ class ASRWsAudioHandler:
Returns: Returns:
str: the final asr result str: the final asr result
""" """
logging.info("send a message to the server") logging.debug("send a message to the server")
if self.url is None: if self.url is None:
logger.error("No asr server, please input valid ip and port") logger.error("No asr server, please input valid ip and port")
...@@ -160,7 +160,7 @@ class ASRWsAudioHandler: ...@@ -160,7 +160,7 @@ class ASRWsAudioHandler:
separators=(',', ': ')) separators=(',', ': '))
await ws.send(audio_info) await ws.send(audio_info)
msg = await ws.recv() msg = await ws.recv()
logger.info("client receive msg={}".format(msg)) logger.debug("client receive msg={}".format(msg))
# 3. send chunk audio data to engine # 3. send chunk audio data to engine
for chunk_data in self.read_wave(wavfile_path): for chunk_data in self.read_wave(wavfile_path):
...@@ -170,7 +170,7 @@ class ASRWsAudioHandler: ...@@ -170,7 +170,7 @@ class ASRWsAudioHandler:
if self.punc_server and len(msg["result"]) > 0: if self.punc_server and len(msg["result"]) > 0:
msg["result"] = self.punc_server.run(msg["result"]) msg["result"] = self.punc_server.run(msg["result"])
logger.info("client receive msg={}".format(msg)) logger.debug("client receive msg={}".format(msg))
# 4. we must send finished signal to the server # 4. we must send finished signal to the server
audio_info = json.dumps( audio_info = json.dumps(
...@@ -310,7 +310,7 @@ class TTSWsHandler: ...@@ -310,7 +310,7 @@ class TTSWsHandler:
start_request = json.dumps({"task": "tts", "signal": "start"}) start_request = json.dumps({"task": "tts", "signal": "start"})
await ws.send(start_request) await ws.send(start_request)
msg = await ws.recv() msg = await ws.recv()
logger.info(f"client receive msg={msg}") logger.debug(f"client receive msg={msg}")
msg = json.loads(msg) msg = json.loads(msg)
session = msg["session"] session = msg["session"]
...@@ -319,7 +319,7 @@ class TTSWsHandler: ...@@ -319,7 +319,7 @@ class TTSWsHandler:
request = json.dumps({"text": text_base64}) request = json.dumps({"text": text_base64})
st = time.time() st = time.time()
await ws.send(request) await ws.send(request)
logging.info("send a message to the server") logging.debug("send a message to the server")
# 4. Process the received response # 4. Process the received response
message = await ws.recv() message = await ws.recv()
...@@ -543,7 +543,6 @@ class VectorHttpHandler: ...@@ -543,7 +543,6 @@ class VectorHttpHandler:
"sample_rate": sample_rate, "sample_rate": sample_rate,
} }
logger.info(self.url)
res = requests.post(url=self.url, data=json.dumps(data)) res = requests.post(url=self.url, data=json.dumps(data))
return res.json() return res.json()
......
...@@ -169,7 +169,7 @@ def save_audio(bytes_data, audio_path, sample_rate: int=24000) -> bool: ...@@ -169,7 +169,7 @@ def save_audio(bytes_data, audio_path, sample_rate: int=24000) -> bool:
sample_rate=sample_rate) sample_rate=sample_rate)
os.remove("./tmp.pcm") os.remove("./tmp.pcm")
else: else:
print("Only supports saved audio format is pcm or wav") logger.error("Only supports saved audio format is pcm or wav")
return False return False
return True return True
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
__all__ = [
'logger',
]
class Logger(object):
def __init__(self, name: str=None):
name = 'PaddleSpeech' if not name else name
self.logger = logging.getLogger(name)
log_config = {
'DEBUG': 10,
'INFO': 20,
'TRAIN': 21,
'EVAL': 22,
'WARNING': 30,
'ERROR': 40,
'CRITICAL': 50,
'EXCEPTION': 100,
}
for key, level in log_config.items():
logging.addLevelName(level, key)
if key == 'EXCEPTION':
self.__dict__[key.lower()] = self.logger.exception
else:
self.__dict__[key.lower()] = functools.partial(self.__call__,
level)
self.format = logging.Formatter(
fmt='[%(asctime)-15s] [%(levelname)8s] - %(message)s')
self.handler = logging.StreamHandler()
self.handler.setFormatter(self.format)
self.logger.addHandler(self.handler)
self.logger.setLevel(logging.DEBUG)
self.logger.propagate = False
def __call__(self, log_level: str, msg: str):
self.logger.log(log_level, msg)
logger = Logger()
...@@ -16,11 +16,11 @@ from typing import Optional ...@@ -16,11 +16,11 @@ from typing import Optional
import onnxruntime as ort import onnxruntime as ort
from .log import logger from paddlespeech.cli.log import logger
def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None): def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
logger.info(f"ort sessconf: {sess_conf}") logger.debug(f"ort sessconf: {sess_conf}")
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
if sess_conf.get('graph_optimization_level', 99) == 0: if sess_conf.get('graph_optimization_level', 99) == 0:
...@@ -34,7 +34,7 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None): ...@@ -34,7 +34,7 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
# fastspeech2/mb_melgan can't use trt now! # fastspeech2/mb_melgan can't use trt now!
if sess_conf.get("use_trt", 0): if sess_conf.get("use_trt", 0):
providers = ['TensorrtExecutionProvider'] providers = ['TensorrtExecutionProvider']
logger.info(f"ort providers: {providers}") logger.debug(f"ort providers: {providers}")
if 'cpu_threads' in sess_conf: if 'cpu_threads' in sess_conf:
sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0) sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0)
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
import base64 import base64
import math import math
from paddlespeech.cli.log import logger
def wav2base64(wav_file: str): def wav2base64(wav_file: str):
""" """
...@@ -61,7 +63,7 @@ def get_chunks(data, block_size, pad_size, step): ...@@ -61,7 +63,7 @@ def get_chunks(data, block_size, pad_size, step):
elif step == "voc": elif step == "voc":
data_len = data.shape[0] data_len = data.shape[0]
else: else:
print("Please set correct type to get chunks, am or voc") logger.error("Please set correct type to get chunks, am or voc")
chunks = [] chunks = []
n = math.ceil(data_len / block_size) n = math.ceil(data_len / block_size)
...@@ -73,7 +75,7 @@ def get_chunks(data, block_size, pad_size, step): ...@@ -73,7 +75,7 @@ def get_chunks(data, block_size, pad_size, step):
elif step == "voc": elif step == "voc":
chunks.append(data[start:end, :]) chunks.append(data[start:end, :])
else: else:
print("Please set correct type to get chunks, am or voc") logger.error("Please set correct type to get chunks, am or voc")
return chunks return chunks
......
...@@ -50,20 +50,34 @@ class HiFiGANGenerator(nn.Layer): ...@@ -50,20 +50,34 @@ class HiFiGANGenerator(nn.Layer):
init_type: str="xavier_uniform", ): init_type: str="xavier_uniform", ):
"""Initialize HiFiGANGenerator module. """Initialize HiFiGANGenerator module.
Args: Args:
in_channels (int): Number of input channels. in_channels (int):
out_channels (int): Number of output channels. Number of input channels.
channels (int): Number of hidden representation channels. out_channels (int):
global_channels (int): Number of global conditioning channels. Number of output channels.
kernel_size (int): Kernel size of initial and final conv layer. channels (int):
upsample_scales (list): List of upsampling scales. Number of hidden representation channels.
upsample_kernel_sizes (list): List of kernel sizes for upsampling layers. global_channels (int):
resblock_kernel_sizes (list): List of kernel sizes for residual blocks. Number of global conditioning channels.
resblock_dilations (list): List of dilation list for residual blocks. kernel_size (int):
use_additional_convs (bool): Whether to use additional conv layers in residual blocks. Kernel size of initial and final conv layer.
bias (bool): Whether to add bias parameter in convolution layers. upsample_scales (list):
nonlinear_activation (str): Activation function module name. List of upsampling scales.
nonlinear_activation_params (dict): Hyperparameters for activation function. upsample_kernel_sizes (list):
use_weight_norm (bool): Whether to use weight norm. List of kernel sizes for upsampling layers.
resblock_kernel_sizes (list):
List of kernel sizes for residual blocks.
resblock_dilations (list):
List of dilation list for residual blocks.
use_additional_convs (bool):
Whether to use additional conv layers in residual blocks.
bias (bool):
Whether to add bias parameter in convolution layers.
nonlinear_activation (str):
Activation function module name.
nonlinear_activation_params (dict):
Hyperparameters for activation function.
use_weight_norm (bool):
Whether to use weight norm.
If set to true, it will be applied to all of the conv layers. If set to true, it will be applied to all of the conv layers.
""" """
super().__init__() super().__init__()
...@@ -199,9 +213,10 @@ class HiFiGANGenerator(nn.Layer): ...@@ -199,9 +213,10 @@ class HiFiGANGenerator(nn.Layer):
def inference(self, c, g: Optional[paddle.Tensor]=None): def inference(self, c, g: Optional[paddle.Tensor]=None):
"""Perform inference. """Perform inference.
Args: Args:
c (Tensor): Input tensor (T, in_channels). c (Tensor):
normalize_before (bool): Whether to perform normalization. Input tensor (T, in_channels).
g (Optional[Tensor]): Global conditioning tensor (global_channels, 1). g (Optional[Tensor]):
Global conditioning tensor (global_channels, 1).
Returns: Returns:
Tensor: Tensor:
Output tensor (T ** prod(upsample_scales), out_channels). Output tensor (T ** prod(upsample_scales), out_channels).
...@@ -233,20 +248,33 @@ class HiFiGANPeriodDiscriminator(nn.Layer): ...@@ -233,20 +248,33 @@ class HiFiGANPeriodDiscriminator(nn.Layer):
"""Initialize HiFiGANPeriodDiscriminator module. """Initialize HiFiGANPeriodDiscriminator module.
Args: Args:
in_channels (int): Number of input channels. in_channels (int):
out_channels (int): Number of output channels. Number of input channels.
period (int): Period. out_channels (int):
kernel_sizes (list): Kernel sizes of initial conv layers and the final conv layer. Number of output channels.
channels (int): Number of initial channels. period (int):
downsample_scales (list): List of downsampling scales. Period.
max_downsample_channels (int): Number of maximum downsampling channels. kernel_sizes (list):
use_additional_convs (bool): Whether to use additional conv layers in residual blocks. Kernel sizes of initial conv layers and the final conv layer.
bias (bool): Whether to add bias parameter in convolution layers. channels (int):
nonlinear_activation (str): Activation function module name. Number of initial channels.
nonlinear_activation_params (dict): Hyperparameters for activation function. downsample_scales (list):
use_weight_norm (bool): Whether to use weight norm. List of downsampling scales.
max_downsample_channels (int):
Number of maximum downsampling channels.
use_additional_convs (bool):
Whether to use additional conv layers in residual blocks.
bias (bool):
Whether to add bias parameter in convolution layers.
nonlinear_activation (str):
Activation function module name.
nonlinear_activation_params (dict):
Hyperparameters for activation function.
use_weight_norm (bool):
Whether to use weight norm.
If set to true, it will be applied to all of the conv layers. If set to true, it will be applied to all of the conv layers.
use_spectral_norm (bool): Whether to use spectral norm. use_spectral_norm (bool):
Whether to use spectral norm.
If set to true, it will be applied to all of the conv layers. If set to true, it will be applied to all of the conv layers.
""" """
super().__init__() super().__init__()
...@@ -298,7 +326,8 @@ class HiFiGANPeriodDiscriminator(nn.Layer): ...@@ -298,7 +326,8 @@ class HiFiGANPeriodDiscriminator(nn.Layer):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
c (Tensor): Input tensor (B, in_channels, T). c (Tensor):
Input tensor (B, in_channels, T).
Returns: Returns:
list: List of each layer's tensors. list: List of each layer's tensors.
""" """
...@@ -367,8 +396,10 @@ class HiFiGANMultiPeriodDiscriminator(nn.Layer): ...@@ -367,8 +396,10 @@ class HiFiGANMultiPeriodDiscriminator(nn.Layer):
"""Initialize HiFiGANMultiPeriodDiscriminator module. """Initialize HiFiGANMultiPeriodDiscriminator module.
Args: Args:
periods (list): List of periods. periods (list):
discriminator_params (dict): Parameters for hifi-gan period discriminator module. List of periods.
discriminator_params (dict):
Parameters for hifi-gan period discriminator module.
The period parameter will be overwritten. The period parameter will be overwritten.
""" """
super().__init__() super().__init__()
...@@ -385,7 +416,8 @@ class HiFiGANMultiPeriodDiscriminator(nn.Layer): ...@@ -385,7 +416,8 @@ class HiFiGANMultiPeriodDiscriminator(nn.Layer):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
x (Tensor): Input noise signal (B, 1, T). x (Tensor):
Input noise signal (B, 1, T).
Returns: Returns:
List: List of list of each discriminator outputs, which consists of each layer output tensors. List: List of list of each discriminator outputs, which consists of each layer output tensors.
""" """
...@@ -417,16 +449,25 @@ class HiFiGANScaleDiscriminator(nn.Layer): ...@@ -417,16 +449,25 @@ class HiFiGANScaleDiscriminator(nn.Layer):
"""Initilize HiFiGAN scale discriminator module. """Initilize HiFiGAN scale discriminator module.
Args: Args:
in_channels (int): Number of input channels. in_channels (int):
out_channels (int): Number of output channels. Number of input channels.
kernel_sizes (list): List of four kernel sizes. The first will be used for the first conv layer, out_channels (int):
Number of output channels.
kernel_sizes (list):
List of four kernel sizes. The first will be used for the first conv layer,
and the second is for downsampling part, and the remaining two are for output layers. and the second is for downsampling part, and the remaining two are for output layers.
channels (int): Initial number of channels for conv layer. channels (int):
max_downsample_channels (int): Maximum number of channels for downsampling layers. Initial number of channels for conv layer.
bias (bool): Whether to add bias parameter in convolution layers. max_downsample_channels (int):
downsample_scales (list): List of downsampling scales. Maximum number of channels for downsampling layers.
nonlinear_activation (str): Activation function module name. bias (bool):
nonlinear_activation_params (dict): Hyperparameters for activation function. Whether to add bias parameter in convolution layers.
downsample_scales (list):
List of downsampling scales.
nonlinear_activation (str):
Activation function module name.
nonlinear_activation_params (dict):
Hyperparameters for activation function.
use_weight_norm (bool): Whether to use weight norm. use_weight_norm (bool): Whether to use weight norm.
If set to true, it will be applied to all of the conv layers. If set to true, it will be applied to all of the conv layers.
use_spectral_norm (bool): Whether to use spectral norm. use_spectral_norm (bool): Whether to use spectral norm.
...@@ -614,7 +655,8 @@ class HiFiGANMultiScaleDiscriminator(nn.Layer): ...@@ -614,7 +655,8 @@ class HiFiGANMultiScaleDiscriminator(nn.Layer):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
x (Tensor): Input noise signal (B, 1, T). x (Tensor):
Input noise signal (B, 1, T).
Returns: Returns:
List: List of list of each discriminator outputs, which consists of each layer output tensors. List: List of list of each discriminator outputs, which consists of each layer output tensors.
""" """
...@@ -675,14 +717,21 @@ class HiFiGANMultiScaleMultiPeriodDiscriminator(nn.Layer): ...@@ -675,14 +717,21 @@ class HiFiGANMultiScaleMultiPeriodDiscriminator(nn.Layer):
"""Initilize HiFiGAN multi-scale + multi-period discriminator module. """Initilize HiFiGAN multi-scale + multi-period discriminator module.
Args: Args:
scales (int): Number of multi-scales. scales (int):
scale_downsample_pooling (str): Pooling module name for downsampling of the inputs. Number of multi-scales.
scale_downsample_pooling_params (dict): Parameters for the above pooling module. scale_downsample_pooling (str):
scale_discriminator_params (dict): Parameters for hifi-gan scale discriminator module. Pooling module name for downsampling of the inputs.
follow_official_norm (bool): Whether to follow the norm setting of the official implementaion. scale_downsample_pooling_params (dict):
Parameters for the above pooling module.
scale_discriminator_params (dict):
Parameters for hifi-gan scale discriminator module.
follow_official_norm (bool):
Whether to follow the norm setting of the official implementaion.
The first discriminator uses spectral norm and the other discriminators use weight norm. The first discriminator uses spectral norm and the other discriminators use weight norm.
periods (list): List of periods. periods (list):
period_discriminator_params (dict): Parameters for hifi-gan period discriminator module. List of periods.
period_discriminator_params (dict):
Parameters for hifi-gan period discriminator module.
The period parameter will be overwritten. The period parameter will be overwritten.
""" """
super().__init__() super().__init__()
...@@ -704,7 +753,8 @@ class HiFiGANMultiScaleMultiPeriodDiscriminator(nn.Layer): ...@@ -704,7 +753,8 @@ class HiFiGANMultiScaleMultiPeriodDiscriminator(nn.Layer):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
x (Tensor): Input noise signal (B, 1, T). x (Tensor):
Input noise signal (B, 1, T).
Returns: Returns:
List: List:
List of list of each discriminator outputs, List of list of each discriminator outputs,
......
...@@ -53,24 +53,38 @@ class MelGANGenerator(nn.Layer): ...@@ -53,24 +53,38 @@ class MelGANGenerator(nn.Layer):
"""Initialize MelGANGenerator module. """Initialize MelGANGenerator module.
Args: Args:
in_channels (int): Number of input channels. in_channels (int):
out_channels (int): Number of output channels, Number of input channels.
out_channels (int):
Number of output channels,
the number of sub-band is out_channels in multi-band melgan. the number of sub-band is out_channels in multi-band melgan.
kernel_size (int): Kernel size of initial and final conv layer. kernel_size (int):
channels (int): Initial number of channels for conv layer. Kernel size of initial and final conv layer.
bias (bool): Whether to add bias parameter in convolution layers. channels (int):
upsample_scales (List[int]): List of upsampling scales. Initial number of channels for conv layer.
stack_kernel_size (int): Kernel size of dilated conv layers in residual stack. bias (bool):
stacks (int): Number of stacks in a single residual stack. Whether to add bias parameter in convolution layers.
nonlinear_activation (Optional[str], optional): Non linear activation in upsample network, by default None upsample_scales (List[int]):
nonlinear_activation_params (Dict[str, Any], optional): Parameters passed to the linear activation in the upsample network, List of upsampling scales.
by default {} stack_kernel_size (int):
pad (str): Padding function module name before dilated convolution layer. Kernel size of dilated conv layers in residual stack.
pad_params (dict): Hyperparameters for padding function. stacks (int):
use_final_nonlinear_activation (nn.Layer): Activation function for the final layer. Number of stacks in a single residual stack.
use_weight_norm (bool): Whether to use weight norm. nonlinear_activation (Optional[str], optional):
Non linear activation in upsample network, by default None
nonlinear_activation_params (Dict[str, Any], optional):
Parameters passed to the linear activation in the upsample network, by default {}
pad (str):
Padding function module name before dilated convolution layer.
pad_params (dict):
Hyperparameters for padding function.
use_final_nonlinear_activation (nn.Layer):
Activation function for the final layer.
use_weight_norm (bool):
Whether to use weight norm.
If set to true, it will be applied to all of the conv layers. If set to true, it will be applied to all of the conv layers.
use_causal_conv (bool): Whether to use causal convolution. use_causal_conv (bool):
Whether to use causal convolution.
""" """
super().__init__() super().__init__()
...@@ -194,7 +208,8 @@ class MelGANGenerator(nn.Layer): ...@@ -194,7 +208,8 @@ class MelGANGenerator(nn.Layer):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
c (Tensor): Input tensor (B, in_channels, T). c (Tensor):
Input tensor (B, in_channels, T).
Returns: Returns:
Tensor: Output tensor (B, out_channels, T ** prod(upsample_scales)). Tensor: Output tensor (B, out_channels, T ** prod(upsample_scales)).
""" """
...@@ -244,7 +259,8 @@ class MelGANGenerator(nn.Layer): ...@@ -244,7 +259,8 @@ class MelGANGenerator(nn.Layer):
"""Perform inference. """Perform inference.
Args: Args:
c (Union[Tensor, ndarray]): Input tensor (T, in_channels). c (Union[Tensor, ndarray]):
Input tensor (T, in_channels).
Returns: Returns:
Tensor: Output tensor (out_channels*T ** prod(upsample_scales), 1). Tensor: Output tensor (out_channels*T ** prod(upsample_scales), 1).
""" """
...@@ -279,20 +295,30 @@ class MelGANDiscriminator(nn.Layer): ...@@ -279,20 +295,30 @@ class MelGANDiscriminator(nn.Layer):
"""Initilize MelGAN discriminator module. """Initilize MelGAN discriminator module.
Args: Args:
in_channels (int): Number of input channels. in_channels (int):
out_channels (int): Number of output channels. Number of input channels.
out_channels (int):
Number of output channels.
kernel_sizes (List[int]): List of two kernel sizes. The prod will be used for the first conv layer, kernel_sizes (List[int]): List of two kernel sizes. The prod will be used for the first conv layer,
and the first and the second kernel sizes will be used for the last two layers. and the first and the second kernel sizes will be used for the last two layers.
For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15, For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15,
the last two layers' kernel size will be 5 and 3, respectively. the last two layers' kernel size will be 5 and 3, respectively.
channels (int): Initial number of channels for conv layer. channels (int):
max_downsample_channels (int): Maximum number of channels for downsampling layers. Initial number of channels for conv layer.
bias (bool): Whether to add bias parameter in convolution layers. max_downsample_channels (int):
downsample_scales (List[int]): List of downsampling scales. Maximum number of channels for downsampling layers.
nonlinear_activation (str): Activation function module name. bias (bool):
nonlinear_activation_params (dict): Hyperparameters for activation function. Whether to add bias parameter in convolution layers.
pad (str): Padding function module name before dilated convolution layer. downsample_scales (List[int]):
pad_params (dict): Hyperparameters for padding function. List of downsampling scales.
nonlinear_activation (str):
Activation function module name.
nonlinear_activation_params (dict):
Hyperparameters for activation function.
pad (str):
Padding function module name before dilated convolution layer.
pad_params (dict):
Hyperparameters for padding function.
""" """
super().__init__() super().__init__()
...@@ -364,7 +390,8 @@ class MelGANDiscriminator(nn.Layer): ...@@ -364,7 +390,8 @@ class MelGANDiscriminator(nn.Layer):
def forward(self, x): def forward(self, x):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
x (Tensor): Input noise signal (B, 1, T). x (Tensor):
Input noise signal (B, 1, T).
Returns: Returns:
List: List of output tensors of each layer (for feat_match_loss). List: List of output tensors of each layer (for feat_match_loss).
""" """
...@@ -406,22 +433,37 @@ class MelGANMultiScaleDiscriminator(nn.Layer): ...@@ -406,22 +433,37 @@ class MelGANMultiScaleDiscriminator(nn.Layer):
"""Initilize MelGAN multi-scale discriminator module. """Initilize MelGAN multi-scale discriminator module.
Args: Args:
in_channels (int): Number of input channels. in_channels (int):
out_channels (int): Number of output channels. Number of input channels.
scales (int): Number of multi-scales. out_channels (int):
downsample_pooling (str): Pooling module name for downsampling of the inputs. Number of output channels.
downsample_pooling_params (dict): Parameters for the above pooling module. scales (int):
kernel_sizes (List[int]): List of two kernel sizes. The sum will be used for the first conv layer, Number of multi-scales.
downsample_pooling (str):
Pooling module name for downsampling of the inputs.
downsample_pooling_params (dict):
Parameters for the above pooling module.
kernel_sizes (List[int]):
List of two kernel sizes. The sum will be used for the first conv layer,
and the first and the second kernel sizes will be used for the last two layers. and the first and the second kernel sizes will be used for the last two layers.
channels (int): Initial number of channels for conv layer. channels (int):
max_downsample_channels (int): Maximum number of channels for downsampling layers. Initial number of channels for conv layer.
bias (bool): Whether to add bias parameter in convolution layers. max_downsample_channels (int):
downsample_scales (List[int]): List of downsampling scales. Maximum number of channels for downsampling layers.
nonlinear_activation (str): Activation function module name. bias (bool):
nonlinear_activation_params (dict): Hyperparameters for activation function. Whether to add bias parameter in convolution layers.
pad (str): Padding function module name before dilated convolution layer. downsample_scales (List[int]):
pad_params (dict): Hyperparameters for padding function. List of downsampling scales.
use_causal_conv (bool): Whether to use causal convolution. nonlinear_activation (str):
Activation function module name.
nonlinear_activation_params (dict):
Hyperparameters for activation function.
pad (str):
Padding function module name before dilated convolution layer.
pad_params (dict):
Hyperparameters for padding function.
use_causal_conv (bool):
Whether to use causal convolution.
""" """
super().__init__() super().__init__()
...@@ -464,7 +506,8 @@ class MelGANMultiScaleDiscriminator(nn.Layer): ...@@ -464,7 +506,8 @@ class MelGANMultiScaleDiscriminator(nn.Layer):
def forward(self, x): def forward(self, x):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
x (Tensor): Input noise signal (B, 1, T). x (Tensor):
Input noise signal (B, 1, T).
Returns: Returns:
List: List of list of each discriminator outputs, which consists of each layer output tensors. List: List of list of each discriminator outputs, which consists of each layer output tensors.
""" """
......
...@@ -54,20 +54,34 @@ class StyleMelGANGenerator(nn.Layer): ...@@ -54,20 +54,34 @@ class StyleMelGANGenerator(nn.Layer):
"""Initilize Style MelGAN generator. """Initilize Style MelGAN generator.
Args: Args:
in_channels (int): Number of input noise channels. in_channels (int):
aux_channels (int): Number of auxiliary input channels. Number of input noise channels.
channels (int): Number of channels for conv layer. aux_channels (int):
out_channels (int): Number of output channels. Number of auxiliary input channels.
kernel_size (int): Kernel size of conv layers. channels (int):
dilation (int): Dilation factor for conv layers. Number of channels for conv layer.
bias (bool): Whether to add bias parameter in convolution layers. out_channels (int):
noise_upsample_scales (list): List of noise upsampling scales. Number of output channels.
noise_upsample_activation (str): Activation function module name for noise upsampling. kernel_size (int):
noise_upsample_activation_params (dict): Hyperparameters for the above activation function. Kernel size of conv layers.
upsample_scales (list): List of upsampling scales. dilation (int):
upsample_mode (str): Upsampling mode in TADE layer. Dilation factor for conv layers.
gated_function (str): Gated function in TADEResBlock ("softmax" or "sigmoid"). bias (bool):
use_weight_norm (bool): Whether to use weight norm. Whether to add bias parameter in convolution layers.
noise_upsample_scales (list):
List of noise upsampling scales.
noise_upsample_activation (str):
Activation function module name for noise upsampling.
noise_upsample_activation_params (dict):
Hyperparameters for the above activation function.
upsample_scales (list):
List of upsampling scales.
upsample_mode (str):
Upsampling mode in TADE layer.
gated_function (str):
Gated function in TADEResBlock ("softmax" or "sigmoid").
use_weight_norm (bool):
Whether to use weight norm.
If set to true, it will be applied to all of the conv layers. If set to true, it will be applied to all of the conv layers.
""" """
super().__init__() super().__init__()
...@@ -194,7 +208,8 @@ class StyleMelGANGenerator(nn.Layer): ...@@ -194,7 +208,8 @@ class StyleMelGANGenerator(nn.Layer):
def inference(self, c): def inference(self, c):
"""Perform inference. """Perform inference.
Args: Args:
c (Tensor): Input tensor (T, in_channels). c (Tensor):
Input tensor (T, in_channels).
Returns: Returns:
Tensor: Output tensor (T ** prod(upsample_scales), out_channels). Tensor: Output tensor (T ** prod(upsample_scales), out_channels).
""" """
...@@ -258,11 +273,16 @@ class StyleMelGANDiscriminator(nn.Layer): ...@@ -258,11 +273,16 @@ class StyleMelGANDiscriminator(nn.Layer):
"""Initilize Style MelGAN discriminator. """Initilize Style MelGAN discriminator.
Args: Args:
repeats (int): Number of repititons to apply RWD. repeats (int):
window_sizes (list): List of random window sizes. Number of repititons to apply RWD.
pqmf_params (list): List of list of Parameters for PQMF modules window_sizes (list):
discriminator_params (dict): Parameters for base discriminator module. List of random window sizes.
use_weight_nom (bool): Whether to apply weight normalization. pqmf_params (list):
List of list of Parameters for PQMF modules
discriminator_params (dict):
Parameters for base discriminator module.
use_weight_nom (bool):
Whether to apply weight normalization.
""" """
super().__init__() super().__init__()
...@@ -299,7 +319,8 @@ class StyleMelGANDiscriminator(nn.Layer): ...@@ -299,7 +319,8 @@ class StyleMelGANDiscriminator(nn.Layer):
def forward(self, x): def forward(self, x):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
x (Tensor): Input tensor (B, 1, T). x (Tensor):
Input tensor (B, 1, T).
Returns: Returns:
List: List of discriminator outputs, #items in the list will be List: List of discriminator outputs, #items in the list will be
equal to repeats * #discriminators. equal to repeats * #discriminators.
......
...@@ -32,29 +32,45 @@ class PWGGenerator(nn.Layer): ...@@ -32,29 +32,45 @@ class PWGGenerator(nn.Layer):
"""Wave Generator for Parallel WaveGAN """Wave Generator for Parallel WaveGAN
Args: Args:
in_channels (int, optional): Number of channels of the input waveform, by default 1 in_channels (int, optional):
out_channels (int, optional): Number of channels of the output waveform, by default 1 Number of channels of the input waveform, by default 1
kernel_size (int, optional): Kernel size of the residual blocks inside, by default 3 out_channels (int, optional):
layers (int, optional): Number of residual blocks inside, by default 30 Number of channels of the output waveform, by default 1
stacks (int, optional): The number of groups to split the residual blocks into, by default 3 kernel_size (int, optional):
Kernel size of the residual blocks inside, by default 3
layers (int, optional):
Number of residual blocks inside, by default 30
stacks (int, optional):
The number of groups to split the residual blocks into, by default 3
Within each group, the dilation of the residual block grows exponentially. Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional): Residual channel of the residual blocks, by default 64 residual_channels (int, optional):
gate_channels (int, optional): Gate channel of the residual blocks, by default 128 Residual channel of the residual blocks, by default 64
skip_channels (int, optional): Skip channel of the residual blocks, by default 64 gate_channels (int, optional):
aux_channels (int, optional): Auxiliary channel of the residual blocks, by default 80 Gate channel of the residual blocks, by default 128
aux_context_window (int, optional): The context window size of the first convolution applied to the skip_channels (int, optional):
auxiliary input, by default 2 Skip channel of the residual blocks, by default 64
dropout (float, optional): Dropout of the residual blocks, by default 0. aux_channels (int, optional):
bias (bool, optional): Whether to use bias in residual blocks, by default True Auxiliary channel of the residual blocks, by default 80
use_weight_norm (bool, optional): Whether to use weight norm in all convolutions, by default True aux_context_window (int, optional):
use_causal_conv (bool, optional): Whether to use causal padding in the upsample network and residual The context window size of the first convolution applied to the auxiliary input, by default 2
blocks, by default False dropout (float, optional):
upsample_scales (List[int], optional): Upsample scales of the upsample network, by default [4, 4, 4, 4] Dropout of the residual blocks, by default 0.
nonlinear_activation (Optional[str], optional): Non linear activation in upsample network, by default None bias (bool, optional):
nonlinear_activation_params (Dict[str, Any], optional): Parameters passed to the linear activation in the upsample network, Whether to use bias in residual blocks, by default True
by default {} use_weight_norm (bool, optional):
interpolate_mode (str, optional): Interpolation mode of the upsample network, by default "nearest" Whether to use weight norm in all convolutions, by default True
freq_axis_kernel_size (int, optional): Kernel size along the frequency axis of the upsample network, by default 1 use_causal_conv (bool, optional):
Whether to use causal padding in the upsample network and residual blocks, by default False
upsample_scales (List[int], optional):
Upsample scales of the upsample network, by default [4, 4, 4, 4]
nonlinear_activation (Optional[str], optional):
Non linear activation in upsample network, by default None
nonlinear_activation_params (Dict[str, Any], optional):
Parameters passed to the linear activation in the upsample network, by default {}
interpolate_mode (str, optional):
Interpolation mode of the upsample network, by default "nearest"
freq_axis_kernel_size (int, optional):
Kernel size along the frequency axis of the upsample network, by default 1
""" """
def __init__( def __init__(
...@@ -147,9 +163,11 @@ class PWGGenerator(nn.Layer): ...@@ -147,9 +163,11 @@ class PWGGenerator(nn.Layer):
"""Generate waveform. """Generate waveform.
Args: Args:
x(Tensor): Shape (N, C_in, T), The input waveform. x(Tensor):
c(Tensor): Shape (N, C_aux, T'). The auxiliary input (e.g. spectrogram). It Shape (N, C_in, T), The input waveform.
is upsampled to match the time resolution of the input. c(Tensor):
Shape (N, C_aux, T'). The auxiliary input (e.g. spectrogram).
It is upsampled to match the time resolution of the input.
Returns: Returns:
Tensor: Shape (N, C_out, T), the generated waveform. Tensor: Shape (N, C_out, T), the generated waveform.
...@@ -195,8 +213,10 @@ class PWGGenerator(nn.Layer): ...@@ -195,8 +213,10 @@ class PWGGenerator(nn.Layer):
"""Waveform generation. This function is used for single instance inference. """Waveform generation. This function is used for single instance inference.
Args: Args:
c(Tensor, optional, optional): Shape (T', C_aux), the auxiliary input, by default None c(Tensor, optional, optional):
x(Tensor, optional): Shape (T, C_in), the noise waveform, by default None Shape (T', C_aux), the auxiliary input, by default None
x(Tensor, optional):
Shape (T, C_in), the noise waveform, by default None
Returns: Returns:
Tensor: Shape (T, C_out), the generated waveform Tensor: Shape (T, C_out), the generated waveform
...@@ -214,20 +234,28 @@ class PWGDiscriminator(nn.Layer): ...@@ -214,20 +234,28 @@ class PWGDiscriminator(nn.Layer):
"""A convolutional discriminator for audio. """A convolutional discriminator for audio.
Args: Args:
in_channels (int, optional): Number of channels of the input audio, by default 1 in_channels (int, optional):
out_channels (int, optional): Output feature size, by default 1 Number of channels of the input audio, by default 1
kernel_size (int, optional): Kernel size of convolutional sublayers, by default 3 out_channels (int, optional):
layers (int, optional): Number of layers, by default 10 Output feature size, by default 1
conv_channels (int, optional): Feature size of the convolutional sublayers, by default 64 kernel_size (int, optional):
dilation_factor (int, optional): The factor with which dilation of each convolutional sublayers grows Kernel size of convolutional sublayers, by default 3
layers (int, optional):
Number of layers, by default 10
conv_channels (int, optional):
Feature size of the convolutional sublayers, by default 64
dilation_factor (int, optional):
The factor with which dilation of each convolutional sublayers grows
exponentially if it is greater than 1, else the dilation of each convolutional sublayers grows linearly, exponentially if it is greater than 1, else the dilation of each convolutional sublayers grows linearly,
by default 1 by default 1
nonlinear_activation (str, optional): The activation after each convolutional sublayer, by default "leakyrelu" nonlinear_activation (str, optional):
nonlinear_activation_params (Dict[str, Any], optional): The parameters passed to the activation's initializer, by default The activation after each convolutional sublayer, by default "leakyrelu"
{"negative_slope": 0.2} nonlinear_activation_params (Dict[str, Any], optional):
bias (bool, optional): Whether to use bias in convolutional sublayers, by default True The parameters passed to the activation's initializer, by default {"negative_slope": 0.2}
use_weight_norm (bool, optional): Whether to use weight normalization at all convolutional sublayers, bias (bool, optional):
by default True Whether to use bias in convolutional sublayers, by default True
use_weight_norm (bool, optional):
Whether to use weight normalization at all convolutional sublayers, by default True
""" """
def __init__( def __init__(
...@@ -290,7 +318,8 @@ class PWGDiscriminator(nn.Layer): ...@@ -290,7 +318,8 @@ class PWGDiscriminator(nn.Layer):
""" """
Args: Args:
x (Tensor): Shape (N, in_channels, num_samples), the input audio. x (Tensor):
Shape (N, in_channels, num_samples), the input audio.
Returns: Returns:
Tensor: Shape (N, out_channels, num_samples), the predicted logits. Tensor: Shape (N, out_channels, num_samples), the predicted logits.
...@@ -318,24 +347,35 @@ class ResidualPWGDiscriminator(nn.Layer): ...@@ -318,24 +347,35 @@ class ResidualPWGDiscriminator(nn.Layer):
"""A wavenet-style discriminator for audio. """A wavenet-style discriminator for audio.
Args: Args:
in_channels (int, optional): Number of channels of the input audio, by default 1 in_channels (int, optional):
out_channels (int, optional): Output feature size, by default 1 Number of channels of the input audio, by default 1
kernel_size (int, optional): Kernel size of residual blocks, by default 3 out_channels (int, optional):
layers (int, optional): Number of residual blocks, by default 30 Output feature size, by default 1
stacks (int, optional): Number of groups of residual blocks, within which the dilation kernel_size (int, optional):
Kernel size of residual blocks, by default 3
layers (int, optional):
Number of residual blocks, by default 30
stacks (int, optional):
Number of groups of residual blocks, within which the dilation
of each residual blocks grows exponentially, by default 3 of each residual blocks grows exponentially, by default 3
residual_channels (int, optional): Residual channels of residual blocks, by default 64 residual_channels (int, optional):
gate_channels (int, optional): Gate channels of residual blocks, by default 128 Residual channels of residual blocks, by default 64
skip_channels (int, optional): Skip channels of residual blocks, by default 64 gate_channels (int, optional):
dropout (float, optional): Dropout probability of residual blocks, by default 0. Gate channels of residual blocks, by default 128
bias (bool, optional): Whether to use bias in residual blocks, by default True skip_channels (int, optional):
use_weight_norm (bool, optional): Whether to use weight normalization in all convolutional layers, Skip channels of residual blocks, by default 64
by default True dropout (float, optional):
use_causal_conv (bool, optional): Whether to use causal convolution in residual blocks, by default False Dropout probability of residual blocks, by default 0.
nonlinear_activation (str, optional): Activation after convolutions other than those in residual blocks, bias (bool, optional):
by default "leakyrelu" Whether to use bias in residual blocks, by default True
nonlinear_activation_params (Dict[str, Any], optional): Parameters to pass to the activation, use_weight_norm (bool, optional):
by default {"negative_slope": 0.2} Whether to use weight normalization in all convolutional layers, by default True
use_causal_conv (bool, optional):
Whether to use causal convolution in residual blocks, by default False
nonlinear_activation (str, optional):
Activation after convolutions other than those in residual blocks, by default "leakyrelu"
nonlinear_activation_params (Dict[str, Any], optional):
Parameters to pass to the activation, by default {"negative_slope": 0.2}
""" """
def __init__( def __init__(
...@@ -405,7 +445,8 @@ class ResidualPWGDiscriminator(nn.Layer): ...@@ -405,7 +445,8 @@ class ResidualPWGDiscriminator(nn.Layer):
def forward(self, x): def forward(self, x):
""" """
Args: Args:
x(Tensor): Shape (N, in_channels, num_samples), the input audio.↩ x(Tensor):
Shape (N, in_channels, num_samples), the input audio.↩
Returns: Returns:
Tensor: Shape (N, out_channels, num_samples), the predicted logits. Tensor: Shape (N, out_channels, num_samples), the predicted logits.
......
...@@ -29,10 +29,14 @@ class ResidualBlock(nn.Layer): ...@@ -29,10 +29,14 @@ class ResidualBlock(nn.Layer):
n: int=2): n: int=2):
"""SpeedySpeech encoder module. """SpeedySpeech encoder module.
Args: Args:
channels (int, optional): Feature size of the residual output(and also the input). channels (int, optional):
kernel_size (int, optional): Kernel size of the 1D convolution. Feature size of the residual output(and also the input).
dilation (int, optional): Dilation of the 1D convolution. kernel_size (int, optional):
n (int): Number of blocks. Kernel size of the 1D convolution.
dilation (int, optional):
Dilation of the 1D convolution.
n (int):
Number of blocks.
""" """
super().__init__() super().__init__()
...@@ -57,7 +61,8 @@ class ResidualBlock(nn.Layer): ...@@ -57,7 +61,8 @@ class ResidualBlock(nn.Layer):
def forward(self, x: paddle.Tensor): def forward(self, x: paddle.Tensor):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
x(Tensor): Batch of input sequences (B, hidden_size, Tmax). x(Tensor):
Batch of input sequences (B, hidden_size, Tmax).
Returns: Returns:
Tensor: The residual output (B, hidden_size, Tmax). Tensor: The residual output (B, hidden_size, Tmax).
""" """
...@@ -89,8 +94,10 @@ class TextEmbedding(nn.Layer): ...@@ -89,8 +94,10 @@ class TextEmbedding(nn.Layer):
def forward(self, text: paddle.Tensor, tone: paddle.Tensor=None): def forward(self, text: paddle.Tensor, tone: paddle.Tensor=None):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
text(Tensor(int64)): Batch of padded token ids (B, Tmax). text(Tensor(int64)):
tones(Tensor, optional(int64)): Batch of padded tone ids (B, Tmax). Batch of padded token ids (B, Tmax).
tones(Tensor, optional(int64)):
Batch of padded tone ids (B, Tmax).
Returns: Returns:
Tensor: The residual output (B, Tmax, embedding_size). Tensor: The residual output (B, Tmax, embedding_size).
""" """
...@@ -109,12 +116,18 @@ class TextEmbedding(nn.Layer): ...@@ -109,12 +116,18 @@ class TextEmbedding(nn.Layer):
class SpeedySpeechEncoder(nn.Layer): class SpeedySpeechEncoder(nn.Layer):
"""SpeedySpeech encoder module. """SpeedySpeech encoder module.
Args: Args:
vocab_size (int): Dimension of the inputs. vocab_size (int):
tone_size (Optional[int]): Number of tones. Dimension of the inputs.
hidden_size (int): Number of encoder hidden units. tone_size (Optional[int]):
kernel_size (int): Kernel size of encoder. Number of tones.
dilations (List[int]): Dilations of encoder. hidden_size (int):
spk_num (Optional[int]): Number of speakers. Number of encoder hidden units.
kernel_size (int):
Kernel size of encoder.
dilations (List[int]):
Dilations of encoder.
spk_num (Optional[int]):
Number of speakers.
""" """
def __init__(self, def __init__(self,
...@@ -161,9 +174,12 @@ class SpeedySpeechEncoder(nn.Layer): ...@@ -161,9 +174,12 @@ class SpeedySpeechEncoder(nn.Layer):
spk_id: paddle.Tensor=None): spk_id: paddle.Tensor=None):
"""Encoder input sequence. """Encoder input sequence.
Args: Args:
text(Tensor(int64)): Batch of padded token ids (B, Tmax). text(Tensor(int64)):
tones(Tensor, optional(int64)): Batch of padded tone ids (B, Tmax). Batch of padded token ids (B, Tmax).
spk_id(Tnesor, optional(int64)): Batch of speaker ids (B,) tones(Tensor, optional(int64)):
Batch of padded tone ids (B, Tmax).
spk_id(Tnesor, optional(int64)):
Batch of speaker ids (B,)
Returns: Returns:
Tensor: Output tensor (B, Tmax, hidden_size). Tensor: Output tensor (B, Tmax, hidden_size).
...@@ -192,7 +208,8 @@ class DurationPredictor(nn.Layer): ...@@ -192,7 +208,8 @@ class DurationPredictor(nn.Layer):
def forward(self, x: paddle.Tensor): def forward(self, x: paddle.Tensor):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
x(Tensor): Batch of input sequences (B, Tmax, hidden_size). x(Tensor):
Batch of input sequences (B, Tmax, hidden_size).
Returns: Returns:
Tensor: Batch of predicted durations in log domain (B, Tmax). Tensor: Batch of predicted durations in log domain (B, Tmax).
...@@ -212,10 +229,14 @@ class SpeedySpeechDecoder(nn.Layer): ...@@ -212,10 +229,14 @@ class SpeedySpeechDecoder(nn.Layer):
]): ]):
"""SpeedySpeech decoder module. """SpeedySpeech decoder module.
Args: Args:
hidden_size (int): Number of decoder hidden units. hidden_size (int):
kernel_size (int): Kernel size of decoder. Number of decoder hidden units.
output_size (int): Dimension of the outputs. kernel_size (int):
dilations (List[int]): Dilations of decoder. Kernel size of decoder.
output_size (int):
Dimension of the outputs.
dilations (List[int]):
Dilations of decoder.
""" """
super().__init__() super().__init__()
res_blocks = [ res_blocks = [
...@@ -230,7 +251,8 @@ class SpeedySpeechDecoder(nn.Layer): ...@@ -230,7 +251,8 @@ class SpeedySpeechDecoder(nn.Layer):
def forward(self, x): def forward(self, x):
"""Decoder input sequence. """Decoder input sequence.
Args: Args:
x(Tensor): Input tensor (B, time, hidden_size). x(Tensor):
Input tensor (B, time, hidden_size).
Returns: Returns:
Tensor: Output tensor (B, time, output_size). Tensor: Output tensor (B, time, output_size).
...@@ -261,18 +283,30 @@ class SpeedySpeech(nn.Layer): ...@@ -261,18 +283,30 @@ class SpeedySpeech(nn.Layer):
positional_dropout_rate: int=0.1): positional_dropout_rate: int=0.1):
"""Initialize SpeedySpeech module. """Initialize SpeedySpeech module.
Args: Args:
vocab_size (int): Dimension of the inputs. vocab_size (int):
encoder_hidden_size (int): Number of encoder hidden units. Dimension of the inputs.
encoder_kernel_size (int): Kernel size of encoder. encoder_hidden_size (int):
encoder_dilations (List[int]): Dilations of encoder. Number of encoder hidden units.
duration_predictor_hidden_size (int): Number of duration predictor hidden units. encoder_kernel_size (int):
decoder_hidden_size (int): Number of decoder hidden units. Kernel size of encoder.
decoder_kernel_size (int): Kernel size of decoder. encoder_dilations (List[int]):
decoder_dilations (List[int]): Dilations of decoder. Dilations of encoder.
decoder_output_size (int): Dimension of the outputs. duration_predictor_hidden_size (int):
tone_size (Optional[int]): Number of tones. Number of duration predictor hidden units.
spk_num (Optional[int]): Number of speakers. decoder_hidden_size (int):
init_type (str): How to initialize transformer parameters. Number of decoder hidden units.
decoder_kernel_size (int):
Kernel size of decoder.
decoder_dilations (List[int]):
Dilations of decoder.
decoder_output_size (int):
Dimension of the outputs.
tone_size (Optional[int]):
Number of tones.
spk_num (Optional[int]):
Number of speakers.
init_type (str):
How to initialize transformer parameters.
""" """
super().__init__() super().__init__()
...@@ -304,14 +338,20 @@ class SpeedySpeech(nn.Layer): ...@@ -304,14 +338,20 @@ class SpeedySpeech(nn.Layer):
spk_id: paddle.Tensor=None): spk_id: paddle.Tensor=None):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
text(Tensor(int64)): Batch of padded token ids (B, Tmax). text(Tensor(int64)):
durations(Tensor(int64)): Batch of padded durations (B, Tmax). Batch of padded token ids (B, Tmax).
tones(Tensor, optional(int64)): Batch of padded tone ids (B, Tmax). durations(Tensor(int64)):
spk_id(Tnesor, optional(int64)): Batch of speaker ids (B,) Batch of padded durations (B, Tmax).
tones(Tensor, optional(int64)):
Batch of padded tone ids (B, Tmax).
spk_id(Tnesor, optional(int64)):
Batch of speaker ids (B,)
Returns: Returns:
Tensor: Output tensor (B, T_frames, decoder_output_size). Tensor:
Tensor: Predicted durations (B, Tmax). Output tensor (B, T_frames, decoder_output_size).
Tensor:
Predicted durations (B, Tmax).
""" """
# input of embedding must be int64 # input of embedding must be int64
text = paddle.cast(text, 'int64') text = paddle.cast(text, 'int64')
...@@ -336,10 +376,14 @@ class SpeedySpeech(nn.Layer): ...@@ -336,10 +376,14 @@ class SpeedySpeech(nn.Layer):
spk_id: paddle.Tensor=None): spk_id: paddle.Tensor=None):
"""Generate the sequence of features given the sequences of characters. """Generate the sequence of features given the sequences of characters.
Args: Args:
text(Tensor(int64)): Input sequence of characters (T,). text(Tensor(int64)):
tones(Tensor, optional(int64)): Batch of padded tone ids (T, ). Input sequence of characters (T,).
durations(Tensor, optional (int64)): Groundtruth of duration (T,). tones(Tensor, optional(int64)):
spk_id(Tensor, optional(int64), optional): spk ids (1,). (Default value = None) Batch of padded tone ids (T, ).
durations(Tensor, optional (int64)):
Groundtruth of duration (T,).
spk_id(Tensor, optional(int64), optional):
spk ids (1,). (Default value = None)
Returns: Returns:
Tensor: logmel (T, decoder_output_size). Tensor: logmel (T, decoder_output_size).
......
...@@ -83,38 +83,67 @@ class Tacotron2(nn.Layer): ...@@ -83,38 +83,67 @@ class Tacotron2(nn.Layer):
init_type: str="xavier_uniform", ): init_type: str="xavier_uniform", ):
"""Initialize Tacotron2 module. """Initialize Tacotron2 module.
Args: Args:
idim (int): Dimension of the inputs. idim (int):
odim (int): Dimension of the outputs. Dimension of the inputs.
embed_dim (int): Dimension of the token embedding. odim (int):
elayers (int): Number of encoder blstm layers. Dimension of the outputs.
eunits (int): Number of encoder blstm units. embed_dim (int):
econv_layers (int): Number of encoder conv layers. Dimension of the token embedding.
econv_filts (int): Number of encoder conv filter size. elayers (int):
econv_chans (int): Number of encoder conv filter channels. Number of encoder blstm layers.
dlayers (int): Number of decoder lstm layers. eunits (int):
dunits (int): Number of decoder lstm units. Number of encoder blstm units.
prenet_layers (int): Number of prenet layers. econv_layers (int):
prenet_units (int): Number of prenet units. Number of encoder conv layers.
postnet_layers (int): Number of postnet layers. econv_filts (int):
postnet_filts (int): Number of postnet filter size. Number of encoder conv filter size.
postnet_chans (int): Number of postnet filter channels. econv_chans (int):
output_activation (str): Name of activation function for outputs. Number of encoder conv filter channels.
adim (int): Number of dimension of mlp in attention. dlayers (int):
aconv_chans (int): Number of attention conv filter channels. Number of decoder lstm layers.
aconv_filts (int): Number of attention conv filter size. dunits (int):
cumulate_att_w (bool): Whether to cumulate previous attention weight. Number of decoder lstm units.
use_batch_norm (bool): Whether to use batch normalization. prenet_layers (int):
use_concate (bool): Whether to concat enc outputs w/ dec lstm outputs. Number of prenet layers.
reduction_factor (int): Reduction factor. prenet_units (int):
spk_num (Optional[int]): Number of speakers. If set to > 1, assume that the Number of prenet units.
postnet_layers (int):
Number of postnet layers.
postnet_filts (int):
Number of postnet filter size.
postnet_chans (int):
Number of postnet filter channels.
output_activation (str):
Name of activation function for outputs.
adim (int):
Number of dimension of mlp in attention.
aconv_chans (int):
Number of attention conv filter channels.
aconv_filts (int):
Number of attention conv filter size.
cumulate_att_w (bool):
Whether to cumulate previous attention weight.
use_batch_norm (bool):
Whether to use batch normalization.
use_concate (bool):
Whether to concat enc outputs w/ dec lstm outputs.
reduction_factor (int):
Reduction factor.
spk_num (Optional[int]):
Number of speakers. If set to > 1, assume that the
sids will be provided as the input and use sid embedding layer. sids will be provided as the input and use sid embedding layer.
lang_num (Optional[int]): Number of languages. If set to > 1, assume that the lang_num (Optional[int]):
Number of languages. If set to > 1, assume that the
lids will be provided as the input and use sid embedding layer. lids will be provided as the input and use sid embedding layer.
spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, spk_embed_dim (Optional[int]):
Speaker embedding dimension. If set to > 0,
assume that spk_emb will be provided as the input. assume that spk_emb will be provided as the input.
spk_embed_integration_type (str): How to integrate speaker embedding. spk_embed_integration_type (str):
dropout_rate (float): Dropout rate. How to integrate speaker embedding.
zoneout_rate (float): Zoneout rate. dropout_rate (float):
Dropout rate.
zoneout_rate (float):
Zoneout rate.
""" """
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()
...@@ -230,18 +259,28 @@ class Tacotron2(nn.Layer): ...@@ -230,18 +259,28 @@ class Tacotron2(nn.Layer):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
text (Tensor(int64)): Batch of padded character ids (B, T_text). text (Tensor(int64)):
text_lengths (Tensor(int64)): Batch of lengths of each input batch (B,). Batch of padded character ids (B, T_text).
speech (Tensor): Batch of padded target features (B, T_feats, odim). text_lengths (Tensor(int64)):
speech_lengths (Tensor(int64)): Batch of the lengths of each target (B,). Batch of lengths of each input batch (B,).
spk_emb (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). speech (Tensor):
spk_id (Optional[Tensor]): Batch of speaker IDs (B, 1). Batch of padded target features (B, T_feats, odim).
lang_id (Optional[Tensor]): Batch of language IDs (B, 1). speech_lengths (Tensor(int64)):
Batch of the lengths of each target (B,).
spk_emb (Optional[Tensor]):
Batch of speaker embeddings (B, spk_embed_dim).
spk_id (Optional[Tensor]):
Batch of speaker IDs (B, 1).
lang_id (Optional[Tensor]):
Batch of language IDs (B, 1).
Returns: Returns:
Tensor: Loss scalar value. Tensor:
Dict: Statistics to be monitored. Loss scalar value.
Tensor: Weight value if not joint training else model outputs. Dict:
Statistics to be monitored.
Tensor:
Weight value if not joint training else model outputs.
""" """
text = text[:, :text_lengths.max()] text = text[:, :text_lengths.max()]
...@@ -329,18 +368,30 @@ class Tacotron2(nn.Layer): ...@@ -329,18 +368,30 @@ class Tacotron2(nn.Layer):
"""Generate the sequence of features given the sequences of characters. """Generate the sequence of features given the sequences of characters.
Args: Args:
text (Tensor(int64)): Input sequence of characters (T_text,). text (Tensor(int64)):
speech (Optional[Tensor]): Feature sequence to extract style (N, idim). Input sequence of characters (T_text,).
spk_emb (ptional[Tensor]): Speaker embedding (spk_embed_dim,). speech (Optional[Tensor]):
spk_id (Optional[Tensor]): Speaker ID (1,). Feature sequence to extract style (N, idim).
lang_id (Optional[Tensor]): Language ID (1,). spk_emb (ptional[Tensor]):
threshold (float): Threshold in inference. Speaker embedding (spk_embed_dim,).
minlenratio (float): Minimum length ratio in inference. spk_id (Optional[Tensor]):
maxlenratio (float): Maximum length ratio in inference. Speaker ID (1,).
use_att_constraint (bool): Whether to apply attention constraint. lang_id (Optional[Tensor]):
backward_window (int): Backward window in attention constraint. Language ID (1,).
forward_window (int): Forward window in attention constraint. threshold (float):
use_teacher_forcing (bool): Whether to use teacher forcing. Threshold in inference.
minlenratio (float):
Minimum length ratio in inference.
maxlenratio (float):
Maximum length ratio in inference.
use_att_constraint (bool):
Whether to apply attention constraint.
backward_window (int):
Backward window in attention constraint.
forward_window (int):
Forward window in attention constraint.
use_teacher_forcing (bool):
Whether to use teacher forcing.
Returns: Returns:
Dict[str, Tensor] Dict[str, Tensor]
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册