提交 d66d6a05 编写于 作者: L lym0302

Merge branch 'develop' of https://github.com/lym0302/PaddleSpeech into develop

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Author: zhendong.peng@mobvoi.com (Zhendong Peng)
import argparse
from flask import Flask
from flask import render_template
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--port', default=19999, type=int, help='port id')
args = parser.parse_args()
app = Flask(__name__)
@app.route('/')
def index():
return render_template('index.html')
if __name__ == '__main__':
app.run(host='0.0.0.0', port=args.port, debug=True)
此差异已折叠。
# paddlespeech serving 网页Demo
- 感谢[wenet](https://github.com/wenet-e2e/wenet)团队的前端demo代码.
![图片](./paddle_web_demo.png)
step1: 开启流式语音识别服务器端
## 使用方法
### 1. 在本地电脑启动网页服务
```
python app.py
```
# 开启流式语音识别服务
cd PaddleSpeech/demos/streaming_asr_server
paddlespeech_server start --config_file conf/ws_conformer_wenetspeech_application_faster.yaml
```
```
step2: 谷歌游览器打开 `web`目录下`index.html`
### 2. 本地电脑浏览器
step3: 点击`连接`,验证WebSocket是否成功连接
step4:点击开始录音(弹窗询问,允许录音)
在浏览器中输入127.0.0.1:19999 即可看到相关网页Demo。
![图片](./paddle_web_demo.png)
/*
* @Author: baipengxia
* @Date: 2021-03-12 11:44:28
* @Last Modified by: baipengxia
* @Last Modified time: 2021-03-12 15:14:24
*/
/** COMMON RESET **/
* {
-webkit-tap-highlight-color: rgba(0, 0, 0, 0);
}
body,
h1,
h2,
h3,
h4,
h5,
h6,
hr,
p,
dl,
dt,
dd,
ul,
ol,
li,
fieldset,
lengend,
button,
input,
textarea,
th,
td {
margin: 0;
padding: 0;
color: #000;
}
body {
font-size: 14px;
}
html, body {
min-width: 1200px;
}
button,
input,
select,
textarea {
font-size: 14px;
}
h1 {
font-size: 18px;
}
h2 {
font-size: 14px;
}
h3 {
font-size: 14px;
}
ul,
ol,
li {
list-style: none;
}
a {
text-decoration: none;
}
a:hover {
text-decoration: none;
}
fieldset,
img {
border: none;
}
table {
border-collapse: collapse;
border-spacing: 0;
}
i {
font-style: normal;
}
label {
position: inherit;
}
.clearfix:after {
content: ".";
display: block;
height: 0;
clear: both;
visibility: hidden;
}
.clearfix {
zoom: 1;
display: block;
}
html,
body {
font-family: Tahoma, Arial, 'microsoft yahei', 'Roboto', 'Droid Sans', 'Helvetica Neue', 'Droid Sans Fallback', 'Heiti SC', 'Hiragino Sans GB', 'Simsun', 'sans-self';
}
.audio-banner {
width: 100%;
overflow: auto;
padding: 0;
background: url('../image/voice-dictation.svg');
background-size: cover;
}
.weaper {
width: 1200px;
height: 155px;
margin: 72px auto;
}
.text-content {
width: 670px;
height: 100%;
float: left;
}
.text-content .title {
font-size: 34px;
font-family: 'PingFangSC-Medium';
font-weight: 500;
color: rgba(255, 255, 255, 1);
line-height: 48px;
}
.text-content .con {
font-size: 16px;
font-family: PingFangSC-Light;
font-weight: 300;
color: rgba(255, 255, 255, 1);
line-height: 30px;
}
.img-con {
width: 416px;
height: 100%;
float: right;
}
.img-con img {
width: 100%;
height: 100%;
}
.con-container {
margin-top: 34px;
}
.audio-advantage {
background: #f8f9fa;
}
.asr-advantage {
width: 1200px;
margin: 0 auto;
}
.asr-advantage h2 {
text-align: center;
font-size: 22px;
padding: 30px 0 0 0;
}
.asr-advantage > ul > li {
box-sizing: border-box;
padding: 0 16px;
width: 33%;
text-align: center;
margin-bottom: 35px;
}
.asr-advantage > ul > li .icons{
margin-top: 10px;
margin-bottom: 20px;
width: 42px;
height: 42px;
}
.service-item-content {
margin-top: 35px;
display: flex;
justify-content: center;
flex-wrap: wrap;
}
.service-item-content img {
width: 160px;
vertical-align: bottom;
}
.service-item-content > li {
box-sizing: border-box;
padding: 0 16px;
width: 33%;
text-align: center;
margin-bottom: 35px;
}
.service-item-content > li .service-item-content-title {
line-height: 1.5;
font-weight: 700;
margin-top: 10px;
}
.service-item-content > li .service-item-content-desc {
margin-top: 5px;
line-height: 1.8;
color: #657384;
}
.audio-scene-con {
width: 100%;
padding-bottom: 84px;
background: #fff;
}
.audio-scene {
overflow: auto;
width: 1200px;
background: #fff;
text-align: center;
padding: 0;
margin: 0 auto;
}
.audio-scene h2 {
padding: 30px 0 0 0;
font-size: 22px;
text-align: center;
}
.audio-experience {
width: 100%;
height: 538px;
background: #fff;
padding: 0;
margin: 0;
overflow: auto;
}
.asr-box {
width: 1200px;
height: 394px;
margin: 64px auto;
}
.asr-box h2 {
font-size: 22px;
text-align: center;
margin-bottom: 64px;
}
.voice-container {
position: relative;
width: 1200px;
height: 308px;
background: rgba(255, 255, 255, 1);
border-radius: 8px;
border: 1px solid rgba(225, 225, 225, 1);
}
.voice-container .voice {
height: 236px;
width: 100%;
border-radius: 8px;
}
.voice-container .voice textarea {
height: 100%;
width: 100%;
border: none;
outline: none;
border-radius: 8px;
padding: 25px;
font-size: 14px;
box-sizing: border-box;
resize: none;
}
.voice-input {
width: 100%;
height: 72px;
box-sizing: border-box;
padding-left: 35px;
background: rgba(242, 244, 245, 1);
border-radius: 8px;
line-height: 72px;
}
.voice-input .el-select {
width: 492px;
}
.start-voice {
display: inline-block;
margin-left: 10px;
}
.start-voice .time {
margin-right: 25px;
}
.asr-advantage > ul > li {
margin-bottom: 77px;
}
#msg {
width: 100%;
line-height: 40px;
font-size: 14px;
margin-left: 330px;
}
#captcha {
margin-left: 350px !important;
display: inline-block;
position: relative;
}
.black {
position: fixed;
width: 100%;
height: 100%;
z-index: 5;
background: rgba(0, 0, 0, 0.5);
top: 0;
left: 0;
}
.container {
position: fixed;
z-index: 6;
top: 25%;
left: 10%;
}
.audio-scene-con {
width: 100%;
padding-bottom: 84px;
background: #fff;
}
#sound {
color: #fff;
cursor: pointer;
background: #147ede;
padding: 10px;
margin-top: 30px;
margin-left: 135px;
width: 176px;
height: 30px !important;
text-align: center;
line-height: 30px !important;
border-radius: 10px;
}
.con-ten {
position: absolute;
width: 100%;
height: 100%;
z-index: 5;
background: #fff;
opacity: 0.5;
top: 0;
left: 0;
}
.websocket-url {
width: 320px;
height: 20px;
border: 1px solid #dcdfe6;
line-height: 20px;
padding: 10px;
border-radius: 4px;
}
.voice-btn {
color: #fff;
background-color: #409eff;
font-weight: 500;
padding: 12px 20px;
font-size: 14px;
border-radius: 4px;
border: 0;
cursor: pointer;
}
.voice-btn.end {
display: none;
}
.result-text {
background: #fff;
padding: 20px;
}
.voice-footer {
border-top: 1px solid #dddede;
background: #f7f9fa;
text-align: center;
margin-bottom: 8px;
color: #333;
font-size: 12px;
padding: 20px 0;
}
/** line animate **/
.time-box {
display: none;
margin-left: 10px;
width: 300px;
}
.total-time {
font-size: 14px;
color: #545454;
}
.voice-btn.end.show,
.time-box.show {
display: inline;
}
.start-taste-line {
margin-right: 20px;
display: inline-block;
}
.start-taste-line hr {
background-color: #187cff;
width: 3px;
height: 8px;
margin: 0 3px;
display: inline-block;
border: none;
}
.hr {
animation: note 0.2s ease-in-out;
animation-iteration-count: infinite;
animation-direction: alternate;
}
.hr-one {
animation-delay: -0.9s;
}
.hr-two {
animation-delay: -0.8s;
}
.hr-three {
animation-delay: -0.7s;
}
.hr-four {
animation-delay: -0.6s;
}
.hr-five {
animation-delay: -0.5s;
}
.hr-six {
animation-delay: -0.4s;
}
.hr-seven {
animation-delay: -0.3s;
}
.hr-eight {
animation-delay: -0.2s;
}
.hr-nine {
animation-delay: -0.1s;
}
@keyframes note {
from {
transform: scaleY(1);
}
to {
transform: scaleY(4);
}
}
\ No newline at end of file
因为 它太大了无法显示 source diff 。你可以改为 查看blob
因为 它太大了无法显示 source diff 。你可以改为 查看blob
SoundRecognizer = {
rec: null,
wave: null,
SampleRate: 16000,
testBitRate: 16,
isCloseRecorder: false,
SendInterval: 300,
realTimeSendTryType: 'pcm',
realTimeSendTryEncBusy: 0,
realTimeSendTryTime: 0,
realTimeSendTryNumber: 0,
transferUploadNumberMax: 0,
realTimeSendTryChunk: null,
soundType: "pcm",
init: function (config) {
this.soundType = config.soundType || 'pcm';
this.SampleRate = config.sampleRate || 16000;
this.recwaveElm = config.recwaveElm || '';
this.TransferUpload = config.translerCallBack || this.TransferProcess;
this.initRecorder();
},
RealTimeSendTryReset: function (type) {
this.realTimeSendTryType = type;
this.realTimeSendTryTime = 0;
},
RealTimeSendTry: function (rec, isClose) {
var that = this;
var t1 = Date.now(), endT = 0, recImpl = Recorder.prototype;
if (this.realTimeSendTryTime == 0) {
this.realTimeSendTryTime = t1;
this.realTimeSendTryEncBusy = 0;
this.realTimeSendTryNumber = 0;
this.transferUploadNumberMax = 0;
this.realTimeSendTryChunk = null;
}
if (!isClose && t1 - this.realTimeSendTryTime < this.SendInterval) {
return;//控制缓冲达到指定间隔才进行传输
}
this.realTimeSendTryTime = t1;
var number = ++this.realTimeSendTryNumber;
//借用SampleData函数进行数据的连续处理,采样率转换是顺带的
var chunk = Recorder.SampleData(rec.buffers, rec.srcSampleRate, this.SampleRate, this.realTimeSendTryChunk, { frameType: isClose ? "" : this.realTimeSendTryType });
//清理已处理完的缓冲数据,释放内存以支持长时间录音,最后完成录音时不能调用stop,因为数据已经被清掉了
for (var i = this.realTimeSendTryChunk ? this.realTimeSendTryChunk.index : 0; i < chunk.index; i++) {
rec.buffers[i] = null;
}
this.realTimeSendTryChunk = chunk;
//没有新数据,或结束时的数据量太小,不能进行mock转码
if (chunk.data.length == 0 || isClose && chunk.data.length < 2000) {
this.TransferUpload(number, null, 0, null, isClose);
return;
}
//实时编码队列阻塞处理
if (!isClose) {
if (this.realTimeSendTryEncBusy >= 2) {
console.log("编码队列阻塞,已丢弃一帧", 1);
return;
}
}
this.realTimeSendTryEncBusy++;
//通过mock方法实时转码成mp3、wav
var encStartTime = Date.now();
var recMock = Recorder({
type: this.realTimeSendTryType
, sampleRate: this.SampleRate //采样率
, bitRate: this.testBitRate //比特率
});
recMock.mock(chunk.data, chunk.sampleRate);
recMock.stop(function (blob, duration) {
that.realTimeSendTryEncBusy && (that.realTimeSendTryEncBusy--);
blob.encTime = Date.now() - encStartTime;
//转码好就推入传输
that.TransferUpload(number, blob, duration, recMock, isClose);
}, function (msg) {
that.realTimeSendTryEncBusy && (that.realTimeSendTryEncBusy--);
//转码错误?没想到什么时候会产生错误!
console.log("不应该出现的错误:" + msg, 1);
});
},
recordClose: function () {
try {
this.rec.close(function () {
this.isCloseRecorder = true;
});
this.RealTimeSendTry(this.rec, true);//最后一次发送
} catch (ex) {
// recordClose();
}
},
recordEnd: function () {
try {
this.rec.stop(function (blob, time) {
this.recordClose();
}, function (s) {
this.recordClose();
});
} catch (ex) {
}
},
initRecorder: function () {
var that = this;
var rec = Recorder({
type: that.soundType
, bitRate: that.testBitRate
, sampleRate: that.SampleRate
, onProcess: function (buffers, level, time, sampleRate) {
that.wave.input(buffers[buffers.length - 1], level, sampleRate);
that.RealTimeSendTry(rec, false);//推入实时处理,因为是unknown格式,这里简化函数调用,没有用到buffers和bufferSampleRate,因为这些数据和rec.buffers是完全相同的。
}
});
rec.open(function () {
that.wave = Recorder.FrequencyHistogramView({
elem: that.recwaveElm, lineCount: 90
, position: 0
, minHeight: 1
, stripeEnable: false
});
rec.start();
that.isCloseRecorder = false;
that.RealTimeSendTryReset(that.soundType);//重置
});
this.rec = rec;
},
TransferProcess: function (number, blobOrNull, duration, blobRec, isClose) {
}
}
\ No newline at end of file
/*
录音
https://github.com/xiangyuecn/Recorder
src: engine/pcm.js
*/
!function(){"use strict";Recorder.prototype.enc_pcm={stable:!0,testmsg:"pcm为未封装的原始音频数据,pcm数据文件无法直接播放;支持位数8位、16位(填在比特率里面),采样率取值无限制"},Recorder.prototype.pcm=function(e,t,r){var a=this.set,n=e.length,o=8==a.bitRate?8:16,c=new ArrayBuffer(n*(o/8)),s=new DataView(c),l=0;if(8==o)for(var p=0;p<n;p++,l++){var i=128+(e[p]>>8);s.setInt8(l,i,!0)}else for(p=0;p<n;p++,l+=2)s.setInt16(l,e[p],!0);t(new Blob([s.buffer],{type:"audio/pcm"}))},Recorder.pcm2wav=function(e,a,n){e.slice&&null!=e.type&&(e={blob:e});var o=e.sampleRate||16e3,c=e.bitRate||16;if(e.sampleRate&&e.bitRate||console.warn("pcm2wav必须提供sampleRate和bitRate"),Recorder.prototype.wav){var s=new FileReader;s.onloadend=function(){var e;if(8==c){var t=new Uint8Array(s.result);e=new Int16Array(t.length);for(var r=0;r<t.length;r++)e[r]=t[r]-128<<8}else e=new Int16Array(s.result);Recorder({type:"wav",sampleRate:o,bitRate:c}).mock(e,o).stop(function(e,t){a(e,t)},n)},s.readAsArrayBuffer(e.blob)}else n("pcm2wav必须先加载wav编码器wav.js")}}();
\ No newline at end of file
/*
录音
https://github.com/xiangyuecn/Recorder
src: engine/wav.js
*/
!function(){"use strict";Recorder.prototype.enc_wav={stable:!0,testmsg:"支持位数8位、16位(填在比特率里面),采样率取值无限制"},Recorder.prototype.wav=function(t,e,n){var r=this.set,a=t.length,o=r.sampleRate,f=8==r.bitRate?8:16,i=a*(f/8),s=new ArrayBuffer(44+i),c=new DataView(s),u=0,v=function(t){for(var e=0;e<t.length;e++,u++)c.setUint8(u,t.charCodeAt(e))},w=function(t){c.setUint16(u,t,!0),u+=2},l=function(t){c.setUint32(u,t,!0),u+=4};if(v("RIFF"),l(36+i),v("WAVE"),v("fmt "),l(16),w(1),w(1),l(o),l(o*(f/8)),w(f/8),w(f),v("data"),l(i),8==f)for(var p=0;p<a;p++,u++){var d=128+(t[p]>>8);c.setInt8(u,d,!0)}else for(p=0;p<a;p++,u+=2)c.setInt16(u,t[p],!0);e(new Blob([c.buffer],{type:"audio/wav"}))}}();
\ No newline at end of file
/*
录音
https://github.com/xiangyuecn/Recorder
src: extensions/frequency.histogram.view.js
*/
!function(){"use strict";var t=function(t){return new e(t)},e=function(t){var e=this,r={scale:2,fps:20,lineCount:30,widthRatio:.6,spaceWidth:0,minHeight:0,position:-1,mirrorEnable:!1,stripeEnable:!0,stripeHeight:3,stripeMargin:6,fallDuration:1e3,stripeFallDuration:3500,linear:[0,"rgba(0,187,17,1)",.5,"rgba(255,215,0,1)",1,"rgba(255,102,0,1)"],stripeLinear:null,shadowBlur:0,shadowColor:"#bbb",stripeShadowBlur:-1,stripeShadowColor:"",onDraw:function(t,e){}};for(var a in t)r[a]=t[a];e.set=t=r;var i=t.elem;i&&("string"==typeof i?i=document.querySelector(i):i.length&&(i=i[0])),i&&(t.width=i.offsetWidth,t.height=i.offsetHeight);var o=t.scale,l=t.width*o,n=t.height*o,h=e.elem=document.createElement("div"),s=["","transform-origin:0 0;","transform:scale("+1/o+");"];h.innerHTML='<div style="width:'+t.width+"px;height:"+t.height+'px;overflow:hidden"><div style="width:'+l+"px;height:"+n+"px;"+s.join("-webkit-")+s.join("-ms-")+s.join("-moz-")+s.join("")+'"><canvas/></div></div>';var f=e.canvas=h.querySelector("canvas");e.ctx=f.getContext("2d");if(f.width=l,f.height=n,i&&(i.innerHTML="",i.appendChild(h)),!Recorder.LibFFT)throw new Error("需要lib.fft.js支持");e.fft=Recorder.LibFFT(1024),e.lastH=[],e.stripesH=[]};e.prototype=t.prototype={genLinear:function(t,e,r,a){for(var i=t.createLinearGradient(0,r,0,a),o=0;o<e.length;)i.addColorStop(e[o++],e[o++]);return i},input:function(t,e,r){var a=this;a.sampleRate=r,a.pcmData=t,a.pcmPos=0,a.inputTime=Date.now(),a.schedule()},schedule:function(){var t=this,e=t.set,r=Math.floor(1e3/e.fps);t.timer||(t.timer=setInterval(function(){t.schedule()},r));var a=Date.now(),i=t.drawTime||0;if(a-t.inputTime>1.3*e.stripeFallDuration)return clearInterval(t.timer),void(t.timer=0);if(!(a-i<r)){t.drawTime=a;for(var o=t.fft.bufferSize,l=t.pcmData,n=t.pcmPos,h=new Int16Array(o),s=0;s<o&&n<l.length;s++,n++)h[s]=l[n];t.pcmPos=n;var f=t.fft.transform(h);t.draw(f,t.sampleRate)}},draw:function(t,e){var r=this,a=r.set,i=r.ctx,o=a.scale,l=a.width*o,n=a.height*o,h=a.lineCount,s=r.fft.bufferSize,f=a.position,d=Math.abs(a.position),c=1==f?0:n,p=n;d<1&&(c=p/=2,p=Math.floor(p*(1+d)),c=Math.floor(0<f?c*(1-d):c*(1+d)));for(var u=r.lastH,v=r.stripesH,w=Math.ceil(p/(a.fallDuration/(1e3/a.fps))),g=Math.ceil(p/(a.stripeFallDuration/(1e3/a.fps))),m=a.stripeMargin*o,M=1<<(Math.round(Math.log(s)/Math.log(2)+3)<<1),b=Math.log(M)/Math.log(10),L=20*Math.log(32767)/Math.log(10),y=s/2,S=Math.min(y,Math.floor(5e3*y/(e/2))),C=S==y,H=C?h:Math.round(.8*h),R=S/H,D=C?0:(y-S)/(h-H),x=0,F=0;F<h;F++){var T=Math.ceil(x);x+=F<H?R:D;for(var B=Math.min(Math.ceil(x),y),E=0,j=T;j<B;j++)E=Math.max(E,Math.abs(t[j]));var I=M<E?Math.floor(17*(Math.log(E)/Math.log(10)-b)):0,q=p*Math.min(I/L,1);u[F]=(u[F]||0)-w,q<u[F]&&(q=u[F]),q<0&&(q=0),u[F]=q;var z=v[F]||0;if(q&&z<q+m)v[F]=q+m;else{var P=z-g;P<0&&(P=0),v[F]=P}}i.clearRect(0,0,l,n);var W=r.genLinear(i,a.linear,c,c-p),k=a.stripeLinear&&r.genLinear(i,a.stripeLinear,c,c-p)||W,A=r.genLinear(i,a.linear,c,c+p),G=a.stripeLinear&&r.genLinear(i,a.stripeLinear,c,c+p)||A;i.shadowBlur=a.shadowBlur*o,i.shadowColor=a.shadowColor;var V=a.mirrorEnable,J=V?2*h-1:h,K=a.widthRatio,N=a.spaceWidth*o;0!=N&&(K=(l-N*(J+1))/l);for(var O=Math.max(1*o,Math.floor(l*K/J)),Q=(l-J*O)/(J+1),U=a.minHeight*o,X=V?l/2-(Q+O/2):0,Y=(F=0,X);F<h;F++)Y+=Q,$=Math.floor(Y),q=Math.max(u[F],U),0!=c&&(_=c-q,i.fillStyle=W,i.fillRect($,_,O,q)),c!=n&&(i.fillStyle=A,i.fillRect($,c,O,q)),Y+=O;if(a.stripeEnable){var Z=a.stripeShadowBlur;i.shadowBlur=(-1==Z?a.shadowBlur:Z)*o,i.shadowColor=a.stripeShadowColor||a.shadowColor;var $,_,tt=a.stripeHeight*o;for(F=0,Y=X;F<h;F++)Y+=Q,$=Math.floor(Y),q=v[F],0!=c&&((_=c-q-tt)<0&&(_=0),i.fillStyle=k,i.fillRect($,_,O,tt)),c!=n&&(n<(_=c+q)+tt&&(_=n-tt),i.fillStyle=G,i.fillRect($,_,O,tt)),Y+=O}if(V){var et=Math.floor(l/2);i.save(),i.scale(-1,1),i.drawImage(r.canvas,Math.ceil(l/2),0,et,n,-et,0,et,n),i.restore()}a.onDraw(t,e)}},Recorder.FrequencyHistogramView=t}();
\ No newline at end of file
/*
录音
https://github.com/xiangyuecn/Recorder
src: extensions/lib.fft.js
*/
Recorder.LibFFT=function(r){"use strict";var s,v,d,l,F,b,g,m;return function(r){var o,t,a,f;for(s=Math.round(Math.log(r)/Math.log(2)),d=((v=1<<s)<<2)*Math.sqrt(2),l=[],F=[],b=[0],g=[0],m=[],o=0;o<v;o++){for(a=o,f=t=0;t!=s;t++)f<<=1,f|=1&a,a>>>=1;m[o]=f}var n,u=2*Math.PI/v;for(o=(v>>1)-1;0<o;o--)n=o*u,g[o]=Math.cos(n),b[o]=Math.sin(n)}(r),{transform:function(r){var o,t,a,f,n,u,e,h,M=1,i=s-1;for(o=0;o!=v;o++)l[o]=r[m[o]],F[o]=0;for(o=s;0!=o;o--){for(t=0;t!=M;t++)for(n=g[t<<i],u=b[t<<i],a=t;a<v;a+=M<<1)e=n*l[f=a+M]-u*F[f],h=n*F[f]+u*l[f],l[f]=l[a]-e,F[f]=F[a]-h,l[a]+=e,F[a]+=h;M<<=1,i--}t=v>>1;var c=new Float64Array(t);for(n=-(u=d),o=t;0!=o;o--)e=l[o],h=F[o],c[o-1]=n<e&&e<u&&n<h&&h<u?0:Math.round(e*e+h*h);return c},bufferSize:v}};
\ No newline at end of file
/*
录音
https://github.com/xiangyuecn/Recorder
src: recorder-core.js
*/
!function(y){"use strict";var h=function(){},A=function(e){return new t(e)};A.IsOpen=function(){var e=A.Stream;if(e){var t=e.getTracks&&e.getTracks()||e.audioTracks||[],n=t[0];if(n){var r=n.readyState;return"live"==r||r==n.LIVE}}return!1},A.BufferSize=4096,A.Destroy=function(){for(var e in M("Recorder Destroy"),g(),n)n[e]()};var n={};A.BindDestroy=function(e,t){n[e]=t},A.Support=function(){var e=y.AudioContext;if(e||(e=y.webkitAudioContext),!e)return!1;var t=navigator.mediaDevices||{};return t.getUserMedia||(t=navigator).getUserMedia||(t.getUserMedia=t.webkitGetUserMedia||t.mozGetUserMedia||t.msGetUserMedia),!!t.getUserMedia&&(A.Scope=t,A.Ctx&&"closed"!=A.Ctx.state||(A.Ctx=new e,A.BindDestroy("Ctx",function(){var e=A.Ctx;e&&e.close&&(e.close(),A.Ctx=0)})),!0)};var k="ConnectEnableWorklet";A[k]=!1;var d=function(e){var t=(e=e||A).BufferSize||A.BufferSize,r=A.Ctx,n=e.Stream,a=n._m=r.createMediaStreamSource(n),u=n._call,o=function(e,t){if(!t||h)for(var n in u){for(var r=t||e.inputBuffer.getChannelData(0),a=r.length,o=new Int16Array(a),s=0,i=0;i<a;i++){var c=Math.max(-1,Math.min(1,r[i]));c=c<0?32768*c:32767*c,o[i]=c,s+=Math.abs(c)}for(var f in u)u[f](o,s);return}else M(l+"多余回调",3)},s="ScriptProcessor",l="audioWorklet",i="Recorder",c=i+" "+l,f="RecProc",p=r.createScriptProcessor||r.createJavaScriptNode,v="。由于"+l+"内部1秒375次回调,在移动端可能会有性能问题导致回调丢失录音变短,PC端无影响,暂不建议开启"+l+"",m=function(){h=n.isWorklet=!1,I(n),M("Connect采用老的"+s+""+(A[k]?"但已":"")+"设置"+i+"."+k+"=true尝试启用"+l+v,3);var e=n._p=p.call(r,t,1,1);a.connect(e),e.connect(r.destination),e.onaudioprocess=function(e){o(e)}},h=n.isWorklet=!p||A[k],d=y.AudioWorkletNode;if(h&&r[l]&&d){var g,S=function(){return h&&n._na},_=n._na=function(){""!==g&&(clearTimeout(g),g=setTimeout(function(){g=0,M(l+"未返回任何音频,恢复使用"+s,3),S()&&p&&m()},500))},C=function(){if(S()){var e=n._n=new d(r,f,{processorOptions:{bufferSize:t}});a.connect(e),e.connect(r.destination),e.port.onmessage=function(e){g&&(clearTimeout(g),g=""),o(0,e.data.val)},M("Connect采用"+l+"方式,设置"+i+"."+k+"=false可恢复老式"+s+v,3)}};r.resume()[u&&"finally"](function(){if(S())if(r[f])C();else{var e,t,n=(t="class "+f+" extends AudioWorkletProcessor{",t+="constructor "+(e=function(e){return e.toString().replace(/^function|DEL_/g,"").replace(/\$RA/g,c)})(function(e){DEL_super(e);var t=this,n=e.processorOptions.bufferSize;t.bufferSize=n,t.buffer=new Float32Array(2*n),t.pos=0,t.port.onmessage=function(e){e.data.kill&&(t.kill=!0,console.log("$RA kill call"))},console.log("$RA .ctor call",e)}),t+="process "+e(function(e,t,n){var r=this,a=r.bufferSize,o=r.buffer,s=r.pos;if((e=(e[0]||[])[0]||[]).length){o.set(e,s);var i=~~((s+=e.length)/a)*a;if(i){this.port.postMessage({val:o.slice(0,i)});var c=o.subarray(i,s);(o=new Float32Array(2*a)).set(c),s=c.length,r.buffer=o}r.pos=s}return!r.kill}),t+='}try{registerProcessor("'+f+'", '+f+')}catch(e){console.error("'+c+'注册失败",e)}',"data:text/javascript;base64,"+btoa(unescape(encodeURIComponent(t))));r[l].addModule(n).then(function(e){S()&&(r[f]=1,C(),g&&_())})[u&&"catch"](function(e){M(l+".addModule失败",1,e),S()&&m()})}})}else m()},I=function(e){e._na=null,e._n&&(e._n.port.postMessage({kill:!0}),e._n.disconnect(),e._n=null)},g=function(e){var t=(e=e||A)==A,n=e.Stream;if(n&&(n._m&&(n._m.disconnect(),n._m=null),n._p&&(n._p.disconnect(),n._p.onaudioprocess=n._p=null),I(n),t)){for(var r=n.getTracks&&n.getTracks()||n.audioTracks||[],a=0;a<r.length;a++){var o=r[a];o.stop&&o.stop()}n.stop&&n.stop()}e.Stream=0};A.SampleData=function(e,t,n,r,a){r||(r={});var o=r.index||0,s=r.offset||0,i=r.frameNext||[];a||(a={});var c=a.frameSize||1;a.frameType&&(c="mp3"==a.frameType?1152:1);for(var f=0,u=o;u<e.length;u++)f+=e[u].length;f=Math.max(0,f-Math.floor(s));var l=t/n;1<l?f=Math.floor(f/l):(l=1,n=t),f+=i.length;for(var p=new Int16Array(f),v=0,u=0;u<i.length;u++)p[v]=i[u],v++;for(var m=e.length;o<m;o++){for(var h=e[o],u=s,d=h.length;u<d;){var g=Math.floor(u),S=Math.ceil(u),_=u-g,C=h[g],y=S<d?h[S]:(e[o+1]||[C])[0]||0;p[v]=C+(y-C)*_,v++,u+=l}s=u-d}i=null;var k=p.length%c;if(0<k){var I=2*(p.length-k);i=new Int16Array(p.buffer.slice(I)),p=new Int16Array(p.buffer.slice(0,I))}return{index:o,offset:s,frameNext:i,sampleRate:n,data:p}},A.PowerLevel=function(e,t){var n=e/t||0;return n<1251?Math.round(n/1250*10):Math.round(Math.min(100,Math.max(0,100*(1+Math.log(n/1e4)/Math.log(10)))))};var M=function(e,t){var n=new Date,r=("0"+n.getMinutes()).substr(-2)+":"+("0"+n.getSeconds()).substr(-2)+"."+("00"+n.getMilliseconds()).substr(-3),a=this&&this.envIn&&this.envCheck&&this.id,o=["["+r+" Recorder"+(a?":"+a:"")+"]"+e],s=arguments,i=y.console||{},c=2,f=i.log;for("number"==typeof t?f=1==t?i.error:3==t?i.warn:f:c=1;c<s.length;c++)o.push(s[c]);u?f&&f("[IsLoser]"+o[0],1<o.length?o:""):f.apply(i,o)},u=!0;try{u=!console.log.apply}catch(e){}A.CLog=M;var r=0;function t(e){this.id=++r,A.Traffic&&A.Traffic();var t={type:"mp3",bitRate:16,sampleRate:16e3,onProcess:h};for(var n in e)t[n]=e[n];this.set=t,this._S=9,this.Sync={O:9,C:9}}A.Sync={O:9,C:9},A.prototype=t.prototype={CLog:M,_streamStore:function(){return this.set.sourceStream?this:A},open:function(e,n){var r=this,t=r._streamStore();e=e||h;var a=function(e,t){t=!!t,r.CLog("录音open失败:"+e+",isUserNotAllow:"+t,1),n&&n(e,t)},o=function(){r.CLog("open ok id:"+r.id),e(),r._SO=0},s=t.Sync,i=++s.O,c=s.C;r._O=r._O_=i,r._SO=r._S;var f=function(){if(c!=s.C||!r._O){var e="open被取消";return i==s.O?r.close():e="open被中断",a(e),!0}},u=r.envCheck({envName:"H5",canProcess:!0});if(u)a("不能录音:"+u);else if(r.set.sourceStream){if(!A.Support())return void a("不支持此浏览器从流中获取录音");g(t),r.Stream=r.set.sourceStream,r.Stream._call={};try{d(t)}catch(e){return void a("从流中打开录音失败:"+e.message)}o()}else{var l=function(e,t){try{y.top.a}catch(e){return void a('无权录音(跨域,请尝试给iframe添加麦克风访问策略,如allow="camera;microphone")')}/Permission|Allow/i.test(e)?a("用户拒绝了录音权限",!0):!1===y.isSecureContext?a("无权录音(需https)"):/Found/i.test(e)?a(t+",无可用麦克风"):a(t)};if(A.IsOpen())o();else if(A.Support()){var p=function(e){(A.Stream=e)._call={},f()||setTimeout(function(){f()||(A.IsOpen()?(d(),o()):a("录音功能无效:无音频流"))},100)},v=function(e){var t=e.name||e.message||e.code+":"+e;r.CLog("请求录音权限错误",1,e),l(t,"无法录音:"+t)},m=A.Scope.getUserMedia({audio:r.set.audioTrackSet||!0},p,v);m&&m.then&&m.then(p)[e&&"catch"](v)}else l("","此浏览器不支持录音")}},close:function(e){e=e||h;var t=this,n=t._streamStore();t._stop();var r=n.Sync;if(t._O=0,t._O_!=r.O)return t.CLog("close被忽略(因为同时open了多个rec,只有最后一个会真正close)",3),void e();r.C++,g(n),t.CLog("close"),e()},mock:function(e,t){var n=this;return n._stop(),n.isMock=1,n.mockEnvInfo=null,n.buffers=[e],n.recSize=e.length,n.srcSampleRate=t,n},envCheck:function(e){var t,n=this.set;return t||(this[n.type+"_envCheck"]?t=this[n.type+"_envCheck"](e,n):n.takeoffEncodeChunk&&(t=n.type+"类型不支持设置takeoffEncodeChunk")),t||""},envStart:function(e,t){var n=this,r=n.set;if(n.isMock=e?1:0,n.mockEnvInfo=e,n.buffers=[],n.recSize=0,n.envInLast=0,n.envInFirst=0,n.envInFix=0,n.envInFixTs=[],r.sampleRate=Math.min(t,r.sampleRate),n.srcSampleRate=t,n.engineCtx=0,n[r.type+"_start"]){var a=n.engineCtx=n[r.type+"_start"](r);a&&(a.pcmDatas=[],a.pcmSize=0)}},envResume:function(){this.envInFixTs=[]},envIn:function(e,t){var a=this,o=a.set,s=a.engineCtx,n=a.srcSampleRate,r=e.length,i=A.PowerLevel(t,r),c=a.buffers,f=c.length;c.push(e);var u=c,l=f,p=Date.now(),v=Math.round(r/n*1e3);a.envInLast=p,1==a.buffers.length&&(a.envInFirst=p-v);var m=a.envInFixTs;m.splice(0,0,{t:p,d:v});for(var h=p,d=0,g=0;g<m.length;g++){var S=m[g];if(3e3<p-S.t){m.length=g;break}h=S.t,d+=S.d}var _=m[1],C=p-h;if(C/3<C-d&&(_&&1e3<C||6<=m.length)){var y=p-_.t-v;if(v/5<y){var k=!o.disableEnvInFix;if(a.CLog("["+p+"]"+(k?"":"")+"补偿"+y+"ms",3),a.envInFix+=y,k){var I=new Int16Array(y*n/1e3);r+=I.length,c.push(I)}}}var M=a.recSize,x=r,b=M+x;if(a.recSize=b,s){var R=A.SampleData(c,n,o.sampleRate,s.chunkInfo);s.chunkInfo=R,b=(M=s.pcmSize)+(x=R.data.length),s.pcmSize=b,c=s.pcmDatas,f=c.length,c.push(R.data),n=R.sampleRate}var L=Math.round(b/n*1e3),w=c.length,T=u.length,z=function(){for(var e=O?0:-x,t=null==c[0],n=f;n<w;n++){var r=c[n];null==r?t=1:(e+=r.length,s&&r.length&&a[o.type+"_encode"](s,r))}if(t&&s)for(n=l,u[0]&&(n=0);n<T;n++)u[n]=null;t&&(e=O?x:0,c[0]=null),s?s.pcmSize+=e:a.recSize+=e},O=o.onProcess(c,i,L,n,f,z);if(!0===O){var D=0;for(g=f;g<w;g++)null==c[g]?D=1:c[g]=new Int16Array(0);D?a.CLog("未进入异步前不能清除buffers",3):s?s.pcmSize-=x:a.recSize-=x}else z()},start:function(){var e=this,t=A.Ctx,n=1;if(e.set.sourceStream?e.Stream||(n=0):A.IsOpen()||(n=0),n)if(e.CLog("开始录音"),e._stop(),e.state=0,e.envStart(null,t.sampleRate),e._SO&&e._SO+1!=e._S)e.CLog("start被中断",3);else{e._SO=0;var r=function(){e.state=1,e.resume()};"suspended"==t.state?(e.CLog("wait ctx resume..."),e.state=3,t.resume().then(function(){e.CLog("ctx resume"),3==e.state&&r()})):r()}else e.CLog("未open",1)},pause:function(){var e=this;e.state&&(e.state=2,e.CLog("pause"),delete e._streamStore().Stream._call[e.id])},resume:function(){var e,n=this;if(n.state){n.state=1,n.CLog("resume"),n.envResume();var t=n._streamStore();t.Stream._call[n.id]=function(e,t){1==n.state&&n.envIn(e,t)},(e=(t||A).Stream)._na&&e._na()}},_stop:function(e){var t=this,n=t.set;t.isMock||t._S++,t.state&&(t.pause(),t.state=0),!e&&t[n.type+"_stop"]&&(t[n.type+"_stop"](t.engineCtx),t.engineCtx=0)},stop:function(n,t,e){var r,a=this,o=a.set;a.CLog("stop "+(a.envInLast?a.envInLast-a.envInFirst+"ms 补"+a.envInFix+"ms":"-"));var s=function(){a._stop(),e&&a.close()},i=function(e){a.CLog("结束录音失败:"+e,1),t&&t(e),s()},c=function(e,t){if(a.CLog("结束录音 编码"+(Date.now()-r)+"ms 音频"+t+"ms/"+e.size+"b"),o.takeoffEncodeChunk)a.CLog("启用takeoffEncodeChunk后stop返回的blob长度为0不提供音频数据",3);else if(e.size<Math.max(100,t/2))return void i("生成的"+o.type+"无效");n&&n(e,t),s()};if(!a.isMock){var f=3==a.state;if(!a.state||f)return void i("未开始录音"+(f?",开始录音前无用户交互导致AudioContext未运行":""));a._stop(!0)}var u=a.recSize;if(u)if(a.buffers[0])if(a[o.type]){if(a.isMock){var l=a.envCheck(a.mockEnvInfo||{envName:"mock",canProcess:!1});if(l)return void i("录音错误:"+l)}var p=a.engineCtx;if(a[o.type+"_complete"]&&p){var v=Math.round(p.pcmSize/o.sampleRate*1e3);return r=Date.now(),void a[o.type+"_complete"](p,function(e){c(e,v)},i)}r=Date.now();var m=A.SampleData(a.buffers,a.srcSampleRate,o.sampleRate);o.sampleRate=m.sampleRate;var h=m.data;v=Math.round(h.length/o.sampleRate*1e3),a.CLog("采样"+u+"->"+h.length+" 花:"+(Date.now()-r)+"ms"),setTimeout(function(){r=Date.now(),a[o.type](h,function(e){c(e,v)},function(e){i(e)})})}else i("未加载"+o.type+"编码器");else i("音频buffers被释放");else i("未采集到录音")}},y.Recorder&&y.Recorder.Destroy(),(y.Recorder=A).LM="2022-03-05 11:53:19",A.TrafficImgUrl="//ia.51.la/go1?id=20469973&pvFlag=1",A.Traffic=function(){var e=A.TrafficImgUrl;if(e){var t=A.Traffic,n=location.href.replace(/#.*/,"");if(0==e.indexOf("//")&&(e=/^https:/i.test(n)?"https:"+e:"http:"+e),!t[n]){t[n]=1;var r=new Image;r.src=e,M("Traffic Analysis Image: Recorder.TrafficImgUrl="+A.TrafficImgUrl)}}}}(window),"function"==typeof define&&define.amd&&define(function(){return Recorder}),"object"==typeof module&&module.exports&&(module.exports=Recorder);
\ No newline at end of file
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>PaddleSpeech Serving-语音实时转写</title>
<link rel="shortcut icon" href="./static/paddle.ico">
<script src="../static/js/jquery-3.2.1.min.js"></script>
<script src="../static/js/recorder/recorder-core.js"></script>
<script src="../static/js/recorder/extensions/lib.fft.js"></script>
<script src="../static/js/recorder/extensions/frequency.histogram.view.js"></script>
<script src="../static/js/recorder/engine/pcm.js"></script>
<script src="../static/js/SoundRecognizer.js"></script>
<link rel="stylesheet" href="../static/css/style.css">
<link rel="stylesheet" href="../static/css/font-awesome.min.css">
</head>
<body>
<div class="asr-content">
<div class="audio-banner">
<div class="weaper">
<div class="text-content">
<p><span class="title">PaddleSpeech Serving简介</span></p>
<p class="con-container">
<span class="con">PaddleSpeech 是基于飞桨 PaddlePaddle 的语音方向的开源模型库,用于语音和音频中的各种关键任务的开发。PaddleSpeech Serving是基于python + fastapi 的语音算法模型的C/S类型后端服务,旨在统一paddle speech下的各语音算子来对外提供后端服务。</span>
</p>
</div>
<div class="img-con">
<img src="../static/image/PaddleSpeech_logo.png" alt="" />
</div>
</div>
</div>
<div class="audio-experience">
<div class="asr-box">
<h2>产品体验</h2>
<div id="client-word-recorder" style="position: relative;">
<div class="pd">
<div style="text-align:center;height:20px;width:100%;
border:0px solid #bcbcbc;color:#000;box-sizing: border-box;display:inline-block"
class="recwave">
</div>
</div>
</div>
<div class="voice-container">
<div class="voice-input">
<span>WebSocket URL:</span>
<input type="text" id="socketUrl" class="websocket-url" value="ws://127.0.0.1:8091/ws/asr"
placeholder="请输入服务器地址,如:ws://127.0.0.1:8091/ws/asr">
<div class="start-voice">
<button type="primary" id="beginBtn" class="voice-btn">
<span class="fa fa-microphone"> 开始识别</span>
</button>
<button type="primary" id="endBtn" class="voice-btn end">
<span class="fa fa-microphone-slash"> 结束识别</span>
</button>
<div id="timeBox" class="time-box flex-display-1">
<span class="total-time">识别中,<i id="timeCount"></i> 秒后自动停止识别</span>
</div>
</div>
</div>
<div class="voice">
<div class="result-text" id="resultPanel">此处显示识别结果</div>
</div>
</div>
</div>
</div>
</div>
<script>
var ws = null
var timeLoop = null
var result = ""
$(document).ready(function () {
$('#beginBtn').on('click', startRecording)
$('#endBtn').on('click', stopRecording)
})
function openWebSocket(url) {
if ("WebSocket" in window) {
ws = new WebSocket(url)
ws.onopen = function () {
console.log("Websocket 连接成功,开始识别")
ws.send(JSON.stringify({
"signal": "start"
}))
}
ws.onmessage = function (_msg) { parseResult(_msg.data) }
ws.onclose = function () {
console.log("WebSocket 连接断开")
}
ws.onerror = function () { console.log("WebSocket 连接失败") }
}
}
function parseResult(data) {
var data = JSON.parse(data)
console.log('result json:', data)
var result = data.result
console.log(result)
$("#resultPanel").html(result)
}
function TransferUpload(number, blobOrNull, duration, blobRec, isClose) {
if (blobOrNull) {
var blob = blobOrNull
var encTime = blob.encTime
var reader = new FileReader()
reader.onloadend = function () { ws.send(reader.result) }
reader.readAsArrayBuffer(blob)
}
}
function startRecording() {
// Check socket url
var socketUrl = $('#socketUrl').val()
if (!socketUrl.trim()) {
alert('请输入 WebSocket 服务器地址,如:ws://127.0.0.1:8091/ws/asr')
$('#socketUrl').focus()
return
}
// init recorder
SoundRecognizer.init({
soundType: 'pcm',
sampleRate: 16000,
recwaveElm: '.recwave',
translerCallBack: TransferUpload
})
openWebSocket(socketUrl)
// Change button state
$('#beginBtn').hide()
$('#endBtn, #timeBox').addClass('show')
// Start countdown
var seconds = 180
$('#timeCount').text(seconds)
timeLoop = setInterval(function () {
seconds--
$('#timeCount').text(seconds)
if (seconds === 0) {
stopRecording()
}
}, 1000)
}
function stopRecording() {
ws.send(JSON.stringify({ "signal": "end" }))
SoundRecognizer.recordClose()
$('#endBtn').add($('#timeBox')).removeClass('show')
$('#beginBtn').show()
$('#timeCount').text('')
clearInterval(timeLoop)
}
</script>
</body>
</html>
......@@ -119,12 +119,9 @@ The configuration file can be found in `conf/tts_online_application.yaml`.
- `protocol`: Service protocol, choices: [http, websocket], default: http.
- `input`: (required): Input text to generate.
- `spk_id`: Speaker id for multi-speaker text to speech. Default: 0
- `speed`: Audio speed, the value should be set between 0 and 3. Default: 1.0
- `volume`: Audio volume, the value should be set between 0 and 3. Default: 1.0
- `sample_rate`: Sampling rate, choices: [0, 8000, 16000], the default is the same as the model. Default: 0
- `output`: Output wave filepath. Default: None, which means not to save the audio to the local.
- `output`: Client output wave filepath. Default: None, which means not to save the audio to the local.
- `play`: Whether to play audio, play while synthesizing, default value: False, which means not playing. **Playing audio needs to rely on the pyaudio library**.
- `spk_id, speed, volume, sample_rate` do not take effect in streaming speech synthesis service temporarily.
- Currently, only the single-speaker model is supported in the code, so `spk_id` does not take effect. Streaming TTS does not support changing sample rate, variable speed and volume.
Output:
```bash
......@@ -150,9 +147,6 @@ The configuration file can be found in `conf/tts_online_application.yaml`.
port=8092,
protocol="http",
spk_id=0,
speed=1.0,
volume=1.0,
sample_rate=0,
output="./output.wav",
play=False)
......@@ -256,12 +250,10 @@ The configuration file can be found in `conf/tts_online_application.yaml`.
- `protocol`: Service protocol, choices: [http, websocket], default: http.
- `input`: (required): Input text to generate.
- `spk_id`: Speaker id for multi-speaker text to speech. Default: 0
- `speed`: Audio speed, the value should be set between 0 and 3. Default: 1.0
- `volume`: Audio volume, the value should be set between 0 and 3. Default: 1.0
- `sample_rate`: Sampling rate, choices: [0, 8000, 16000], the default is the same as the model. Default: 0
- `output`: Output wave filepath. Default: None, which means not to save the audio to the local.
- `output`: Client output wave filepath. Default: None, which means not to save the audio to the local.
- `play`: Whether to play audio, play while synthesizing, default value: False, which means not playing. **Playing audio needs to rely on the pyaudio library**.
- `spk_id, speed, volume, sample_rate` do not take effect in streaming speech synthesis service temporarily.
- Currently, only the single-speaker model is supported in the code, so `spk_id` does not take effect. Streaming TTS does not support changing sample rate, variable speed and volume.
Output:
......@@ -288,9 +280,6 @@ The configuration file can be found in `conf/tts_online_application.yaml`.
port=8092,
protocol="websocket",
spk_id=0,
speed=1.0,
volume=1.0,
sample_rate=0,
output="./output.wav",
play=False)
......
......@@ -118,12 +118,9 @@
- `protocol`: 服务协议,可选 [http, websocket], 默认: http。
- `input`: (必须输入): 待合成的文本。
- `spk_id`: 说话人 id,用于多说话人语音合成,默认值: 0。
- `speed`: 音频速度,该值应设置在 0 到 3 之间。 默认值:1.0
- `volume`: 音频音量,该值应设置在 0 到 3 之间。 默认值: 1.0
- `sample_rate`: 采样率,可选 [0, 8000, 16000],默认值:0,表示与模型采样率相同
- `output`: 输出音频的路径, 默认值:None,表示不保存音频到本地。
- `output`: 客户端输出音频的路径, 默认值:None,表示不保存音频。
- `play`: 是否播放音频,边合成边播放, 默认值:False,表示不播放。**播放音频需要依赖pyaudio库**。
- `spk_id, speed, volume, sample_rate` 在流式语音合成服务中暂时不生效
- 目前代码中只支持单说话人的模型,因此 spk_id 的选择并不生效。流式 TTS 不支持更换采样率,变速和变音量等功能
输出:
......@@ -150,9 +147,6 @@
port=8092,
protocol="http",
spk_id=0,
speed=1.0,
volume=1.0,
sample_rate=0,
output="./output.wav",
play=False)
......@@ -256,12 +250,10 @@
- `protocol`: 服务协议,可选 [http, websocket], 默认: http。
- `input`: (必须输入): 待合成的文本。
- `spk_id`: 说话人 id,用于多说话人语音合成,默认值: 0。
- `speed`: 音频速度,该值应设置在 0 到 3 之间。 默认值:1.0
- `volume`: 音频音量,该值应设置在 0 到 3 之间。 默认值: 1.0
- `sample_rate`: 采样率,可选 [0, 8000, 16000],默认值:0,表示与模型采样率相同
- `output`: 输出音频的路径, 默认值:None,表示不保存音频到本地。
- `output`: 客户端输出音频的路径, 默认值:None,表示不保存音频。
- `play`: 是否播放音频,边合成边播放, 默认值:False,表示不播放。**播放音频需要依赖pyaudio库**。
- `spk_id, speed, volume, sample_rate` 在流式语音合成服务中暂时不生效。
- 目前代码中只支持单说话人的模型,因此 spk_id 的选择并不生效。流式 TTS 不支持更换采样率,变速和变音量等功能。
输出:
......@@ -288,9 +280,6 @@
port=8092,
protocol="websocket",
spk_id=0,
speed=1.0,
volume=1.0,
sample_rate=0,
output="./output.wav",
play=False)
......
......@@ -22,6 +22,7 @@ onnxruntime
pandas
paddlenlp
paddlespeech_feat
Pillow>=9.0.0
praatio==5.0.0
pypinyin
pypinyin-dict
......
......@@ -63,7 +63,7 @@ pip install paddlespeech -i https://pypi.tuna.tsinghua.edu.cn/simple
```
> If you encounter problem with downloading **nltk_data** while using paddlespeech, it maybe due to your poor network, we suggest you download the [nltk_data](https://paddlespeech.bj.bcebos.com/Parakeet/tools/nltk_data.tar.gz) provided by us, and extract it to your `${HOME}`.
> If you fail to install paddlespeech-ctcdecoders, it doesn't matter.
> If you fail to install paddlespeech-ctcdecoders, you only can not use deepspeech2 model inference. For other models, it doesn't matter.
## Medium: Get the Major Functions (Support Linux, mac and windows not support training)
If you want to get the major function of `paddlespeech`, you need to do following steps:
......
......@@ -60,7 +60,7 @@ pip install paddlespeech -i https://pypi.tuna.tsinghua.edu.cn/simple
```
> 如果您在使用 paddlespeech 的过程中遇到关于下载 **nltk_data** 的问题,可能是您的网络不佳,我们建议您下载我们提供的 [nltk_data](https://paddlespeech.bj.bcebos.com/Parakeet/tools/nltk_data.tar.gz) 并解压缩到您的 `${HOME}` 目录下。
> 如果出现 paddlespeech-ctcdecoders 无法安装的问题,无须担心,这不影响使用。
> 如果出现 paddlespeech-ctcdecoders 无法安装的问题,无须担心,这个只影响 deepspeech2 模型的推理,不影响其他模型的使用。
## 中等: 获取主要功能(支持 Linux, Mac 和 Windows 不支持训练)
如果你想要使用 `paddlespeech` 的主要功能。你需要完成以下几个步骤
......
......@@ -10,7 +10,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python |
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python |
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0464 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python |
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.0.1.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0460 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python |
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python |
[Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz)| Librispeech Dataset | Char-based | 1.3 GB | 2 Conv + 5 bidirectional LSTM layers| - |0.0467| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0) | inference/python |
[Conformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 191 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0338 | 960 h | [Conformer Librispeech ASR1](../../examples/librispeech/asr1) | python |
......
......@@ -2,13 +2,13 @@
## Conformer
paddle version: 2.2.2
paddlespeech version: 0.2.0
paddlespeech version: 1.0.1
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0530 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0495 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug| test | ctc_prefix_beam_search | - | 0.0494 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0464 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0522 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0481 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug| test | ctc_prefix_beam_search | - | 0.0480 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0460 |
## Conformer Streaming
......
......@@ -57,7 +57,7 @@ feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 64
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
......@@ -73,10 +73,10 @@ num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 2
n_epoch: 150
accum_grad: 8
global_grad_clip: 5.0
dist_sampler: True
dist_sampler: False
optim: adam
optim_conf:
lr: 0.002
......
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
......@@ -43,40 +42,42 @@ model_conf:
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
train_manifest: data/train_l/data.list
dev_manifest: data/dev/data.list
test_manifest: data/test_meeting/data.list
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
use_stream_data: True
unit_type: 'char'
vocab_filepath: data/lang_char/vocab.txt
preprocess_config: conf/preprocess.yaml
cmvn_file: data/mean_std.json
spm_model_prefix: ''
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
dither: 0.1
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 64
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 0
subsampling_factor: 1
batch_size: 32
minlen_in: 10
maxlen_in: 1200 # if input length(number of frames) > maxlen-in, data is automatically removed
minlen_out: 0
maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is automatically removed
resample_rate: 16000
shuffle_size: 1500 # read number of 'shuffle_size' data as a chunk, shuffle the data in the chunk
sort_size: 1000 # read number of 'sort_size' data as a chunk, sort the data in the chunk
num_workers: 8
prefetch_factor: 10
dist_sampler: True
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 16
n_epoch: 32
accum_grad: 32
global_grad_clip: 5.0
log_interval: 100
checkpoint:
......
......@@ -2,6 +2,8 @@
# Copyright 2021 Mobvoi Inc(Author: Di Wu, Binbin Zhang)
# NPU, ASLP Group (Author: Qijie Shao)
#
# Modified from wenet(https://github.com/wenet-e2e/wenet)
stage=-1
stop_stage=100
......@@ -30,7 +32,7 @@ mkdir -p data
TARGET_DIR=${MAIN_ROOT}/dataset
mkdir -p ${TARGET_DIR}
if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data
echo "Please follow https://github.com/wenet-e2e/WenetSpeech to download the data."
exit 0;
......@@ -44,86 +46,57 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
data || exit 1;
fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# generate manifests
python3 ${TARGET_DIR}/aishell/aishell.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/aishell"
if [ $? -ne 0 ]; then
echo "Prepare Aishell failed. Terminated."
exit 1
fi
for dataset in train dev test; do
mv data/manifest.${dataset} data/manifest.${dataset}.raw
done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# compute mean and stddev for normalizer
if $cmvn; then
full_size=`cat data/${train_set}/wav.scp | wc -l`
sampling_size=$((full_size / cmvn_sampling_divisor))
shuf -n $sampling_size data/$train_set/wav.scp \
> data/$train_set/wav.scp.sampled
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--spectrum_type="fbank" \
--feat_dim=80 \
--delta_delta=false \
--stride_ms=10 \
--window_ms=25 \
--sample_rate=16000 \
--use_dB_normalization=False \
--num_samples=-1 \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi
fi
dict=data/dict/lang_char.txt
dict=data/lang_char/vocab.txt
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# download data, generate manifests
# build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type="char" \
--count_threshold=0 \
--vocab_path="data/lang_char/vocab.txt" \
--manifest_paths "data/manifest.train.raw"
if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated."
exit 1
fi
echo "Make a dictionary"
echo "dictionary: ${dict}"
mkdir -p $(dirname $dict)
echo "<blank>" > ${dict} # 0 will be used for "blank" in CTC
echo "<unk>" >> ${dict} # <unk> must be 1
echo "▁" >> ${dict} # ▁ is for space
utils/text2token.py -s 1 -n 1 --space "▁" data/${train_set}/text \
| cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' \
| grep -v "▁" \
| awk '{print $0}' >> ${dict} \
|| exit 1;
echo "<eos>" >> $dict
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for dataset in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--cmvn_path "data/mean_std.json" \
--unit_type "char" \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.${dataset}.raw" \
--output_path="data/manifest.${dataset}"
echo "Compute cmvn"
# Here we use all the training data, you can sample some some data to save time
# BUG!!! We should use the segmented data for CMVN
if $cmvn; then
full_size=`cat data/${train_set}/wav.scp | wc -l`
sampling_size=$((full_size / cmvn_sampling_divisor))
shuf -n $sampling_size data/$train_set/wav.scp \
> data/$train_set/wav.scp.sampled
python3 utils/compute_cmvn_stats.py \
--num_workers 16 \
--train_config $train_config \
--in_scp data/$train_set/wav.scp.sampled \
--out_cmvn data/$train_set/mean_std.json \
|| exit 1;
fi
fi
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
} &
done
wait
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Making shards, please wait..."
RED='\033[0;31m'
NOCOLOR='\033[0m'
echo -e "It requires ${RED}1.2T ${NOCOLOR}space for $shards_dir, please make sure you have enough space"
echo -e "It takes about ${RED}12 ${NOCOLOR}hours with 32 threads"
for x in $dev_set $test_sets ${train_set}; do
dst=$shards_dir/$x
mkdir -p $dst
utils/make_filted_shard_list.py --num_node 1 --num_gpus_per_node 8 --num_utts_per_shard 1000 \
--do_filter --resample 16000 \
--num_threads 32 --segments data/$x/segments \
data/$x/wav.scp data/$x/text \
$(realpath $dst) data/$x/data.list
done
fi
echo "Aishell data preparation done."
echo "Wenetspeech data preparation done."
exit 0
#!/bin/bash
profiler_options=
benchmark_batch_size=0
benchmark_max_step=0
# seed may break model convergence
seed=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi
if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1
fi
config_path=$1
ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
echo ${ips_config}
mkdir -p exp
if [ ${ngpu} == 0 ]; then
python3 -u ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--seed ${seed} \
--config ${config_path} \
--output exp/${ckpt_name} \
--profiler-options "${profiler_options}" \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
else
NCCL_SOCKET_IFNAME=eth0 python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--seed ${seed} \
--config ${config_path} \
--output exp/${ckpt_name} \
--profiler-options "${profiler_options}" \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
fi
if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"
exit 1
fi
exit 0
......@@ -24,7 +24,7 @@ stage=1
prefix=
train_subset=L
. ./tools/parse_options.sh || exit 1;
. ./utils/parse_options.sh || exit 1;
filter_by_id () {
idlist=$1
......@@ -132,4 +132,4 @@ if [ $stage -le 2 ]; then
done
fi
echo "$0: Done"
\ No newline at end of file
echo "$0: Done"
......@@ -7,6 +7,7 @@ gpus=0,1,2,3,4,5,6,7
stage=0
stop_stage=100
conf_path=conf/conformer.yaml
ips= #xxx.xxx.xxx.xxx,xxx.xxx.xxx.xxx
decode_conf_path=conf/tuning/decode.yaml
average_checkpoint=true
avg_num=10
......@@ -26,7 +27,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt}
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
# flake8: noqa
from .cache import (
cached_tarfile_samples,
cached_tarfile_to_samples,
lru_cleanup,
pipe_cleaner,
)
from .compat import WebDataset, WebLoader, FluidWrapper
from .extradatasets import MockDataset, with_epoch, with_length
from .filters import (
associate,
batched,
decode,
detshuffle,
extract_keys,
getfirst,
info,
map,
map_dict,
map_tuple,
pipelinefilter,
rename,
rename_keys,
audio_resample,
select,
shuffle,
slice,
to_tuple,
transform_with,
unbatched,
xdecode,
audio_data_filter,
audio_tokenize,
audio_resample,
audio_compute_fbank,
audio_spec_aug,
sort,
audio_padding,
audio_cmvn,
placeholder,
)
from .handlers import (
ignore_and_continue,
ignore_and_stop,
reraise_exception,
warn_and_continue,
warn_and_stop,
)
from .pipeline import DataPipeline
from .shardlists import (
MultiShardSample,
ResampledShards,
SimpleShardList,
non_empty,
resampled,
shardspec,
single_node_only,
split_by_node,
split_by_worker,
)
from .tariterators import tarfile_samples, tarfile_to_samples
from .utils import PipelineStage, repeatedly
from .writer import ShardWriter, TarWriter, numpy_dumps
from .mix import RandomMix, RoundRobin
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Automatically decode webdataset samples."""
import io, json, os, pickle, re, tempfile
from functools import partial
import numpy as np
"""Extensions passed on to the image decoder."""
image_extensions = "jpg jpeg png ppm pgm pbm pnm".split()
################################################################
# handle basic datatypes
################################################################
def paddle_loads(data):
"""Load data using paddle.loads, importing paddle only if needed.
:param data: data to be decoded
"""
import io
import paddle
stream = io.BytesIO(data)
return paddle.load(stream)
def tenbin_loads(data):
from . import tenbin
return tenbin.decode_buffer(data)
def msgpack_loads(data):
import msgpack
return msgpack.unpackb(data)
def npy_loads(data):
import numpy.lib.format
stream = io.BytesIO(data)
return numpy.lib.format.read_array(stream)
def cbor_loads(data):
import cbor
return cbor.loads(data)
decoders = {
"txt": lambda data: data.decode("utf-8"),
"text": lambda data: data.decode("utf-8"),
"transcript": lambda data: data.decode("utf-8"),
"cls": lambda data: int(data),
"cls2": lambda data: int(data),
"index": lambda data: int(data),
"inx": lambda data: int(data),
"id": lambda data: int(data),
"json": lambda data: json.loads(data),
"jsn": lambda data: json.loads(data),
"pyd": lambda data: pickle.loads(data),
"pickle": lambda data: pickle.loads(data),
"pdparams": lambda data: paddle_loads(data),
"ten": tenbin_loads,
"tb": tenbin_loads,
"mp": msgpack_loads,
"msg": msgpack_loads,
"npy": npy_loads,
"npz": lambda data: np.load(io.BytesIO(data)),
"cbor": cbor_loads,
}
def basichandlers(key, data):
"""Handle basic file decoding.
This function is usually part of the post= decoders.
This handles the following forms of decoding:
- txt -> unicode string
- cls cls2 class count index inx id -> int
- json jsn -> JSON decoding
- pyd pickle -> pickle decoding
- pdparams -> paddle.loads
- ten tenbin -> fast tensor loading
- mp messagepack msg -> messagepack decoding
- npy -> Python NPY decoding
:param key: file name extension
:param data: binary data to be decoded
"""
extension = re.sub(r".*[.]", "", key)
if extension in decoders:
return decoders[extension](data)
return None
################################################################
# Generic extension handler.
################################################################
def call_extension_handler(key, data, f, extensions):
"""Call the function f with the given data if the key matches the extensions.
:param key: actual key found in the sample
:param data: binary data
:param f: decoder function
:param extensions: list of matching extensions
"""
extension = key.lower().split(".")
for target in extensions:
target = target.split(".")
if len(target) > len(extension):
continue
if extension[-len(target) :] == target:
return f(data)
return None
def handle_extension(extensions, f):
"""Return a decoder function for the list of extensions.
Extensions can be a space separated list of extensions.
Extensions can contain dots, in which case the corresponding number
of extension components must be present in the key given to f.
Comparisons are case insensitive.
Examples:
handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg
handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg
"""
extensions = extensions.lower().split()
return partial(call_extension_handler, f=f, extensions=extensions)
################################################################
# handle images
################################################################
imagespecs = {
"l8": ("numpy", "uint8", "l"),
"rgb8": ("numpy", "uint8", "rgb"),
"rgba8": ("numpy", "uint8", "rgba"),
"l": ("numpy", "float", "l"),
"rgb": ("numpy", "float", "rgb"),
"rgba": ("numpy", "float", "rgba"),
"paddlel8": ("paddle", "uint8", "l"),
"paddlergb8": ("paddle", "uint8", "rgb"),
"paddlergba8": ("paddle", "uint8", "rgba"),
"paddlel": ("paddle", "float", "l"),
"paddlergb": ("paddle", "float", "rgb"),
"paddle": ("paddle", "float", "rgb"),
"paddlergba": ("paddle", "float", "rgba"),
"pill": ("pil", None, "l"),
"pil": ("pil", None, "rgb"),
"pilrgb": ("pil", None, "rgb"),
"pilrgba": ("pil", None, "rgba"),
}
class ImageHandler:
"""Decode image data using the given `imagespec`.
The `imagespec` specifies whether the image is decoded
to numpy/paddle/pi, decoded to uint8/float, and decoded
to l/rgb/rgba:
- l8: numpy uint8 l
- rgb8: numpy uint8 rgb
- rgba8: numpy uint8 rgba
- l: numpy float l
- rgb: numpy float rgb
- rgba: numpy float rgba
- paddlel8: paddle uint8 l
- paddlergb8: paddle uint8 rgb
- paddlergba8: paddle uint8 rgba
- paddlel: paddle float l
- paddlergb: paddle float rgb
- paddle: paddle float rgb
- paddlergba: paddle float rgba
- pill: pil None l
- pil: pil None rgb
- pilrgb: pil None rgb
- pilrgba: pil None rgba
"""
def __init__(self, imagespec, extensions=image_extensions):
"""Create an image handler.
:param imagespec: short string indicating the type of decoding
:param extensions: list of extensions the image handler is invoked for
"""
if imagespec not in list(imagespecs.keys()):
raise ValueError("Unknown imagespec: %s" % imagespec)
self.imagespec = imagespec.lower()
self.extensions = extensions
def __call__(self, key, data):
"""Perform image decoding.
:param key: file name extension
:param data: binary data
"""
import PIL.Image
extension = re.sub(r".*[.]", "", key)
if extension.lower() not in self.extensions:
return None
imagespec = self.imagespec
atype, etype, mode = imagespecs[imagespec]
with io.BytesIO(data) as stream:
img = PIL.Image.open(stream)
img.load()
img = img.convert(mode.upper())
if atype == "pil":
return img
elif atype == "numpy":
result = np.asarray(img)
if result.dtype != np.uint8:
raise ValueError("ImageHandler: numpy image must be uint8")
if etype == "uint8":
return result
else:
return result.astype("f") / 255.0
elif atype == "paddle":
import paddle
result = np.asarray(img)
if result.dtype != np.uint8:
raise ValueError("ImageHandler: paddle image must be uint8")
if etype == "uint8":
result = np.array(result.transpose(2, 0, 1))
return paddle.tensor(result)
else:
result = np.array(result.transpose(2, 0, 1))
return paddle.tensor(result) / 255.0
return None
def imagehandler(imagespec, extensions=image_extensions):
"""Create an image handler.
This is just a lower case alias for ImageHander.
:param imagespec: textual image spec
:param extensions: list of extensions the handler should be applied for
"""
return ImageHandler(imagespec, extensions)
################################################################
# torch video
################################################################
'''
def torch_video(key, data):
"""Decode video using the torchvideo library.
:param key: file name extension
:param data: data to be decoded
"""
extension = re.sub(r".*[.]", "", key)
if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
return None
import torchvision.io
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}")
with open(fname, "wb") as stream:
stream.write(data)
return torchvision.io.read_video(fname, pts_unit="sec")
'''
################################################################
# paddlespeech.audio
################################################################
def paddle_audio(key, data):
"""Decode audio using the paddlespeech.audio library.
:param key: file name extension
:param data: data to be decoded
"""
extension = re.sub(r".*[.]", "", key)
if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]:
return None
import paddlespeech.audio
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}")
with open(fname, "wb") as stream:
stream.write(data)
return paddlespeech.audio.load(fname)
################################################################
# special class for continuing decoding
################################################################
class Continue:
"""Special class for continuing decoding.
This is mostly used for decompression, as in:
def decompressor(key, data):
if key.endswith(".gz"):
return Continue(key[:-3], decompress(data))
return None
"""
def __init__(self, key, data):
"""__init__.
:param key:
:param data:
"""
self.key, self.data = key, data
def gzfilter(key, data):
"""Decode .gz files.
This decodes compressed files and the continues decoding.
:param key: file name extension
:param data: binary data
"""
import gzip
if not key.endswith(".gz"):
return None
decompressed = gzip.open(io.BytesIO(data)).read()
return Continue(key[:-3], decompressed)
################################################################
# decode entire training amples
################################################################
default_pre_handlers = [gzfilter]
default_post_handlers = [basichandlers]
class Decoder:
"""Decode samples using a list of handlers.
For each key/data item, this iterates through the list of
handlers until some handler returns something other than None.
"""
def __init__(self, handlers, pre=None, post=None, only=None, partial=False):
"""Create a Decoder.
:param handlers: main list of handlers
:param pre: handlers called before the main list (.gz handler by default)
:param post: handlers called after the main list (default handlers by default)
:param only: a list of extensions; when give, only ignores files with those extensions
:param partial: allow partial decoding (i.e., don't decode fields that aren't of type bytes)
"""
if isinstance(only, str):
only = only.split()
self.only = only if only is None else set(only)
if pre is None:
pre = default_pre_handlers
if post is None:
post = default_post_handlers
assert all(callable(h) for h in handlers), f"one of {handlers} not callable"
assert all(callable(h) for h in pre), f"one of {pre} not callable"
assert all(callable(h) for h in post), f"one of {post} not callable"
self.handlers = pre + handlers + post
self.partial = partial
def decode1(self, key, data):
"""Decode a single field of a sample.
:param key: file name extension
:param data: binary data
"""
key = "." + key
for f in self.handlers:
result = f(key, data)
if isinstance(result, Continue):
key, data = result.key, result.data
continue
if result is not None:
return result
return data
def decode(self, sample):
"""Decode an entire sample.
:param sample: the sample, a dictionary of key value pairs
"""
result = {}
assert isinstance(sample, dict), sample
for k, v in list(sample.items()):
if k[0] == "_":
if isinstance(v, bytes):
v = v.decode("utf-8")
result[k] = v
continue
if self.only is not None and k not in self.only:
result[k] = v
continue
assert v is not None
if self.partial:
if isinstance(v, bytes):
result[k] = self.decode1(k, v)
else:
result[k] = v
else:
assert isinstance(v, bytes)
result[k] = self.decode1(k, v)
return result
def __call__(self, sample):
"""Decode an entire sample.
:param sample: the sample
"""
assert isinstance(sample, dict), (len(sample), sample)
return self.decode(sample)
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
import itertools, os, random, re, sys
from urllib.parse import urlparse
from . import filters
from . import gopen
from .handlers import reraise_exception
from .tariterators import tar_file_and_group_expander
default_cache_dir = os.environ.get("WDS_CACHE", "./_cache")
default_cache_size = float(os.environ.get("WDS_CACHE_SIZE", "1e18"))
def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False):
"""Performs cleanup of the file cache in cache_dir using an LRU strategy,
keeping the total size of all remaining files below cache_size."""
if not os.path.exists(cache_dir):
return
total_size = 0
for dirpath, dirnames, filenames in os.walk(cache_dir):
for filename in filenames:
total_size += os.path.getsize(os.path.join(dirpath, filename))
if total_size <= cache_size:
return
# sort files by last access time
files = []
for dirpath, dirnames, filenames in os.walk(cache_dir):
for filename in filenames:
files.append(os.path.join(dirpath, filename))
files.sort(key=keyfn, reverse=True)
# delete files until we're under the cache size
while len(files) > 0 and total_size > cache_size:
fname = files.pop()
total_size -= os.path.getsize(fname)
if verbose:
print("# deleting %s" % fname, file=sys.stderr)
os.remove(fname)
def download(url, dest, chunk_size=1024 ** 2, verbose=False):
"""Download a file from `url` to `dest`."""
temp = dest + f".temp{os.getpid()}"
with gopen.gopen(url) as stream:
with open(temp, "wb") as f:
while True:
data = stream.read(chunk_size)
if not data:
break
f.write(data)
os.rename(temp, dest)
def pipe_cleaner(spec):
"""Guess the actual URL from a "pipe:" specification."""
if spec.startswith("pipe:"):
spec = spec[5:]
words = spec.split(" ")
for word in words:
if re.match(r"^(https?|gs|ais|s3)", word):
return word
return spec
def get_file_cached(
spec,
cache_size=-1,
cache_dir=None,
url_to_name=pipe_cleaner,
verbose=False,
):
if cache_size == -1:
cache_size = default_cache_size
if cache_dir is None:
cache_dir = default_cache_dir
url = url_to_name(spec)
parsed = urlparse(url)
dirname, filename = os.path.split(parsed.path)
dirname = dirname.lstrip("/")
dirname = re.sub(r"[:/|;]", "_", dirname)
destdir = os.path.join(cache_dir, dirname)
os.makedirs(destdir, exist_ok=True)
dest = os.path.join(cache_dir, dirname, filename)
if not os.path.exists(dest):
if verbose:
print("# downloading %s to %s" % (url, dest), file=sys.stderr)
lru_cleanup(cache_dir, cache_size, verbose=verbose)
download(spec, dest, verbose=verbose)
return dest
def get_filetype(fname):
with os.popen("file '%s'" % fname) as f:
ftype = f.read()
return ftype
def check_tar_format(fname):
"""Check whether a file is a tar archive."""
ftype = get_filetype(fname)
return "tar archive" in ftype or "gzip compressed" in ftype
verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0"))
def cached_url_opener(
data,
handler=reraise_exception,
cache_size=-1,
cache_dir=None,
url_to_name=pipe_cleaner,
validator=check_tar_format,
verbose=False,
always=False,
):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
verbose = verbose or verbose_cache
for sample in data:
assert isinstance(sample, dict), sample
assert "url" in sample
url = sample["url"]
attempts = 5
try:
if not always and os.path.exists(url):
dest = url
else:
dest = get_file_cached(
url,
cache_size=cache_size,
cache_dir=cache_dir,
url_to_name=url_to_name,
verbose=verbose,
)
if verbose:
print("# opening %s" % dest, file=sys.stderr)
assert os.path.exists(dest)
if not validator(dest):
ftype = get_filetype(dest)
with open(dest, "rb") as f:
data = f.read(200)
os.remove(dest)
raise ValueError(
"%s (%s) is not a tar archive, but a %s, contains %s"
% (dest, url, ftype, repr(data))
)
try:
stream = open(dest, "rb")
sample.update(stream=stream)
yield sample
except FileNotFoundError as exn:
# dealing with race conditions in lru_cleanup
attempts -= 1
if attempts > 0:
time.sleep(random.random() * 10)
continue
raise exn
except Exception as exn:
exn.args = exn.args + (url,)
if handler(exn):
continue
else:
break
def cached_tarfile_samples(
src,
handler=reraise_exception,
cache_size=-1,
cache_dir=None,
verbose=False,
url_to_name=pipe_cleaner,
always=False,
):
streams = cached_url_opener(
src,
handler=handler,
cache_size=cache_size,
cache_dir=cache_dir,
verbose=verbose,
url_to_name=url_to_name,
always=always,
)
samples = tar_file_and_group_expander(streams, handler=handler)
return samples
cached_tarfile_to_samples = filters.pipelinefilter(cached_tarfile_samples)
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
from dataclasses import dataclass
from itertools import islice
from typing import List
import braceexpand, yaml
from . import autodecode
from . import cache, filters, shardlists, tariterators
from .filters import reraise_exception
from .pipeline import DataPipeline
from .paddle_utils import DataLoader, IterableDataset
class FluidInterface:
def batched(self, batchsize):
return self.compose(filters.batched(batchsize))
def dynamic_batched(self, max_frames_in_batch):
return self.compose(filter.dynamic_batched(max_frames_in_batch))
def unbatched(self):
return self.compose(filters.unbatched())
def listed(self, batchsize, partial=True):
return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None)
def unlisted(self):
return self.compose(filters.unlisted())
def log_keys(self, logfile=None):
return self.compose(filters.log_keys(logfile))
def shuffle(self, size, **kw):
if size < 1:
return self
else:
return self.compose(filters.shuffle(size, **kw))
def map(self, f, handler=reraise_exception):
return self.compose(filters.map(f, handler=handler))
def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception):
handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
return self.map(decoder, handler=handler)
def map_dict(self, handler=reraise_exception, **kw):
return self.compose(filters.map_dict(handler=handler, **kw))
def select(self, predicate, **kw):
return self.compose(filters.select(predicate, **kw))
def to_tuple(self, *args, handler=reraise_exception):
return self.compose(filters.to_tuple(*args, handler=handler))
def map_tuple(self, *args, handler=reraise_exception):
return self.compose(filters.map_tuple(*args, handler=handler))
def slice(self, *args):
return self.compose(filters.slice(*args))
def rename(self, **kw):
return self.compose(filters.rename(**kw))
def rsample(self, p=0.5):
return self.compose(filters.rsample(p))
def rename_keys(self, *args, **kw):
return self.compose(filters.rename_keys(*args, **kw))
def extract_keys(self, *args, **kw):
return self.compose(filters.extract_keys(*args, **kw))
def xdecode(self, *args, **kw):
return self.compose(filters.xdecode(*args, **kw))
def audio_data_filter(self, *args, **kw):
return self.compose(filters.audio_data_filter(*args, **kw))
def audio_tokenize(self, *args, **kw):
return self.compose(filters.audio_tokenize(*args, **kw))
def resample(self, *args, **kw):
return self.compose(filters.resample(*args, **kw))
def audio_compute_fbank(self, *args, **kw):
return self.compose(filters.audio_compute_fbank(*args, **kw))
def audio_spec_aug(self, *args, **kw):
return self.compose(filters.audio_spec_aug(*args, **kw))
def sort(self, size=500):
return self.compose(filters.sort(size))
def audio_padding(self):
return self.compose(filters.audio_padding())
def audio_cmvn(self, cmvn_file):
return self.compose(filters.audio_cmvn(cmvn_file))
class WebDataset(DataPipeline, FluidInterface):
"""Small fluid-interface wrapper for DataPipeline."""
def __init__(
self,
urls,
handler=reraise_exception,
resampled=False,
repeat=False,
shardshuffle=None,
cache_size=0,
cache_dir=None,
detshuffle=False,
nodesplitter=shardlists.single_node_only,
verbose=False,
):
super().__init__()
if isinstance(urls, IterableDataset):
assert not resampled
self.append(urls)
elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
with (open(urls)) as stream:
spec = yaml.safe_load(stream)
assert "datasets" in spec
self.append(shardlists.MultiShardSample(spec))
elif isinstance(urls, dict):
assert "datasets" in urls
self.append(shardlists.MultiShardSample(urls))
elif resampled:
self.append(shardlists.ResampledShards(urls))
else:
self.append(shardlists.SimpleShardList(urls))
self.append(nodesplitter)
self.append(shardlists.split_by_worker)
if shardshuffle is True:
shardshuffle = 100
if shardshuffle is not None:
if detshuffle:
self.append(filters.detshuffle(shardshuffle))
else:
self.append(filters.shuffle(shardshuffle))
if cache_size == 0:
self.append(tariterators.tarfile_to_samples(handler=handler))
else:
assert cache_size == -1 or cache_size > 0
self.append(
cache.cached_tarfile_to_samples(
handler=handler,
verbose=verbose,
cache_size=cache_size,
cache_dir=cache_dir,
)
)
class FluidWrapper(DataPipeline, FluidInterface):
"""Small fluid-interface wrapper for DataPipeline."""
def __init__(self, initial):
super().__init__()
self.append(initial)
class WebLoader(DataPipeline, FluidInterface):
def __init__(self, *args, **kw):
super().__init__(DataLoader(*args, **kw))
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
import itertools as itt
import os
import random
import sys
import braceexpand
from . import utils
from .paddle_utils import IterableDataset
from .utils import PipelineStage
class MockDataset(IterableDataset):
"""MockDataset.
A mock dataset for performance testing and unit testing.
"""
def __init__(self, sample, length):
"""Create a mock dataset instance.
:param sample: the sample to be returned repeatedly
:param length: the length of the mock dataset
"""
self.sample = sample
self.length = length
def __iter__(self):
"""Return an iterator over this mock dataset."""
for i in range(self.length):
yield self.sample
class repeatedly(IterableDataset, PipelineStage):
"""Repeatedly yield samples from a dataset."""
def __init__(self, source, nepochs=None, nbatches=None, length=None):
"""Create an instance of Repeatedly.
:param nepochs: repeat for a maximum of nepochs
:param nbatches: repeat for a maximum of nbatches
"""
self.source = source
self.length = length
self.nbatches = nbatches
def invoke(self, source):
"""Return an iterator that iterates repeatedly over a source."""
return utils.repeatedly(
source,
nepochs=self.nepochs,
nbatches=self.nbatches,
)
class with_epoch(IterableDataset):
"""Change the actual and nominal length of an IterableDataset.
This will continuously iterate through the original dataset, but
impose new epoch boundaries at the given length/nominal.
This exists mainly as a workaround for the odd logic in DataLoader.
It is also useful for choosing smaller nominal epoch sizes with
very large datasets.
"""
def __init__(self, dataset, length):
"""Chop the dataset to the given length.
:param dataset: IterableDataset
:param length: declared length of the dataset
:param nominal: nominal length of dataset (if different from declared)
"""
super().__init__()
self.length = length
self.source = None
def __getstate__(self):
"""Return the pickled state of the dataset.
This resets the dataset iterator, since that can't be pickled.
"""
result = dict(self.__dict__)
result["source"] = None
return result
def invoke(self, dataset):
"""Return an iterator over the dataset.
This iterator returns as many samples as given by the `length`
parameter.
"""
if self.source is None:
self.source = iter(dataset)
for i in range(self.length):
try:
sample = next(self.source)
except StopIteration:
self.source = iter(dataset)
try:
sample = next(self.source)
except StopIteration:
return
yield sample
self.source = None
class with_length(IterableDataset, PipelineStage):
"""Repeatedly yield samples from a dataset."""
def __init__(self, dataset, length):
"""Create an instance of Repeatedly.
:param dataset: source dataset
:param length: stated length
"""
super().__init__()
self.dataset = dataset
self.length = length
def invoke(self, dataset):
"""Return an iterator that iterates repeatedly over a source."""
return iter(dataset)
def __len__(self):
"""Return the user specified length."""
return self.length
此差异已折叠。
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Open URLs by calling subcommands."""
import os, sys, re
from subprocess import PIPE, Popen
from urllib.parse import urlparse
# global used for printing additional node information during verbose output
info = {}
class Pipe:
"""Wrapper class for subprocess.Pipe.
This class looks like a stream from the outside, but it checks
subprocess status and handles timeouts with exceptions.
This way, clients of the class do not need to know that they are
dealing with subprocesses.
:param *args: passed to `subprocess.Pipe`
:param **kw: passed to `subprocess.Pipe`
:param timeout: timeout for closing/waiting
:param ignore_errors: don't raise exceptions on subprocess errors
:param ignore_status: list of status codes to ignore
"""
def __init__(
self,
*args,
mode=None,
timeout=7200.0,
ignore_errors=False,
ignore_status=[],
**kw,
):
"""Create an IO Pipe."""
self.ignore_errors = ignore_errors
self.ignore_status = [0] + ignore_status
self.timeout = timeout
self.args = (args, kw)
if mode[0] == "r":
self.proc = Popen(*args, stdout=PIPE, **kw)
self.stream = self.proc.stdout
if self.stream is None:
raise ValueError(f"{args}: couldn't open")
elif mode[0] == "w":
self.proc = Popen(*args, stdin=PIPE, **kw)
self.stream = self.proc.stdin
if self.stream is None:
raise ValueError(f"{args}: couldn't open")
self.status = None
def __str__(self):
return f"<Pipe {self.args}>"
def check_status(self):
"""Poll the process and handle any errors."""
status = self.proc.poll()
if status is not None:
self.wait_for_child()
def wait_for_child(self):
"""Check the status variable and raise an exception if necessary."""
verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
if self.status is not None and verbose:
# print(f"(waiting again [{self.status} {os.getpid()}:{self.proc.pid}])", file=sys.stderr)
return
self.status = self.proc.wait()
if verbose:
print(
f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}",
file=sys.stderr,
)
if self.status not in self.ignore_status and not self.ignore_errors:
raise Exception(f"{self.args}: exit {self.status} (read) {info}")
def read(self, *args, **kw):
"""Wrap stream.read and checks status."""
result = self.stream.read(*args, **kw)
self.check_status()
return result
def write(self, *args, **kw):
"""Wrap stream.write and checks status."""
result = self.stream.write(*args, **kw)
self.check_status()
return result
def readLine(self, *args, **kw):
"""Wrap stream.readLine and checks status."""
result = self.stream.readLine(*args, **kw)
self.status = self.proc.poll()
self.check_status()
return result
def close(self):
"""Wrap stream.close, wait for the subprocess, and handle errors."""
self.stream.close()
self.status = self.proc.wait(self.timeout)
self.wait_for_child()
def __enter__(self):
"""Context handler."""
return self
def __exit__(self, etype, value, traceback):
"""Context handler."""
self.close()
def set_options(
obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None
):
"""Set options for Pipes.
This function can be called on any stream. It will set pipe options only
when its argument is a pipe.
:param obj: any kind of stream
:param timeout: desired timeout
:param ignore_errors: desired ignore_errors setting
:param ignore_status: desired ignore_status setting
:param handler: desired error handler
"""
if not isinstance(obj, Pipe):
return False
if timeout is not None:
obj.timeout = timeout
if ignore_errors is not None:
obj.ignore_errors = ignore_errors
if ignore_status is not None:
obj.ignore_status = ignore_status
if handler is not None:
obj.handler = handler
return True
def gopen_file(url, mode="rb", bufsize=8192):
"""Open a file.
This works for local files, files over HTTP, and pipe: files.
:param url: URL to be opened
:param mode: mode to open it with
:param bufsize: requested buffer size
"""
return open(url, mode)
def gopen_pipe(url, mode="rb", bufsize=8192):
"""Use gopen to open a pipe.
:param url: a pipe: URL
:param mode: desired mode
:param bufsize: desired buffer size
"""
assert url.startswith("pipe:")
cmd = url[5:]
if mode[0] == "r":
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141],
) # skipcq: BAN-B604
elif mode[0] == "w":
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141],
) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_curl(url, mode="rb", bufsize=8192):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if mode[0] == "r":
cmd = f"curl -s -L '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w":
cmd = f"curl -s -L -T - '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 26],
) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_htgs(url, mode="rb", bufsize=8192):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if mode[0] == "r":
url = re.sub(r"(?i)^htgs://", "gs://", url)
cmd = f"curl -s -L '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w":
raise ValueError(f"{mode}: cannot write")
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_gsutil(url, mode="rb", bufsize=8192):
"""Open a URL with `curl`.
:param url: url (usually, http:// etc.)
:param mode: file mode
:param bufsize: buffer size
"""
if mode[0] == "r":
cmd = f"gsutil cat '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w":
cmd = f"gsutil cp - '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 26],
) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_error(url, *args, **kw):
"""Raise a value error.
:param url: url
:param args: other arguments
:param kw: other keywords
"""
raise ValueError(f"{url}: no gopen handler defined")
"""A dispatch table mapping URL schemes to handlers."""
gopen_schemes = dict(
__default__=gopen_error,
pipe=gopen_pipe,
http=gopen_curl,
https=gopen_curl,
sftp=gopen_curl,
ftps=gopen_curl,
scp=gopen_curl,
gs=gopen_gsutil,
htgs=gopen_htgs,
)
def gopen(url, mode="rb", bufsize=8192, **kw):
"""Open the URL.
This uses the `gopen_schemes` dispatch table to dispatch based
on scheme.
Support for the following schemes is built-in: pipe, file,
http, https, sftp, ftps, scp.
When no scheme is given the url is treated as a file.
You can use the OPEN_VERBOSE argument to get info about
files being opened.
:param url: the source URL
:param mode: the mode ("rb", "r")
:param bufsize: the buffer size
"""
global fallback_gopen
verbose = int(os.environ.get("GOPEN_VERBOSE", 0))
if verbose:
print("GOPEN", url, info, file=sys.stderr)
assert mode in ["rb", "wb"], mode
if url == "-":
if mode == "rb":
return sys.stdin.buffer
elif mode == "wb":
return sys.stdout.buffer
else:
raise ValueError(f"unknown mode {mode}")
pr = urlparse(url)
if pr.scheme == "":
bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
return open(url, mode, buffering=bufsize)
if pr.scheme == "file":
bufsize = int(os.environ.get("GOPEN_BUFFER", -1))
return open(pr.path, mode, buffering=bufsize)
handler = gopen_schemes["__default__"]
handler = gopen_schemes.get(pr.scheme, handler)
return handler(url, mode, bufsize, **kw)
def reader(url, **kw):
"""Open url with gopen and mode "rb".
:param url: source URL
:param kw: other keywords forwarded to gopen
"""
return gopen(url, "rb", **kw)
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Pluggable exception handlers.
These are functions that take an exception as an argument and then return...
- the exception (in order to re-raise it)
- True (in order to continue and ignore the exception)
- False (in order to ignore the exception and stop processing)
They are used as handler= arguments in much of the library.
"""
import time, warnings
def reraise_exception(exn):
"""Call in an exception handler to re-raise the exception."""
raise exn
def ignore_and_continue(exn):
"""Call in an exception handler to ignore any exception and continue."""
return True
def warn_and_continue(exn):
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
warnings.warn(repr(exn))
time.sleep(0.5)
return True
def ignore_and_stop(exn):
"""Call in an exception handler to ignore any exception and stop further processing."""
return False
def warn_and_stop(exn):
"""Call in an exception handler to ignore any exception and stop further processing."""
warnings.warn(repr(exn))
time.sleep(0.5)
return False
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes for mixing samples from multiple sources."""
import itertools, os, random, time, sys
from functools import reduce, wraps
import numpy as np
from . import autodecode, utils
from .paddle_utils import PaddleTensor, IterableDataset
from .utils import PipelineStage
def round_robin_shortest(*sources):
i = 0
while True:
try:
sample = next(sources[i % len(sources)])
yield sample
except StopIteration:
break
i += 1
def round_robin_longest(*sources):
i = 0
while len(sources) > 0:
try:
sample = next(sources[i])
i += 1
yield sample
except StopIteration:
del sources[i]
class RoundRobin(IterableDataset):
def __init__(self, datasets, longest=False):
self.datasets = datasets
self.longest = longest
def __iter__(self):
"""Return an iterator over the sources."""
sources = [iter(d) for d in self.datasets]
if self.longest:
return round_robin_longest(*sources)
else:
return round_robin_shortest(*sources)
def random_samples(sources, probs=None, longest=False):
if probs is None:
probs = [1] * len(sources)
else:
probs = list(probs)
while len(sources) > 0:
cum = (np.array(probs) / np.sum(probs)).cumsum()
r = random.random()
i = np.searchsorted(cum, r)
try:
yield next(sources[i])
except StopIteration:
if longest:
del sources[i]
del probs[i]
else:
break
class RandomMix(IterableDataset):
def __init__(self, datasets, probs=None, longest=False):
self.datasets = datasets
self.probs = probs
self.longest = longest
def __iter__(self):
"""Return an iterator over the sources."""
sources = [iter(d) for d in self.datasets]
return random_samples(sources, self.probs, longest=self.longest)
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Mock implementations of paddle interfaces when paddle is not available."""
try:
from paddle.io import DataLoader, IterableDataset
except ModuleNotFoundError:
class IterableDataset:
"""Empty implementation of IterableDataset when paddle is not available."""
pass
class DataLoader:
"""Empty implementation of DataLoader when paddle is not available."""
pass
try:
from paddle import Tensor as PaddleTensor
except ModuleNotFoundError:
class TorchTensor:
"""Empty implementation of PaddleTensor when paddle is not available."""
pass
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#%%
import copy, os, random, sys, time
from dataclasses import dataclass
from itertools import islice
from typing import List
import braceexpand, yaml
from .handlers import reraise_exception
from .paddle_utils import DataLoader, IterableDataset
from .utils import PipelineStage
def add_length_method(obj):
def length(self):
return self.size
Combined = type(
obj.__class__.__name__ + "_Length",
(obj.__class__, IterableDataset),
{"__len__": length},
)
obj.__class__ = Combined
return obj
class DataPipeline(IterableDataset, PipelineStage):
"""A pipeline starting with an IterableDataset and a series of filters."""
def __init__(self, *args, **kwargs):
super().__init__()
self.pipeline = []
self.length = -1
self.repetitions = 1
self.nsamples = -1
for arg in args:
if arg is None:
continue
if isinstance(arg, list):
self.pipeline.extend(arg)
else:
self.pipeline.append(arg)
def invoke(self, f, *args, **kwargs):
"""Apply a pipeline stage, possibly to the output of a previous stage."""
if isinstance(f, PipelineStage):
return f.run(*args, **kwargs)
if isinstance(f, (IterableDataset, DataLoader)) and len(args) == 0:
return iter(f)
if isinstance(f, list):
return iter(f)
if callable(f):
result = f(*args, **kwargs)
return result
raise ValueError(f"{f}: not a valid pipeline stage")
def iterator1(self):
"""Create an iterator through one epoch in the pipeline."""
source = self.invoke(self.pipeline[0])
for step in self.pipeline[1:]:
source = self.invoke(step, source)
return source
def iterator(self):
"""Create an iterator through the entire dataset, using the given number of repetitions."""
for i in range(self.repetitions):
for sample in self.iterator1():
yield sample
def __iter__(self):
"""Create an iterator through the pipeline, repeating and slicing as requested."""
if self.repetitions != 1:
if self.nsamples > 0:
return islice(self.iterator(), self.nsamples)
else:
return self.iterator()
else:
return self.iterator()
def stage(self, i):
"""Return pipeline stage i."""
return self.pipeline[i]
def append(self, f):
"""Append a pipeline stage (modifies the object)."""
self.pipeline.append(f)
return self
def append_list(self, *args):
for arg in args:
self.pipeline.append(arg)
return self
def compose(self, *args):
"""Append a pipeline stage to a copy of the pipeline and returns the copy."""
result = copy.copy(self)
for arg in args:
result.append(arg)
return result
def with_length(self, n):
"""Add a __len__ method returning the desired value.
This does not change the actual number of samples in an epoch.
PyTorch IterableDataset should not have a __len__ method.
This is provided only as a workaround for some broken training environments
that require a __len__ method.
"""
self.size = n
return add_length_method(self)
def with_epoch(self, nsamples=-1, nbatches=-1):
"""Change the epoch to return the given number of samples/batches.
The two arguments mean the same thing."""
self.repetitions = sys.maxsize
self.nsamples = max(nsamples, nbatches)
return self
def repeat(self, nepochs=-1, nbatches=-1):
"""Repeat iterating through the dataset for the given #epochs up to the given #samples."""
if nepochs > 0:
self.repetitions = nepochs
self.nsamples = nbatches
else:
self.repetitions = sys.maxsize
self.nsamples = nbatches
return self
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
import os, random, sys, time
from dataclasses import dataclass, field
from itertools import islice
from typing import List
import braceexpand, yaml
from . import utils
from .filters import pipelinefilter
from .paddle_utils import IterableDataset
from ..utils.log import Logger
logger = Logger(__name__)
def expand_urls(urls):
if isinstance(urls, str):
urllist = urls.split("::")
result = []
for url in urllist:
result.extend(braceexpand.braceexpand(url))
return result
else:
return list(urls)
class SimpleShardList(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(self, urls, seed=None):
"""Iterate through the list of shards.
:param urls: a list of URLs as a Python list or brace notation string
"""
super().__init__()
urls = expand_urls(urls)
self.urls = urls
assert isinstance(self.urls[0], str)
self.seed = seed
def __len__(self):
return len(self.urls)
def __iter__(self):
"""Return an iterator over the shards."""
urls = self.urls.copy()
if self.seed is not None:
random.Random(self.seed).shuffle(urls)
for url in urls:
yield dict(url=url)
def split_by_node(src, group=None):
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
logger.info(f"world_size:{world_size}, rank:{rank}")
if world_size > 1:
for s in islice(src, rank, None, world_size):
yield s
else:
for s in src:
yield s
def single_node_only(src, group=None):
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
if world_size > 1:
raise ValueError("input pipeline needs to be reconfigured for multinode training")
for s in src:
yield s
def split_by_worker(src):
rank, world_size, worker, num_workers = utils.paddle_worker_info()
logger.info(f"num_workers:{num_workers}, worker:{worker}")
if num_workers > 1:
for s in islice(src, worker, None, num_workers):
yield s
else:
for s in src:
yield s
def resampled_(src, n=sys.maxsize):
import random
seed = time.time()
try:
seed = open("/dev/random", "rb").read(20)
except Exception as exn:
print(repr(exn)[:50], file=sys.stderr)
rng = random.Random(seed)
print("# resampled loading", file=sys.stderr)
items = list(src)
print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr)
for i in range(n):
yield rng.choice(items)
resampled = pipelinefilter(resampled_)
def non_empty(src):
count = 0
for s in src:
yield s
count += 1
if count == 0:
raise ValueError("pipeline stage received no data at all and this was declared as an error")
@dataclass
class MSSource:
"""Class representing a data source."""
name: str = ""
perepoch: int = -1
resample: bool = False
urls: List[str] = field(default_factory=list)
default_rng = random.Random()
def expand(s):
return os.path.expanduser(os.path.expandvars(s))
class MultiShardSample(IterableDataset):
def __init__(self, fname):
"""Construct a shardlist from multiple sources using a YAML spec."""
self.epoch = -1
class MultiShardSample(IterableDataset):
def __init__(self, fname):
"""Construct a shardlist from multiple sources using a YAML spec."""
self.epoch = -1
self.parse_spec(fname)
def parse_spec(self, fname):
self.rng = default_rng # capture default_rng if we fork
if isinstance(fname, dict):
spec = fname
fname = "{dict}"
else:
with open(fname) as stream:
spec = yaml.safe_load(stream)
assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys())
prefix = expand(spec.get("prefix", ""))
self.sources = []
for ds in spec["datasets"]:
assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list(
ds.keys()
)
buckets = ds.get("buckets", spec.get("buckets", []))
if isinstance(buckets, str):
buckets = [buckets]
buckets = [expand(s) for s in buckets]
if buckets == []:
buckets = [""]
assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented"
bucket = buckets[0]
name = ds.get("name", "@" + bucket)
urls = ds["shards"]
if isinstance(urls, str):
urls = [urls]
# urls = [u for url in urls for u in braceexpand.braceexpand(url)]
urls = [
prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url))
]
resample = ds.get("resample", -1)
nsample = ds.get("choose", -1)
if nsample > len(urls):
raise ValueError(f"perepoch {nsample} must be no greater than the number of shards")
if (nsample > 0) and (resample > 0):
raise ValueError("specify only one of perepoch or choose")
entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample)
self.sources.append(entry)
print(f"# {name} {len(urls)} {nsample}", file=sys.stderr)
def set_epoch(self, seed):
"""Set the current epoch (for consistent shard selection among nodes)."""
self.rng = random.Random(seed)
def get_shards_for_epoch(self):
result = []
for source in self.sources:
if source.resample > 0:
# sample with replacement
l = self.rng.choices(source.urls, k=source.resample)
elif source.perepoch > 0:
# sample without replacement
l = list(source.urls)
self.rng.shuffle(l)
l = l[: source.perepoch]
else:
l = list(source.urls)
result += l
self.rng.shuffle(result)
return result
def __iter__(self):
shards = self.get_shards_for_epoch()
for shard in shards:
yield dict(url=shard)
def shardspec(spec):
if spec.endswith(".yaml"):
return MultiShardSample(spec)
else:
return SimpleShardList(spec)
class ResampledShards(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(
self,
urls,
nshards=sys.maxsize,
worker_seed=None,
deterministic=False,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
"""
super().__init__()
urls = expand_urls(urls)
self.urls = urls
assert isinstance(self.urls[0], str)
self.nshards = nshards
self.worker_seed = utils.paddle_worker_seed if worker_seed is None else worker_seed
self.deterministic = deterministic
self.epoch = -1
def __iter__(self):
"""Return an iterator over the shards."""
self.epoch += 1
if self.deterministic:
seed = utils.make_seed(self.worker_seed(), self.epoch)
else:
seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4))
if os.environ.get("WDS_SHOW_SEED", "0") == "1":
print(f"# ResampledShards seed {seed}")
self.rng = random.Random(seed)
for _ in range(self.nshards):
index = self.rng.randint(0, len(self.urls) - 1)
yield dict(url=self.urls[index])
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Low level iteration functions for tar archives."""
import random, re, tarfile
import braceexpand
from . import filters
from . import gopen
from .handlers import reraise_exception
trace = False
meta_prefix = "__"
meta_suffix = "__"
import paddlespeech
import paddle
import numpy as np
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
def base_plus_ext(path):
"""Split off all file extensions.
Returns base, allext.
:param path: path with extensions
:param returns: path with all extensions removed
"""
match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
if not match:
return None, None
return match.group(1), match.group(2)
def valid_sample(sample):
"""Check whether a sample is valid.
:param sample: sample to be checked
"""
return (
sample is not None
and isinstance(sample, dict)
and len(list(sample.keys())) > 0
and not sample.get("__bad__", False)
)
# FIXME: UNUSED
def shardlist(urls, *, shuffle=False):
"""Given a list of URLs, yields that list, possibly shuffled."""
if isinstance(urls, str):
urls = braceexpand.braceexpand(urls)
else:
urls = list(urls)
if shuffle:
random.shuffle(urls)
for url in urls:
yield dict(url=url)
def url_opener(data, handler=reraise_exception, **kw):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
for sample in data:
assert isinstance(sample, dict), sample
assert "url" in sample
url = sample["url"]
try:
stream = gopen.gopen(url, **kw)
sample.update(stream=stream)
yield sample
except Exception as exn:
exn.args = exn.args + (url,)
if handler(exn):
continue
else:
break
def tar_file_iterator(
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
:param fileobj: byte stream suitable for tarfile
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
"""
stream = tarfile.open(fileobj=fileobj, mode="r:*")
for tarinfo in stream:
fname = tarinfo.name
try:
if not tarinfo.isreg():
continue
if fname is None:
continue
if (
"/" not in fname
and fname.startswith(meta_prefix)
and fname.endswith(meta_suffix)
):
# skipping metadata for now
continue
if skip_meta is not None and re.match(skip_meta, fname):
continue
name = tarinfo.name
pos = name.rfind('.')
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if postfix == 'wav':
waveform, sample_rate = paddlespeech.audio.load(stream.extractfile(tarinfo), normal=False)
result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate)
else:
txt = stream.extractfile(tarinfo).read().decode('utf8').strip()
result = dict(fname=prefix, txt=txt)
#result = dict(fname=fname, data=data)
yield result
stream.members = []
except Exception as exn:
if hasattr(exn, "args") and len(exn.args) > 0:
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
if handler(exn):
continue
else:
break
del stream
def tar_file_and_group_iterator(
fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
):
""" Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
Args:
data: Iterable[{src, stream}]
Returns:
Iterable[{key, wav, txt, sample_rate}]
"""
stream = tarfile.open(fileobj=fileobj, mode="r:*")
prev_prefix = None
example = {}
valid = True
for tarinfo in stream:
name = tarinfo.name
pos = name.rfind('.')
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if prev_prefix is not None and prefix != prev_prefix:
example['fname'] = prev_prefix
if valid:
yield example
example = {}
valid = True
with stream.extractfile(tarinfo) as file_obj:
try:
if postfix == 'txt':
example['txt'] = file_obj.read().decode('utf8').strip()
elif postfix in AUDIO_FORMAT_SETS:
waveform, sample_rate = paddlespeech.audio.load(file_obj, normal=False)
waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32)
example['wav'] = waveform
example['sample_rate'] = sample_rate
else:
example[postfix] = file_obj.read()
except Exception as exn:
if hasattr(exn, "args") and len(exn.args) > 0:
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
if handler(exn):
continue
else:
break
valid = False
# logging.warning('error to parse {}'.format(name))
prev_prefix = prefix
if prev_prefix is not None:
example['fname'] = prev_prefix
yield example
stream.close()
def tar_file_expander(data, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for source in data:
url = source["url"]
try:
assert isinstance(source, dict)
assert "stream" in source
for sample in tar_file_iterator(source["stream"]):
assert (
isinstance(sample, dict) and "data" in sample and "fname" in sample
)
sample["__url__"] = url
yield sample
except Exception as exn:
exn.args = exn.args + (source.get("stream"), source.get("url"))
if handler(exn):
continue
else:
break
def tar_file_and_group_expander(data, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for source in data:
url = source["url"]
try:
assert isinstance(source, dict)
assert "stream" in source
for sample in tar_file_and_group_iterator(source["stream"]):
assert (
isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample
)
sample["__url__"] = url
yield sample
except Exception as exn:
exn.args = exn.args + (source.get("stream"), source.get("url"))
if handler(exn):
continue
else:
break
def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample = None
for filesample in data:
assert isinstance(filesample, dict)
fname, value = filesample["fname"], filesample["data"]
prefix, suffix = keys(fname)
if trace:
print(
prefix,
suffix,
current_sample.keys() if isinstance(current_sample, dict) else None,
)
if prefix is None:
continue
if lcase:
suffix = suffix.lower()
if current_sample is None or prefix != current_sample["__key__"]:
if valid_sample(current_sample):
yield current_sample
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
if suffix in current_sample:
raise ValueError(
f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}"
)
if suffixes is None or suffix in suffixes:
current_sample[suffix] = value
if valid_sample(current_sample):
yield current_sample
def tarfile_samples(src, handler=reraise_exception):
streams = url_opener(src, handler=handler)
samples = tar_file_and_group_expander(streams, handler=handler)
return samples
tarfile_to_samples = filters.pipelinefilter(tarfile_samples)
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Miscellaneous utility functions."""
import importlib
import itertools as itt
import os
import re
import sys
from typing import Any, Callable, Iterator, Optional, Union
from ..utils.log import Logger
logger = Logger(__name__)
def make_seed(*args):
seed = 0
for arg in args:
seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
return seed
class PipelineStage:
def invoke(self, *args, **kw):
raise NotImplementedError
def identity(x: Any) -> Any:
"""Return the argument as is."""
return x
def safe_eval(s: str, expr: str = "{}"):
"""Evaluate the given expression more safely."""
if re.sub("[^A-Za-z0-9_]", "", s) != s:
raise ValueError(f"safe_eval: illegal characters in: '{s}'")
return eval(expr.format(s))
def lookup_sym(sym: str, modules: list):
"""Look up a symbol in a list of modules."""
for mname in modules:
module = importlib.import_module(mname, package="webdataset")
result = getattr(module, sym, None)
if result is not None:
return result
return None
def repeatedly0(
loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
):
"""Repeatedly returns batches from a DataLoader."""
for epoch in range(nepochs):
for sample in itt.islice(loader, nbatches):
yield sample
def guess_batchsize(batch: Union[tuple, list]):
"""Guess the batch size by looking at the length of the first element in a tuple."""
return len(batch[0])
def repeatedly(
source: Iterator,
nepochs: int = None,
nbatches: int = None,
nsamples: int = None,
batchsize: Callable[..., int] = guess_batchsize,
):
"""Repeatedly yield samples from an iterator."""
epoch = 0
batch = 0
total = 0
while True:
for sample in source:
yield sample
batch += 1
if nbatches is not None and batch >= nbatches:
return
if nsamples is not None:
total += guess_batchsize(sample)
if total >= nsamples:
return
epoch += 1
if nepochs is not None and epoch >= nepochs:
return
def paddle_worker_info(group=None):
"""Return node and worker info for PyTorch and some distributed environments."""
rank = 0
world_size = 1
worker = 0
num_workers = 1
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
else:
try:
import paddle.distributed
group = group or paddle.distributed.get_group()
rank = paddle.distributed.get_rank()
world_size = paddle.distributed.get_world_size()
except ModuleNotFoundError:
pass
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
worker = int(os.environ["WORKER"])
num_workers = int(os.environ["NUM_WORKERS"])
else:
try:
from paddle.io import get_worker_info
worker_info = paddle.io.get_worker_info()
if worker_info is not None:
worker = worker_info.id
num_workers = worker_info.num_workers
except ModuleNotFoundError as E:
logger.info(f"not found {E}")
exit(-1)
return rank, world_size, worker, num_workers
def paddle_worker_seed(group=None):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank, world_size, worker, num_workers = paddle_worker_info(group=group)
return rank * 1000 + worker
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the text featurizer class."""
from pprint import pformat
from typing import Union
import sentencepiece as spm
from .utility import BLANK
from .utility import EOS
from .utility import load_dict
from .utility import MASKCTC
from .utility import SOS
from .utility import SPACE
from .utility import UNK
from ..utils.log import Logger
logger = Logger(__name__)
__all__ = ["TextFeaturizer"]
class TextFeaturizer():
def __init__(self, unit_type, vocab, spm_model_prefix=None, maskctc=False):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
a list of token indices. Note that the token indexing order follows the
given vocabulary file.
Args:
unit_type (str): unit type, e.g. char, word, spm
vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list.
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
"""
assert unit_type in ('char', 'spm', 'word')
self.unit_type = unit_type
self.unk = UNK
self.maskctc = maskctc
if vocab:
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file(
vocab, maskctc)
self.vocab_size = len(self.vocab_list)
else:
logger.warning("TextFeaturizer: not have vocab file or vocab list.")
if unit_type == 'spm':
spm_model = spm_model_prefix + '.model'
self.sp = spm.SentencePieceProcessor()
self.sp.Load(spm_model)
def tokenize(self, text, replace_space=True):
if self.unit_type == 'char':
tokens = self.char_tokenize(text, replace_space)
elif self.unit_type == 'word':
tokens = self.word_tokenize(text)
else: # spm
tokens = self.spm_tokenize(text)
return tokens
def detokenize(self, tokens):
if self.unit_type == 'char':
text = self.char_detokenize(tokens)
elif self.unit_type == 'word':
text = self.word_detokenize(tokens)
else: # spm
text = self.spm_detokenize(tokens)
return text
def featurize(self, text):
"""Convert text string to a list of token indices.
Args:
text (str): Text to process.
Returns:
List[int]: List of token indices.
"""
tokens = self.tokenize(text)
ids = []
for token in tokens:
if token not in self.vocab_dict:
logger.debug(f"Text Token: {token} -> {self.unk}")
token = self.unk
ids.append(self.vocab_dict[token])
return ids
def defeaturize(self, idxs):
"""Convert a list of token indices to text string,
ignore index after eos_id.
Args:
idxs (List[int]): List of token indices.
Returns:
str: Text.
"""
tokens = []
for idx in idxs:
if idx == self.eos_id:
break
tokens.append(self._id2token[idx])
text = self.detokenize(tokens)
return text
def char_tokenize(self, text, replace_space=True):
"""Character tokenizer.
Args:
text (str): text string.
replace_space (bool): False only used by build_vocab.py.
Returns:
List[str]: tokens.
"""
text = text.strip()
if replace_space:
text_list = [SPACE if item == " " else item for item in list(text)]
else:
text_list = list(text)
return text_list
def char_detokenize(self, tokens):
"""Character detokenizer.
Args:
tokens (List[str]): tokens.
Returns:
str: text string.
"""
tokens = [t.replace(SPACE, " ") for t in tokens]
return "".join(tokens)
def word_tokenize(self, text):
"""Word tokenizer, separate by <space>."""
return text.strip().split()
def word_detokenize(self, tokens):
"""Word detokenizer, separate by <space>."""
return " ".join(tokens)
def spm_tokenize(self, text):
"""spm tokenize.
Args:
text (str): text string.
Returns:
List[str]: sentence pieces str code
"""
stats = {"num_empty": 0, "num_filtered": 0}
def valid(line):
return True
def encode(l):
return self.sp.EncodeAsPieces(l)
def encode_line(line):
line = line.strip()
if len(line) > 0:
line = encode(line)
if valid(line):
return line
else:
stats["num_filtered"] += 1
else:
stats["num_empty"] += 1
return None
enc_line = encode_line(text)
return enc_line
def spm_detokenize(self, tokens, input_format='piece'):
"""spm detokenize.
Args:
ids (List[str]): tokens.
Returns:
str: text
"""
if input_format == "piece":
def decode(l):
return "".join(self.sp.DecodePieces(l))
elif input_format == "id":
def decode(l):
return "".join(self.sp.DecodeIds(l))
return decode(tokens)
def _load_vocabulary_from_file(self, vocab: Union[str, list],
maskctc: bool):
"""Load vocabulary from file."""
if isinstance(vocab, list):
vocab_list = vocab
else:
vocab_list = load_dict(vocab, maskctc)
assert vocab_list is not None
logger.debug(f"Vocab: {pformat(vocab_list)}")
id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)])
token2id = dict(
[(token, idx) for (idx, token) in enumerate(vocab_list)])
blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1
maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1
unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1
eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1
sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1
logger.info(f"BLANK id: {blank_id}")
logger.info(f"UNK id: {unk_id}")
logger.info(f"EOS id: {eos_id}")
logger.info(f"SOS id: {sos_id}")
logger.info(f"SPACE id: {space_id}")
logger.info(f"MASKCTC id: {maskctc_id}")
return token2id, id2token, vocab_list, unk_id, eos_id, blank_id
此差异已折叠。
......@@ -14,8 +14,8 @@
# Modified from espnet(https://github.com/espnet/espnet)
import inspect
from paddlespeech.s2t.transform.transform_interface import TransformInterface
from paddlespeech.s2t.utils.check_kwargs import check_kwargs
from paddlespeech.audio.transform.transform_interface import TransformInterface
from paddlespeech.audio.utils.check_kwargs import check_kwargs
class FuncTrans(TransformInterface):
......
......@@ -17,8 +17,97 @@ import numpy
import scipy
import soundfile
from paddlespeech.s2t.io.reader import SoundHDF5File
import io
import os
import h5py
import numpy as np
class SoundHDF5File():
"""Collecting sound files to a HDF5 file
>>> f = SoundHDF5File('a.flac.h5', mode='a')
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
>>> f['id'] = (array, 16000)
>>> array, rate = f['id']
:param: str filepath:
:param: str mode:
:param: str format: The type used when saving wav. flac, nist, htk, etc.
:param: str dtype:
"""
def __init__(self,
filepath,
mode="r+",
format=None,
dtype="int16",
**kwargs):
self.filepath = filepath
self.mode = mode
self.dtype = dtype
self.file = h5py.File(filepath, mode, **kwargs)
if format is None:
# filepath = a.flac.h5 -> format = flac
second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1]
format = second_ext[1:]
if format.upper() not in soundfile.available_formats():
# If not found, flac is selected
format = "flac"
# This format affects only saving
self.format = format
def __repr__(self):
return '<SoundHDF5 file "{}" (mode {}, format {}, type {})>'.format(
self.filepath, self.mode, self.format, self.dtype)
def create_dataset(self, name, shape=None, data=None, **kwds):
f = io.BytesIO()
array, rate = data
soundfile.write(f, array, rate, format=self.format)
self.file.create_dataset(
name, shape=shape, data=np.void(f.getvalue()), **kwds)
def __setitem__(self, name, data):
self.create_dataset(name, data=data)
def __getitem__(self, key):
data = self.file[key][()]
f = io.BytesIO(data.tobytes())
array, rate = soundfile.read(f, dtype=self.dtype)
return array, rate
def keys(self):
return self.file.keys()
def values(self):
for k in self.file:
yield self[k]
def items(self):
for k in self.file:
yield k, self[k]
def __iter__(self):
return iter(self.file)
def __contains__(self, item):
return item in self.file
def __len__(self, item):
return len(self.file)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def close(self):
self.file.close()
class SpeedPerturbation():
"""SpeedPerturbation
......@@ -469,3 +558,4 @@ class RIRConvolve():
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1)
else:
return scipy.convolve(x, rir, mode="same")
......@@ -14,12 +14,10 @@
# Modified from espnet(https://github.com/espnet/espnet)
"""Spec Augment module for preprocessing i.e., data augmentation"""
import random
import numpy
from PIL import Image
from PIL.Image import BICUBIC
from paddlespeech.s2t.transform.functional import FuncTrans
from .functional import FuncTrans
def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
......@@ -46,9 +44,10 @@ def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
warped = random.randrange(center - window, center +
window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
left = Image.fromarray(x[:center]).resize((x.shape[1], warped),
Image.BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC)
Image.BICUBIC)
if inplace:
x[:warped] = left
x[warped:] = right
......
......@@ -17,7 +17,7 @@ import numpy as np
import paddle
from python_speech_features import logfbank
import paddlespeech.audio.compliance.kaldi as kaldi
from ..compliance import kaldi
def stft(x,
......
此差异已折叠。
此差异已折叠。
......@@ -65,6 +65,7 @@ class Logger(object):
def __init__(self, name: str=None):
name = 'PaddleAudio' if not name else name
self.name = name
self.logger = logging.getLogger(name)
for key, conf in log_config.items():
......@@ -101,7 +102,7 @@ class Logger(object):
if not self.is_enable:
return
self.logger.log(log_level, msg)
self.logger.log(log_level, self.name + " | " + msg)
@contextlib.contextmanager
def use_terminator(self, terminator: str):
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -217,7 +217,7 @@ class BaseExecutor(ABC):
logging.getLogger(name) for name in logging.root.manager.loggerDict
]
for l in loggers:
l.disabled = True
l.setLevel(logging.ERROR)
def show_rtf(self, info: Dict[str, List[float]]):
"""
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册