Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
795eb7bd
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
206
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
795eb7bd
编写于
9月 01, 2022
作者:
小湉湉
提交者:
GitHub
9月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format paddlespeech with pre-commit (#2331)
上级
5d5888af
变更
61
隐藏空白更改
内联
并排
Showing
61 changed file
with
1052 addition
and
940 deletion
+1052
-940
demos/audio_searching/src/operations/load.py
demos/audio_searching/src/operations/load.py
+3
-2
demos/speech_web/API.md
demos/speech_web/API.md
+1
-1
demos/speech_web/speech_server/main.py
demos/speech_web/speech_server/main.py
+80
-80
demos/speech_web/speech_server/requirements.txt
demos/speech_web/speech_server/requirements.txt
+5
-6
demos/speech_web/speech_server/src/AudioManeger.py
demos/speech_web/speech_server/src/AudioManeger.py
+45
-42
demos/speech_web/speech_server/src/SpeechBase/asr.py
demos/speech_web/speech_server/src/SpeechBase/asr.py
+8
-10
demos/speech_web/speech_server/src/SpeechBase/nlp.py
demos/speech_web/speech_server/src/SpeechBase/nlp.py
+9
-9
demos/speech_web/speech_server/src/SpeechBase/sql_helper.py
demos/speech_web/speech_server/src/SpeechBase/sql_helper.py
+31
-25
demos/speech_web/speech_server/src/SpeechBase/tts.py
demos/speech_web/speech_server/src/SpeechBase/tts.py
+43
-49
demos/speech_web/speech_server/src/SpeechBase/vpr.py
demos/speech_web/speech_server/src/SpeechBase/vpr.py
+28
-26
demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py
demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py
+5
-4
demos/speech_web/speech_server/src/WebsocketManeger.py
demos/speech_web/speech_server/src/WebsocketManeger.py
+2
-1
demos/speech_web/speech_server/src/robot.py
demos/speech_web/speech_server/src/robot.py
+23
-21
demos/speech_web/speech_server/src/util.py
demos/speech_web/speech_server/src/util.py
+6
-11
demos/streaming_asr_server/local/rtf_from_log.py
demos/streaming_asr_server/local/rtf_from_log.py
+1
-1
docs/requirements.txt
docs/requirements.txt
+17
-18
docs/source/conf.py
docs/source/conf.py
+3
-2
examples/iwslt2012/punc0/local/preprocess.py
examples/iwslt2012/punc0/local/preprocess.py
+24
-22
examples/other/tts_finetune/tts3/finetune.py
examples/other/tts_finetune/tts3/finetune.py
+4
-5
paddlespeech/__init__.py
paddlespeech/__init__.py
+0
-2
paddlespeech/audio/__init__.py
paddlespeech/audio/__init__.py
+3
-3
paddlespeech/audio/streamdata/__init__.py
paddlespeech/audio/streamdata/__init__.py
+62
-63
paddlespeech/audio/streamdata/autodecode.py
paddlespeech/audio/streamdata/autodecode.py
+9
-10
paddlespeech/audio/streamdata/cache.py
paddlespeech/audio/streamdata/cache.py
+30
-33
paddlespeech/audio/streamdata/compat.py
paddlespeech/audio/streamdata/compat.py
+39
-29
paddlespeech/audio/streamdata/extradatasets.py
paddlespeech/audio/streamdata/extradatasets.py
+1
-12
paddlespeech/audio/streamdata/filters.py
paddlespeech/audio/streamdata/filters.py
+164
-92
paddlespeech/audio/streamdata/gopen.py
paddlespeech/audio/streamdata/gopen.py
+26
-36
paddlespeech/audio/streamdata/handlers.py
paddlespeech/audio/streamdata/handlers.py
+2
-3
paddlespeech/audio/streamdata/mix.py
paddlespeech/audio/streamdata/mix.py
+2
-7
paddlespeech/audio/streamdata/paddle_utils.py
paddlespeech/audio/streamdata/paddle_utils.py
+2
-12
paddlespeech/audio/streamdata/pipeline.py
paddlespeech/audio/streamdata/pipeline.py
+5
-9
paddlespeech/audio/streamdata/shardlists.py
paddlespeech/audio/streamdata/shardlists.py
+44
-33
paddlespeech/audio/streamdata/tariterators.py
paddlespeech/audio/streamdata/tariterators.py
+41
-40
paddlespeech/audio/streamdata/utils.py
paddlespeech/audio/streamdata/utils.py
+17
-15
paddlespeech/audio/streamdata/writer.py
paddlespeech/audio/streamdata/writer.py
+40
-37
paddlespeech/audio/text/text_featurizer.py
paddlespeech/audio/text/text_featurizer.py
+1
-1
paddlespeech/audio/transform/perturb.py
paddlespeech/audio/transform/perturb.py
+6
-5
paddlespeech/audio/transform/spec_augment.py
paddlespeech/audio/transform/spec_augment.py
+1
-0
paddlespeech/cli/executor.py
paddlespeech/cli/executor.py
+1
-1
paddlespeech/s2t/__init__.py
paddlespeech/s2t/__init__.py
+1
-0
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+14
-9
paddlespeech/s2t/exps/u2_kaldi/model.py
paddlespeech/s2t/exps/u2_kaldi/model.py
+16
-10
paddlespeech/s2t/exps/u2_st/model.py
paddlespeech/s2t/exps/u2_st/model.py
+12
-7
paddlespeech/s2t/io/dataloader.py
paddlespeech/s2t/io/dataloader.py
+80
-65
paddlespeech/s2t/models/u2_st/u2_st.py
paddlespeech/s2t/models/u2_st/u2_st.py
+6
-7
paddlespeech/s2t/modules/align.py
paddlespeech/s2t/modules/align.py
+32
-7
paddlespeech/s2t/modules/initializer.py
paddlespeech/s2t/modules/initializer.py
+1
-1
paddlespeech/server/engine/asr/online/ctc_endpoint.py
paddlespeech/server/engine/asr/online/ctc_endpoint.py
+4
-2
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
+1
-1
paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
...ch/server/engine/asr/online/paddleinference/asr_engine.py
+1
-1
paddlespeech/server/engine/asr/python/asr_engine.py
paddlespeech/server/engine/asr/python/asr_engine.py
+7
-5
paddlespeech/t2s/datasets/sampler.py
paddlespeech/t2s/datasets/sampler.py
+4
-3
paddlespeech/t2s/exps/ernie_sat/train.py
paddlespeech/t2s/exps/ernie_sat/train.py
+0
-1
paddlespeech/t2s/exps/ernie_sat/utils.py
paddlespeech/t2s/exps/ernie_sat/utils.py
+7
-4
paddlespeech/t2s/exps/syn_utils.py
paddlespeech/t2s/exps/syn_utils.py
+4
-4
paddlespeech/t2s/frontend/g2pw/__init__.py
paddlespeech/t2s/frontend/g2pw/__init__.py
+0
-1
paddlespeech/t2s/frontend/mix_frontend.py
paddlespeech/t2s/frontend/mix_frontend.py
+5
-2
paddlespeech/t2s/training/updaters/standard_updater.py
paddlespeech/t2s/training/updaters/standard_updater.py
+2
-1
setup.py
setup.py
+2
-7
speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py
speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py
+19
-24
未找到文件。
demos/audio_searching/src/operations/load.py
浏览文件 @
795eb7bd
...
...
@@ -26,8 +26,9 @@ def get_audios(path):
"""
supported_formats
=
[
".wav"
,
".mp3"
,
".ogg"
,
".flac"
,
".m4a"
]
return
[
item
for
sublist
in
[[
os
.
path
.
join
(
dir
,
file
)
for
file
in
files
]
for
dir
,
_
,
files
in
list
(
os
.
walk
(
path
))]
item
for
sublist
in
[[
os
.
path
.
join
(
dir
,
file
)
for
file
in
files
]
for
dir
,
_
,
files
in
list
(
os
.
walk
(
path
))]
for
item
in
sublist
if
os
.
path
.
splitext
(
item
)[
1
]
in
supported_formats
]
...
...
demos/speech_web/API.md
浏览文件 @
795eb7bd
...
...
@@ -401,4 +401,4 @@ curl -X 'GET' \
"code"
:
0
,
"result"
:
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
,
"message"
:
"ok"
```
\ No newline at end of file
```
demos/speech_web/speech_server/main.py
浏览文件 @
795eb7bd
...
...
@@ -3,48 +3,48 @@
# 2. 接收录音音频,返回识别结果
# 3. 接收ASR识别结果,返回NLP对话结果
# 4. 接收NLP对话结果,返回TTS音频
import
argparse
import
base64
import
yaml
import
os
import
json
import
datetime
import
json
import
os
from
typing
import
List
import
aiofiles
import
librosa
import
soundfile
as
sf
import
numpy
as
np
import
argparse
import
uvicorn
import
aiofiles
from
typing
import
Optional
,
List
from
pydantic
import
BaseModel
from
fastapi
import
FastAPI
,
Header
,
File
,
UploadFile
,
Form
,
Cookie
,
WebSocket
,
WebSocketDisconnect
from
fastapi
import
FastAPI
from
fastapi
import
File
from
fastapi
import
Form
from
fastapi
import
UploadFile
from
fastapi
import
WebSocket
from
fastapi
import
WebSocketDisconnect
from
fastapi.responses
import
StreamingResponse
from
starlette.responses
import
FileResponse
from
starlette.middleware.cors
import
CORSMiddleware
from
starlette.requests
import
Request
from
starlette.websockets
import
WebSocketState
as
WebSocketState
from
pydantic
import
BaseModel
from
src.AudioManeger
import
AudioMannger
from
src.util
import
*
from
src.robot
import
Robot
from
src.WebsocketManeger
import
ConnectionManager
from
src.SpeechBase.vpr
import
VPR
from
src.util
import
*
from
src.WebsocketManeger
import
ConnectionManager
from
starlette.middleware.cors
import
CORSMiddleware
from
starlette.requests
import
Request
from
starlette.responses
import
FileResponse
from
starlette.websockets
import
WebSocketState
as
WebSocketState
from
paddlespeech.server.engine.asr.online.python.asr_engine
import
PaddleASRConnectionHanddler
from
paddlespeech.server.utils.audio_process
import
float2pcm
# 解析配置
parser
=
argparse
.
ArgumentParser
(
prog
=
'PaddleSpeechDemo'
,
add_help
=
True
)
parser
=
argparse
.
ArgumentParser
(
prog
=
'PaddleSpeechDemo'
,
add_help
=
True
)
parser
.
add_argument
(
"--port"
,
action
=
"store"
,
type
=
int
,
help
=
"port of the app"
,
default
=
8010
,
required
=
False
)
"--port"
,
action
=
"store"
,
type
=
int
,
help
=
"port of the app"
,
default
=
8010
,
required
=
False
)
args
=
parser
.
parse_args
()
port
=
args
.
port
...
...
@@ -60,39 +60,41 @@ ie_model_path = "source/model"
UPLOAD_PATH
=
"source/vpr"
WAV_PATH
=
"source/wav"
base_sources
=
[
UPLOAD_PATH
,
WAV_PATH
]
base_sources
=
[
UPLOAD_PATH
,
WAV_PATH
]
for
path
in
base_sources
:
os
.
makedirs
(
path
,
exist_ok
=
True
)
# 初始化
app
=
FastAPI
()
chatbot
=
Robot
(
asr_config
,
tts_config
,
asr_init_path
,
ie_model_path
=
ie_model_path
)
chatbot
=
Robot
(
asr_config
,
tts_config
,
asr_init_path
,
ie_model_path
=
ie_model_path
)
manager
=
ConnectionManager
()
aumanager
=
AudioMannger
(
chatbot
)
aumanager
.
init
()
vpr
=
VPR
(
db_path
,
dim
=
192
,
top_k
=
5
)
vpr
=
VPR
(
db_path
,
dim
=
192
,
top_k
=
5
)
# 服务配置
class
NlpBase
(
BaseModel
):
chat
:
str
class
TtsBase
(
BaseModel
):
text
:
str
text
:
str
class
Audios
:
def
__init__
(
self
)
->
None
:
self
.
audios
=
b
""
audios
=
Audios
()
######################################################################
########################### ASR 服务 #################################
#####################################################################
# 接收文件,返回ASR结果
# 上传文件
@
app
.
post
(
"/asr/offline"
)
...
...
@@ -101,7 +103,8 @@ async def speech2textOffline(files: List[UploadFile]):
asr_res
=
""
for
file
in
files
[:
1
]:
# 生成时间戳
now_name
=
"asr_offline_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
now_name
=
"asr_offline_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
out_file_path
=
os
.
path
.
join
(
WAV_PATH
,
now_name
)
async
with
aiofiles
.
open
(
out_file_path
,
'wb'
)
as
out_file
:
content
=
await
file
.
read
()
# async read
...
...
@@ -110,10 +113,9 @@ async def speech2textOffline(files: List[UploadFile]):
# 返回ASR识别结果
asr_res
=
chatbot
.
speech2text
(
out_file_path
)
return
SuccessRequest
(
result
=
asr_res
)
# else:
# return ErrorRequest(message="文件不是.wav格式")
return
ErrorRequest
(
message
=
"上传文件为空"
)
# 接收文件,同时将wav强制转成16k, int16类型
@
app
.
post
(
"/asr/offlinefile"
)
async
def
speech2textOfflineFile
(
files
:
List
[
UploadFile
]):
...
...
@@ -121,7 +123,8 @@ async def speech2textOfflineFile(files: List[UploadFile]):
asr_res
=
""
for
file
in
files
[:
1
]:
# 生成时间戳
now_name
=
"asr_offline_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
now_name
=
"asr_offline_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
out_file_path
=
os
.
path
.
join
(
WAV_PATH
,
now_name
)
async
with
aiofiles
.
open
(
out_file_path
,
'wb'
)
as
out_file
:
content
=
await
file
.
read
()
# async read
...
...
@@ -132,22 +135,18 @@ async def speech2textOfflineFile(files: List[UploadFile]):
wav
=
float2pcm
(
wav
)
# float32 to int16
wav_bytes
=
wav
.
tobytes
()
# to bytes
wav_base64
=
base64
.
b64encode
(
wav_bytes
).
decode
(
'utf8'
)
# 将文件重新写入
now_name
=
now_name
[:
-
4
]
+
"_16k"
+
".wav"
out_file_path
=
os
.
path
.
join
(
WAV_PATH
,
now_name
)
sf
.
write
(
out_file_path
,
wav
,
16000
)
sf
.
write
(
out_file_path
,
wav
,
16000
)
# 返回ASR识别结果
asr_res
=
chatbot
.
speech2text
(
out_file_path
)
response_res
=
{
"asr_result"
:
asr_res
,
"wav_base64"
:
wav_base64
}
response_res
=
{
"asr_result"
:
asr_res
,
"wav_base64"
:
wav_base64
}
return
SuccessRequest
(
result
=
response_res
)
return
ErrorRequest
(
message
=
"上传文件为空"
)
return
ErrorRequest
(
message
=
"上传文件为空"
)
# 流式接收测试
...
...
@@ -161,15 +160,17 @@ async def speech2textOnlineRecive(files: List[UploadFile]):
print
(
f
"audios长度变化:
{
len
(
audios
.
audios
)
}
"
)
return
SuccessRequest
(
message
=
"接收成功"
)
# 采集环境噪音大小
@
app
.
post
(
"/asr/collectEnv"
)
async
def
collectEnv
(
files
:
List
[
UploadFile
]):
for
file
in
files
[:
1
]:
for
file
in
files
[:
1
]:
content
=
await
file
.
read
()
# async read
# 初始化, wav 前44字节是头部信息
aumanager
.
compute_env_volume
(
content
[
44
:])
vad_
=
aumanager
.
vad_threshold
return
SuccessRequest
(
result
=
vad_
,
message
=
"采集环境噪音成功"
)
return
SuccessRequest
(
result
=
vad_
,
message
=
"采集环境噪音成功"
)
# 停止录音
@
app
.
get
(
"/asr/stopRecord"
)
...
...
@@ -179,6 +180,7 @@ async def stopRecord():
print
(
"Online录音暂停"
)
return
SuccessRequest
(
message
=
"停止成功"
)
# 恢复录音
@
app
.
get
(
"/asr/resumeRecord"
)
async
def
resumeRecord
():
...
...
@@ -187,7 +189,7 @@ async def resumeRecord():
return
SuccessRequest
(
message
=
"Online录音恢复"
)
# 聊天用的ASR
# 聊天用的
ASR
@
app
.
websocket
(
"/ws/asr/offlineStream"
)
async
def
websocket_endpoint
(
websocket
:
WebSocket
):
await
manager
.
connect
(
websocket
)
...
...
@@ -210,9 +212,9 @@ async def websocket_endpoint(websocket: WebSocket):
# print(f"用户-{user}-离开")
# Online识别的
ASR
# 流式识别的
ASR
@
app
.
websocket
(
'/ws/asr/onlineStream'
)
async
def
websocket_endpoint
(
websocket
:
WebSocket
):
async
def
websocket_endpoint
_online
(
websocket
:
WebSocket
):
"""PaddleSpeech Online ASR Server api
Args:
...
...
@@ -298,12 +300,14 @@ async def websocket_endpoint(websocket: WebSocket):
except
WebSocketDisconnect
:
pass
######################################################################
########################### NLP 服务 #################################
#####################################################################
@
app
.
post
(
"/nlp/chat"
)
async
def
chatOffline
(
nlp_base
:
NlpBase
):
async
def
chatOffline
(
nlp_base
:
NlpBase
):
chat
=
nlp_base
.
chat
if
not
chat
:
return
ErrorRequest
(
message
=
"传入文本为空"
)
...
...
@@ -311,8 +315,9 @@ async def chatOffline(nlp_base:NlpBase):
res
=
chatbot
.
chat
(
chat
)
return
SuccessRequest
(
result
=
res
)
@
app
.
post
(
"/nlp/ie"
)
async
def
ieOffline
(
nlp_base
:
NlpBase
):
async
def
ieOffline
(
nlp_base
:
NlpBase
):
nlp_text
=
nlp_base
.
chat
if
not
nlp_text
:
return
ErrorRequest
(
message
=
"传入文本为空"
)
...
...
@@ -320,17 +325,20 @@ async def ieOffline(nlp_base:NlpBase):
res
=
chatbot
.
ie
(
nlp_text
)
return
SuccessRequest
(
result
=
res
)
######################################################################
########################### TTS 服务 #################################
#####################################################################
@
app
.
post
(
"/tts/offline"
)
async
def
text2speechOffline
(
tts_base
:
TtsBase
):
async
def
text2speechOffline
(
tts_base
:
TtsBase
):
text
=
tts_base
.
text
if
not
text
:
return
ErrorRequest
(
message
=
"文本为空"
)
else
:
now_name
=
"tts_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
now_name
=
"tts_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
out_file_path
=
os
.
path
.
join
(
WAV_PATH
,
now_name
)
# 保存为文件,再转成base64传输
chatbot
.
text2speech
(
text
,
outpath
=
out_file_path
)
...
...
@@ -339,12 +347,14 @@ async def text2speechOffline(tts_base:TtsBase):
base_str
=
base64
.
b64encode
(
data_bin
)
return
SuccessRequest
(
result
=
base_str
)
# http流式TTS
@
app
.
post
(
"/tts/online"
)
async
def
stream_tts
(
request_body
:
TtsBase
):
text
=
request_body
.
text
return
StreamingResponse
(
chatbot
.
text2speechStreamBytes
(
text
=
text
))
# ws流式TTS
@
app
.
websocket
(
"/ws/tts/online"
)
async
def
stream_ttsWS
(
websocket
:
WebSocket
):
...
...
@@ -356,17 +366,11 @@ async def stream_ttsWS(websocket: WebSocket):
if
text
:
for
sub_wav
in
chatbot
.
text2speechStream
(
text
=
text
):
# print("发送sub wav: ", len(sub_wav))
res
=
{
"wav"
:
sub_wav
,
"done"
:
False
}
res
=
{
"wav"
:
sub_wav
,
"done"
:
False
}
await
websocket
.
send_json
(
res
)
# 输送结束
res
=
{
"wav"
:
sub_wav
,
"done"
:
True
}
res
=
{
"wav"
:
sub_wav
,
"done"
:
True
}
await
websocket
.
send_json
(
res
)
# manager.disconnect(websocket)
...
...
@@ -396,8 +400,9 @@ async def vpr_enroll(table_name: str=None,
return
{
'status'
:
False
,
'msg'
:
"spk_id can not be None"
}
# Save the upload data to server.
content
=
await
audio
.
read
()
now_name
=
"vpr_enroll_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
audio_path
=
os
.
path
.
join
(
UPLOAD_PATH
,
now_name
)
now_name
=
"vpr_enroll_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
audio_path
=
os
.
path
.
join
(
UPLOAD_PATH
,
now_name
)
with
open
(
audio_path
,
"wb+"
)
as
f
:
f
.
write
(
content
)
...
...
@@ -413,20 +418,19 @@ async def vpr_recog(request: Request,
audio
:
UploadFile
=
File
(...)):
# Voice print recognition online
# try:
# Save the upload data to server.
# Save the upload data to server.
content
=
await
audio
.
read
()
now_name
=
"vpr_query_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
query_audio_path
=
os
.
path
.
join
(
UPLOAD_PATH
,
now_name
)
now_name
=
"vpr_query_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
query_audio_path
=
os
.
path
.
join
(
UPLOAD_PATH
,
now_name
)
with
open
(
query_audio_path
,
"wb+"
)
as
f
:
f
.
write
(
content
)
f
.
write
(
content
)
spk_ids
,
paths
,
scores
=
vpr
.
do_search_vpr
(
query_audio_path
)
res
=
dict
(
zip
(
spk_ids
,
zip
(
paths
,
scores
)))
# Sort results by distance metric, closest distances first
res
=
sorted
(
res
.
items
(),
key
=
lambda
item
:
item
[
1
][
1
],
reverse
=
True
)
return
res
# except Exception as e:
# return {'status': False, 'msg': e}, 400
@
app
.
post
(
'/vpr/del'
)
...
...
@@ -460,17 +464,18 @@ async def vpr_database64(vprId: int):
return
{
'status'
:
False
,
'msg'
:
"vpr_id can not be None"
}
audio_path
=
vpr
.
do_get_wav
(
vprId
)
# 返回base64
# 将文件转成16k, 16bit类型的wav文件
wav
,
sr
=
librosa
.
load
(
audio_path
,
sr
=
16000
)
wav
=
float2pcm
(
wav
)
# float32 to int16
wav_bytes
=
wav
.
tobytes
()
# to bytes
wav_base64
=
base64
.
b64encode
(
wav_bytes
).
decode
(
'utf8'
)
return
SuccessRequest
(
result
=
wav_base64
)
except
Exception
as
e
:
return
{
'status'
:
False
,
'msg'
:
e
},
400
@
app
.
get
(
'/vpr/data'
)
async
def
vpr_data
(
vprId
:
int
):
# Get the audio file from path by spk_id in MySQL
...
...
@@ -482,11 +487,6 @@ async def vpr_data(vprId: int):
except
Exception
as
e
:
return
{
'status'
:
False
,
'msg'
:
e
},
400
if
__name__
==
'__main__'
:
uvicorn
.
run
(
app
=
app
,
host
=
'0.0.0.0'
,
port
=
port
)
demos/speech_web/speech_server/requirements.txt
浏览文件 @
795eb7bd
aiofiles
faiss-cpu
fastapi
librosa
numpy
paddlenlp
paddlepaddle
paddlespeech
pydantic
scikit_learn
python-multipart
scikit_learn
SoundFile
starlette
uvicorn
paddlepaddle
paddlespeech
paddlenlp
faiss-cpu
python-multipart
\ No newline at end of file
demos/speech_web/speech_server/src/AudioManeger.py
浏览文件 @
795eb7bd
import
imp
from
queue
import
Queue
import
numpy
as
np
import
datetime
import
os
import
wave
import
random
import
datetime
import
numpy
as
np
from
.util
import
randName
class
AudioMannger
:
def
__init__
(
self
,
robot
,
frame_length
=
160
,
frame
=
10
,
data_width
=
2
,
vad_default
=
300
):
def
__init__
(
self
,
robot
,
frame_length
=
160
,
frame
=
10
,
data_width
=
2
,
vad_default
=
300
):
# 二进制 pcm 流
self
.
audios
=
b
''
self
.
asr_result
=
""
...
...
@@ -20,8 +24,9 @@ class AudioMannger:
os
.
makedirs
(
self
.
file_dir
,
exist_ok
=
True
)
self
.
vad_deafult
=
vad_default
self
.
vad_threshold
=
vad_default
self
.
vad_threshold_path
=
os
.
path
.
join
(
self
.
file_dir
,
"vad_threshold.npy"
)
self
.
vad_threshold_path
=
os
.
path
.
join
(
self
.
file_dir
,
"vad_threshold.npy"
)
# 10ms 一帧
self
.
frame_length
=
frame_length
# 10帧,检测一次 vad
...
...
@@ -30,67 +35,64 @@ class AudioMannger:
self
.
data_width
=
data_width
# window
self
.
window_length
=
frame_length
*
frame
*
data_width
# 是否开始录音
self
.
on_asr
=
False
self
.
silence_cnt
=
0
self
.
silence_cnt
=
0
self
.
max_silence_cnt
=
4
self
.
is_pause
=
False
# 录音暂停与恢复
def
init
(
self
):
if
os
.
path
.
exists
(
self
.
vad_threshold_path
):
# 平均响度文件存在
self
.
vad_threshold
=
np
.
load
(
self
.
vad_threshold_path
)
def
clear_audio
(
self
):
# 清空 pcm 累积片段与 asr 识别结果
self
.
audios
=
b
''
def
clear_asr
(
self
):
self
.
asr_result
=
""
def
compute_chunk_volume
(
self
,
start_index
,
pcm_bins
):
# 根据帧长计算能量平均值
pcm_bin
=
pcm_bins
[
start_index
:
start_index
+
self
.
window_length
]
pcm_bin
=
pcm_bins
[
start_index
:
start_index
+
self
.
window_length
]
# 转成 numpy
pcm_np
=
np
.
frombuffer
(
pcm_bin
,
np
.
int16
)
# 归一化 + 计算响度
x
=
pcm_np
.
astype
(
np
.
float32
)
x
=
np
.
abs
(
x
)
return
np
.
mean
(
x
)
return
np
.
mean
(
x
)
def
is_speech
(
self
,
start_index
,
pcm_bins
):
# 检查是否没
if
start_index
>
len
(
pcm_bins
):
return
False
# 检查从这个 start 开始是否为静音帧
energy
=
self
.
compute_chunk_volume
(
start_index
=
start_index
,
pcm_bins
=
pcm_bins
)
energy
=
self
.
compute_chunk_volume
(
start_index
=
start_index
,
pcm_bins
=
pcm_bins
)
# print(energy)
if
energy
>
self
.
vad_threshold
:
return
True
else
:
return
False
def
compute_env_volume
(
self
,
pcm_bins
):
max_energy
=
0
start
=
0
while
start
<
len
(
pcm_bins
):
energy
=
self
.
compute_chunk_volume
(
start_index
=
start
,
pcm_bins
=
pcm_bins
)
energy
=
self
.
compute_chunk_volume
(
start_index
=
start
,
pcm_bins
=
pcm_bins
)
if
energy
>
max_energy
:
max_energy
=
energy
start
+=
self
.
window_length
self
.
vad_threshold
=
max_energy
+
100
if
max_energy
>
self
.
vad_deafult
else
self
.
vad_deafult
# 保存成文件
np
.
save
(
self
.
vad_threshold_path
,
self
.
vad_threshold
)
print
(
f
"vad 阈值大小:
{
self
.
vad_threshold
}
"
)
print
(
f
"环境采样保存:
{
os
.
path
.
realpath
(
self
.
vad_threshold_path
)
}
"
)
def
stream_asr
(
self
,
pcm_bin
):
# 先把 pcm_bin 送进去做端点检测
start
=
0
...
...
@@ -99,7 +101,7 @@ class AudioMannger:
self
.
on_asr
=
True
self
.
silence_cnt
=
0
print
(
"录音中"
)
self
.
audios
+=
pcm_bin
[
start
:
start
+
self
.
window_length
]
self
.
audios
+=
pcm_bin
[
start
:
start
+
self
.
window_length
]
else
:
if
self
.
on_asr
:
self
.
silence_cnt
+=
1
...
...
@@ -110,41 +112,42 @@ class AudioMannger:
print
(
"录音停止"
)
# audios 保存为 wav, 送入 ASR
if
len
(
self
.
audios
)
>
2
*
16000
:
file_path
=
os
.
path
.
join
(
self
.
file_dir
,
"asr_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
)
file_path
=
os
.
path
.
join
(
self
.
file_dir
,
"asr_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
)
self
.
save_audio
(
file_path
=
file_path
)
self
.
asr_result
=
self
.
robot
.
speech2text
(
file_path
)
self
.
clear_audio
()
return
self
.
asr_result
return
self
.
asr_result
else
:
# 正常接收
print
(
"录音中 静音"
)
self
.
audios
+=
pcm_bin
[
start
:
start
+
self
.
window_length
]
self
.
audios
+=
pcm_bin
[
start
:
start
+
self
.
window_length
]
start
+=
self
.
window_length
return
""
def
save_audio
(
self
,
file_path
):
print
(
"保存音频"
)
wf
=
wave
.
open
(
file_path
,
'wb'
)
# 创建一个音频文件,名字为“01.wav"
wf
.
setnchannels
(
1
)
# 设置声道数为2
wf
.
setsampwidth
(
2
)
# 设置采样深度为
wf
.
setframerate
(
16000
)
# 设置采样率为16000
wf
=
wave
.
open
(
file_path
,
'wb'
)
# 创建一个音频文件,名字为“01.wav"
wf
.
setnchannels
(
1
)
# 设置声道数为2
wf
.
setsampwidth
(
2
)
# 设置采样深度为
wf
.
setframerate
(
16000
)
# 设置采样率为16000
# 将数据写入创建的音频文件
wf
.
writeframes
(
self
.
audios
)
# 写完后将文件关闭
wf
.
close
()
def
end
(
self
):
# audios 保存为 wav, 送入 ASR
file_path
=
os
.
path
.
join
(
self
.
file_dir
,
"asr.wav"
)
self
.
save_audio
(
file_path
=
file_path
)
return
self
.
robot
.
speech2text
(
file_path
)
def
stop
(
self
):
self
.
is_pause
=
True
self
.
audios
=
b
''
def
resume
(
self
):
self
.
is_pause
=
False
\ No newline at end of file
demos/speech_web/speech_server/src/SpeechBase/asr.py
浏览文件 @
795eb7bd
from
re
import
sub
import
numpy
as
np
import
paddle
import
librosa
import
soundfile
from
paddlespeech.server.engine.asr.online.python.asr_engine
import
ASREngine
from
paddlespeech.server.engine.asr.online.python.asr_engine
import
PaddleASRConnectionHanddler
from
paddlespeech.server.utils.config
import
get_config
def
readWave
(
samples
):
x_len
=
len
(
samples
)
...
...
@@ -31,20 +28,23 @@ def readWave(samples):
class
ASR
:
def
__init__
(
self
,
config_path
,
)
->
None
:
def
__init__
(
self
,
config_path
,
)
->
None
:
self
.
config
=
get_config
(
config_path
)[
'asr_online'
]
self
.
engine
=
ASREngine
()
self
.
engine
.
init
(
self
.
config
)
self
.
connection_handler
=
PaddleASRConnectionHanddler
(
self
.
engine
)
def
offlineASR
(
self
,
samples
,
sample_rate
=
16000
):
x_chunk
,
x_chunk_lens
=
self
.
engine
.
preprocess
(
samples
=
samples
,
sample_rate
=
sample_rate
)
x_chunk
,
x_chunk_lens
=
self
.
engine
.
preprocess
(
samples
=
samples
,
sample_rate
=
sample_rate
)
self
.
engine
.
run
(
x_chunk
,
x_chunk_lens
)
result
=
self
.
engine
.
postprocess
()
self
.
engine
.
reset
()
return
result
def
onlineASR
(
self
,
samples
:
bytes
=
None
,
is_finished
=
False
):
def
onlineASR
(
self
,
samples
:
bytes
=
None
,
is_finished
=
False
):
if
not
is_finished
:
# 流式开始
self
.
connection_handler
.
extract_feat
(
samples
)
...
...
@@ -58,5 +58,3 @@ class ASR:
asr_results
=
self
.
connection_handler
.
get_result
()
self
.
connection_handler
.
reset
()
return
asr_results
\ No newline at end of file
demos/speech_web/speech_server/src/SpeechBase/nlp.py
浏览文件 @
795eb7bd
from
paddlenlp
import
Taskflow
class
NLP
:
def
__init__
(
self
,
ie_model_path
=
None
):
schema
=
[
"时间"
,
"出发地"
,
"目的地"
,
"费用"
]
if
ie_model_path
:
self
.
ie_model
=
Taskflow
(
"information_extraction"
,
schema
=
schema
,
task_path
=
ie_model_path
)
self
.
ie_model
=
Taskflow
(
"information_extraction"
,
schema
=
schema
,
task_path
=
ie_model_path
)
else
:
self
.
ie_model
=
Taskflow
(
"information_extraction"
,
schema
=
schema
)
self
.
ie_model
=
Taskflow
(
"information_extraction"
,
schema
=
schema
)
self
.
dialogue_model
=
Taskflow
(
"dialogue"
)
def
chat
(
self
,
text
):
result
=
self
.
dialogue_model
([
text
])
return
result
[
0
]
def
ie
(
self
,
text
):
result
=
self
.
ie_model
(
text
)
return
result
\ No newline at end of file
demos/speech_web/speech_server/src/SpeechBase/sql_helper.py
浏览文件 @
795eb7bd
import
base64
import
sqlite3
import
os
import
sqlite3
import
numpy
as
np
from
pkg_resources
import
resource_stream
def
dict_factory
(
cursor
,
row
):
d
=
{}
for
idx
,
col
in
enumerate
(
cursor
.
description
):
d
[
col
[
0
]]
=
row
[
idx
]
return
d
def
dict_factory
(
cursor
,
row
):
d
=
{}
for
idx
,
col
in
enumerate
(
cursor
.
description
):
d
[
col
[
0
]]
=
row
[
idx
]
return
d
class
DataBase
(
object
):
def
__init__
(
self
,
db_path
:
str
):
def
__init__
(
self
,
db_path
:
str
):
db_path
=
os
.
path
.
realpath
(
db_path
)
if
os
.
path
.
exists
(
db_path
):
...
...
@@ -21,12 +22,12 @@ class DataBase(object):
db_path_dir
=
os
.
path
.
dirname
(
db_path
)
os
.
makedirs
(
db_path_dir
,
exist_ok
=
True
)
self
.
db_path
=
db_path
self
.
conn
=
sqlite3
.
connect
(
self
.
db_path
)
self
.
conn
.
row_factory
=
dict_factory
self
.
cursor
=
self
.
conn
.
cursor
()
self
.
init_database
()
def
init_database
(
self
):
"""
初始化数据库, 若表不存在则创建
...
...
@@ -41,20 +42,21 @@ class DataBase(object):
"""
self
.
cursor
.
execute
(
sql
)
self
.
conn
.
commit
()
def
execute_base
(
self
,
sql
,
data_dict
):
self
.
cursor
.
execute
(
sql
,
data_dict
)
self
.
conn
.
commit
()
def
insert_one
(
self
,
username
,
vector_base64
:
str
,
wav_path
):
def
insert_one
(
self
,
username
,
vector_base64
:
str
,
wav_path
):
if
not
os
.
path
.
exists
(
wav_path
):
return
None
,
"wav not exists"
else
:
sql
=
f
"""
sql
=
"""
insert into
vprtable (username, vector, wavpath)
values (?, ?, ?)
"""
try
:
self
.
cursor
.
execute
(
sql
,
(
username
,
vector_base64
,
wav_path
))
self
.
conn
.
commit
()
...
...
@@ -63,25 +65,27 @@ class DataBase(object):
except
Exception
as
e
:
print
(
e
)
return
None
,
e
def
select_all
(
self
):
sql
=
"""
SELECT * from vprtable
"""
result
=
self
.
cursor
.
execute
(
sql
).
fetchall
()
return
result
def
select_by_id
(
self
,
vpr_id
):
sql
=
f
"""
SELECT * from vprtable WHERE `id` =
{
vpr_id
}
"""
result
=
self
.
cursor
.
execute
(
sql
).
fetchall
()
return
result
def
select_by_username
(
self
,
username
):
sql
=
f
"""
SELECT * from vprtable WHERE `username` = '
{
username
}
'
"""
result
=
self
.
cursor
.
execute
(
sql
).
fetchall
()
return
result
...
...
@@ -89,28 +93,30 @@ class DataBase(object):
sql
=
f
"""
DELETE from vprtable WHERE `username`='
{
username
}
'
"""
self
.
cursor
.
execute
(
sql
)
self
.
conn
.
commit
()
def
drop_all
(
self
):
sql
=
f
"""
sql
=
"""
DELETE from vprtable
"""
self
.
cursor
.
execute
(
sql
)
self
.
conn
.
commit
()
def
drop_table
(
self
):
sql
=
f
"""
sql
=
"""
DROP TABLE vprtable
"""
self
.
cursor
.
execute
(
sql
)
self
.
conn
.
commit
()
def
encode_vector
(
self
,
vector
:
np
.
ndarray
):
def
encode_vector
(
self
,
vector
:
np
.
ndarray
):
return
base64
.
b64encode
(
vector
).
decode
(
'utf8'
)
def
decode_vector
(
self
,
vector_base64
,
dtype
=
np
.
float32
):
b
=
base64
.
b64decode
(
vector_base64
)
vc
=
np
.
frombuffer
(
b
,
dtype
=
dtype
)
return
vc
\ No newline at end of file
demos/speech_web/speech_server/src/SpeechBase/tts.py
浏览文件 @
795eb7bd
...
...
@@ -5,18 +5,19 @@
# 2. 加载模型
# 3. 端到端推理
# 4. 流式推理
import
base64
import
math
import
logging
import
math
import
numpy
as
np
from
paddlespeech.server.utils.onnx_infer
import
get_sess
from
paddlespeech.t2s.frontend.zh_frontend
import
Frontend
from
paddlespeech.server.utils.util
import
denorm
,
get_chunks
from
paddlespeech.server.engine.tts.online.onnx.tts_engine
import
TTSEngine
from
paddlespeech.server.utils.audio_process
import
float2pcm
from
paddlespeech.server.utils.config
import
get_config
from
paddlespeech.server.utils.util
import
denorm
from
paddlespeech.server.utils.util
import
get_chunks
from
paddlespeech.t2s.frontend.zh_frontend
import
Frontend
from
paddlespeech.server.engine.tts.online.onnx.tts_engine
import
TTSEngine
class
TTS
:
def
__init__
(
self
,
config_path
):
...
...
@@ -26,12 +27,12 @@ class TTS:
self
.
engine
.
init
(
self
.
config
)
self
.
executor
=
self
.
engine
.
executor
#self.engine.warm_up()
# 前端初始化
self
.
frontend
=
Frontend
(
phone_vocab_path
=
self
.
engine
.
executor
.
phones_dict
,
tone_vocab_path
=
None
)
phone_vocab_path
=
self
.
engine
.
executor
.
phones_dict
,
tone_vocab_path
=
None
)
def
depadding
(
self
,
data
,
chunk_num
,
chunk_id
,
block
,
pad
,
upsample
):
"""
Streaming inference removes the result of pad inference
...
...
@@ -48,39 +49,37 @@ class TTS:
data
=
data
[
front_pad
*
upsample
:(
front_pad
+
block
)
*
upsample
]
return
data
def
offlineTTS
(
self
,
text
):
get_tone_ids
=
False
merge_sentences
=
False
input_ids
=
self
.
frontend
.
get_input_ids
(
text
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
text
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
phone_ids
=
input_ids
[
"phone_ids"
]
wav_list
=
[]
for
i
in
range
(
len
(
phone_ids
)):
orig_hs
=
self
.
engine
.
executor
.
am_encoder_infer_sess
.
run
(
None
,
input_feed
=
{
'text'
:
phone_ids
[
i
].
numpy
()}
)
None
,
input_feed
=
{
'text'
:
phone_ids
[
i
].
numpy
()})
hs
=
orig_hs
[
0
]
am_decoder_output
=
self
.
engine
.
executor
.
am_decoder_sess
.
run
(
None
,
input_feed
=
{
'xs'
:
hs
})
None
,
input_feed
=
{
'xs'
:
hs
})
am_postnet_output
=
self
.
engine
.
executor
.
am_postnet_sess
.
run
(
None
,
input_feed
=
{
'xs'
:
np
.
transpose
(
am_decoder_output
[
0
],
(
0
,
2
,
1
))
})
None
,
input_feed
=
{
'xs'
:
np
.
transpose
(
am_decoder_output
[
0
],
(
0
,
2
,
1
))
})
am_output_data
=
am_decoder_output
+
np
.
transpose
(
am_postnet_output
[
0
],
(
0
,
2
,
1
))
normalized_mel
=
am_output_data
[
0
][
0
]
mel
=
denorm
(
normalized_mel
,
self
.
engine
.
executor
.
am_mu
,
self
.
engine
.
executor
.
am_std
)
mel
=
denorm
(
normalized_mel
,
self
.
engine
.
executor
.
am_mu
,
self
.
engine
.
executor
.
am_std
)
wav
=
self
.
engine
.
executor
.
voc_sess
.
run
(
output_names
=
None
,
input_feed
=
{
'logmel'
:
mel
})[
0
]
output_names
=
None
,
input_feed
=
{
'logmel'
:
mel
})[
0
]
wav_list
.
append
(
wav
)
wavs
=
np
.
concatenate
(
wav_list
)
return
wavs
def
streamTTS
(
self
,
text
):
get_tone_ids
=
False
...
...
@@ -88,9 +87,7 @@ class TTS:
# front
input_ids
=
self
.
frontend
.
get_input_ids
(
text
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
text
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
phone_ids
=
input_ids
[
"phone_ids"
]
for
i
in
range
(
len
(
phone_ids
)):
...
...
@@ -105,14 +102,15 @@ class TTS:
mel
=
mel
[
0
]
# voc streaming
mel_chunks
=
get_chunks
(
mel
,
self
.
config
.
voc_block
,
self
.
config
.
voc_pad
,
"voc"
)
mel_chunks
=
get_chunks
(
mel
,
self
.
config
.
voc_block
,
self
.
config
.
voc_pad
,
"voc"
)
voc_chunk_num
=
len
(
mel_chunks
)
for
i
,
mel_chunk
in
enumerate
(
mel_chunks
):
sub_wav
=
self
.
executor
.
voc_sess
.
run
(
output_names
=
None
,
input_feed
=
{
'logmel'
:
mel_chunk
})
sub_wav
=
self
.
depadding
(
sub_wav
[
0
],
voc_chunk_num
,
i
,
self
.
config
.
voc_block
,
self
.
config
.
voc_pad
,
self
.
config
.
voc_upsample
)
sub_wav
=
self
.
depadding
(
sub_wav
[
0
],
voc_chunk_num
,
i
,
self
.
config
.
voc_block
,
self
.
config
.
voc_pad
,
self
.
config
.
voc_upsample
)
yield
self
.
after_process
(
sub_wav
)
...
...
@@ -130,7 +128,8 @@ class TTS:
end
=
min
(
self
.
config
.
voc_block
+
self
.
config
.
voc_pad
,
mel_len
)
# streaming am
hss
=
get_chunks
(
orig_hs
,
self
.
config
.
am_block
,
self
.
config
.
am_pad
,
"am"
)
hss
=
get_chunks
(
orig_hs
,
self
.
config
.
am_block
,
self
.
config
.
am_pad
,
"am"
)
am_chunk_num
=
len
(
hss
)
for
i
,
hs
in
enumerate
(
hss
):
am_decoder_output
=
self
.
executor
.
am_decoder_sess
.
run
(
...
...
@@ -147,7 +146,8 @@ class TTS:
sub_mel
=
denorm
(
normalized_mel
,
self
.
executor
.
am_mu
,
self
.
executor
.
am_std
)
sub_mel
=
self
.
depadding
(
sub_mel
,
am_chunk_num
,
i
,
self
.
config
.
am_block
,
self
.
config
.
am_pad
,
1
)
self
.
config
.
am_block
,
self
.
config
.
am_pad
,
1
)
if
i
==
0
:
mel_streaming
=
sub_mel
...
...
@@ -165,23 +165,22 @@ class TTS:
output_names
=
None
,
input_feed
=
{
'logmel'
:
voc_chunk
})
sub_wav
=
self
.
depadding
(
sub_wav
[
0
],
voc_chunk_num
,
voc_chunk_id
,
self
.
config
.
voc_block
,
self
.
config
.
voc_pad
,
self
.
config
.
voc_upsample
)
self
.
config
.
voc_block
,
self
.
config
.
voc_pad
,
self
.
config
.
voc_upsample
)
yield
self
.
after_process
(
sub_wav
)
voc_chunk_id
+=
1
start
=
max
(
0
,
voc_chunk_id
*
self
.
config
.
voc_block
-
self
.
config
.
voc_pad
)
end
=
min
(
(
voc_chunk_id
+
1
)
*
self
.
config
.
voc_block
+
self
.
config
.
voc_pad
,
mel_len
)
start
=
max
(
0
,
voc_chunk_id
*
self
.
config
.
voc_block
-
self
.
config
.
voc_pad
)
end
=
min
((
voc_chunk_id
+
1
)
*
self
.
config
.
voc_block
+
self
.
config
.
voc_pad
,
mel_len
)
else
:
logging
.
error
(
"Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts."
)
)
def
streamTTSBytes
(
self
,
text
):
for
wav
in
self
.
engine
.
executor
.
infer
(
text
=
text
,
...
...
@@ -191,19 +190,14 @@ class TTS:
wav
=
float2pcm
(
wav
)
# float32 to int16
wav_bytes
=
wav
.
tobytes
()
# to bytes
yield
wav_bytes
def
after_process
(
self
,
wav
):
# for tvm
wav
=
float2pcm
(
wav
)
# float32 to int16
wav_bytes
=
wav
.
tobytes
()
# to bytes
wav_base64
=
base64
.
b64encode
(
wav_bytes
).
decode
(
'utf8'
)
# to base64
return
wav_base64
def
streamTTS_TVM
(
self
,
text
):
# 用 TVM 优化
pass
\ No newline at end of file
demos/speech_web/speech_server/src/SpeechBase/vpr.py
浏览文件 @
795eb7bd
# vpr Demo 没有使用 mysql 与 muilvs, 仅用于docker演示
import
logging
import
faiss
from
matplotlib
import
use
import
numpy
as
np
from
.sql_helper
import
DataBase
from
.vpr_encode
import
get_audio_embedding
class
VPR
:
def
__init__
(
self
,
db_path
,
dim
,
top_k
)
->
None
:
# 初始化
...
...
@@ -14,15 +16,15 @@ class VPR:
self
.
top_k
=
top_k
self
.
dtype
=
np
.
float32
self
.
vpr_idx
=
0
# db 初始化
self
.
db
=
DataBase
(
db_path
)
# faiss 初始化
index_ip
=
faiss
.
IndexFlatIP
(
dim
)
self
.
index_ip
=
faiss
.
IndexIDMap
(
index_ip
)
self
.
init
()
def
init
(
self
):
# demo 初始化,把 mysql中的向量注册到 faiss 中
sql_dbs
=
self
.
db
.
select_all
()
...
...
@@ -34,12 +36,13 @@ class VPR:
if
len
(
vc
.
shape
)
==
1
:
vc
=
np
.
expand_dims
(
vc
,
axis
=
0
)
# 构建数据库
self
.
index_ip
.
add_with_ids
(
vc
,
np
.
array
((
idx
,)).
astype
(
'int64'
))
self
.
index_ip
.
add_with_ids
(
vc
,
np
.
array
(
(
idx
,
)).
astype
(
'int64'
))
logging
.
info
(
"faiss 构建完毕"
)
def
faiss_enroll
(
self
,
idx
,
vc
):
self
.
index_ip
.
add_with_ids
(
vc
,
np
.
array
((
idx
,)).
astype
(
'int64'
))
self
.
index_ip
.
add_with_ids
(
vc
,
np
.
array
((
idx
,
)).
astype
(
'int64'
))
def
vpr_enroll
(
self
,
username
,
wav_path
):
# 注册声纹
emb
=
get_audio_embedding
(
wav_path
)
...
...
@@ -53,21 +56,22 @@ class VPR:
else
:
last_idx
,
mess
=
None
return
last_idx
def
vpr_recog
(
self
,
wav_path
):
# 识别声纹
emb_search
=
get_audio_embedding
(
wav_path
)
if
emb_search
is
not
None
:
emb_search
=
np
.
expand_dims
(
emb_search
,
axis
=
0
)
D
,
I
=
self
.
index_ip
.
search
(
emb_search
,
self
.
top_k
)
D
=
D
.
tolist
()[
0
]
I
=
I
.
tolist
()[
0
]
return
[(
round
(
D
[
i
]
*
100
,
2
),
I
[
i
])
for
i
in
range
(
len
(
D
))
if
I
[
i
]
!=
-
1
]
I
=
I
.
tolist
()[
0
]
return
[(
round
(
D
[
i
]
*
100
,
2
),
I
[
i
])
for
i
in
range
(
len
(
D
))
if
I
[
i
]
!=
-
1
]
else
:
logging
.
error
(
"识别失败"
)
return
None
def
do_search_vpr
(
self
,
wav_path
):
spk_ids
,
paths
,
scores
=
[],
[],
[]
recog_result
=
self
.
vpr_recog
(
wav_path
)
...
...
@@ -78,41 +82,39 @@ class VPR:
scores
.
append
(
score
)
paths
.
append
(
""
)
return
spk_ids
,
paths
,
scores
def
vpr_del
(
self
,
username
):
# 根据用户username, 删除声纹
# 查用户ID,删除对应向量
res
=
self
.
db
.
select_by_username
(
username
)
for
r
in
res
:
idx
=
r
[
'id'
]
self
.
index_ip
.
remove_ids
(
np
.
array
((
idx
,)).
astype
(
'int64'
))
self
.
index_ip
.
remove_ids
(
np
.
array
((
idx
,
)).
astype
(
'int64'
))
self
.
db
.
drop_by_username
(
username
)
def
vpr_list
(
self
):
# 获取数据列表
return
self
.
db
.
select_all
()
def
do_list
(
self
):
spk_ids
,
vpr_ids
=
[],
[]
for
res
in
self
.
db
.
select_all
():
spk_ids
.
append
(
res
[
'username'
])
vpr_ids
.
append
(
res
[
'id'
])
return
spk_ids
,
vpr_ids
return
spk_ids
,
vpr_ids
def
do_get_wav
(
self
,
vpr_idx
):
res
=
self
.
db
.
select_by_id
(
vpr_idx
)
return
res
[
0
][
'wavpath'
]
res
=
self
.
db
.
select_by_id
(
vpr_idx
)
return
res
[
0
][
'wavpath'
]
def
vpr_data
(
self
,
idx
):
# 获取对应ID的数据
res
=
self
.
db
.
select_by_id
(
idx
)
return
res
def
vpr_droptable
(
self
):
# 删除表
self
.
db
.
drop_table
()
# 清空 faiss
self
.
index_ip
.
reset
()
demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py
浏览文件 @
795eb7bd
from
paddlespeech.cli.vector
import
VectorExecutor
import
numpy
as
np
import
logging
import
numpy
as
np
from
paddlespeech.cli.vector
import
VectorExecutor
vector_executor
=
VectorExecutor
()
def
get_audio_embedding
(
path
):
"""
Use vpr_inference to generate embedding of audio
...
...
@@ -16,5 +19,3 @@ def get_audio_embedding(path):
except
Exception
as
e
:
logging
.
error
(
f
"Error with embedding:
{
e
}
"
)
return
None
\ No newline at end of file
demos/speech_web/speech_server/src/WebsocketManeger.py
浏览文件 @
795eb7bd
...
...
@@ -2,6 +2,7 @@ from typing import List
from
fastapi
import
WebSocket
class
ConnectionManager
:
def
__init__
(
self
):
# 存放激活的ws连接对象
...
...
@@ -28,4 +29,4 @@ class ConnectionManager:
await
connection
.
send_text
(
message
)
manager
=
ConnectionManager
()
\ No newline at end of file
manager
=
ConnectionManager
()
demos/speech_web/speech_server/src/robot.py
浏览文件 @
795eb7bd
from
paddlespeech.cli.asr.infer
import
ASRExecutor
import
soundfile
as
sf
import
os
import
librosa
import
soundfile
as
sf
from
src.SpeechBase.asr
import
ASR
from
src.SpeechBase.tts
import
TTS
from
src.SpeechBase.nlp
import
NLP
from
src.SpeechBase.tts
import
TTS
from
paddlespeech.cli.asr.infer
import
ASRExecutor
class
Robot
:
def
__init__
(
self
,
asr_config
,
tts_config
,
asr_init_path
,
def
__init__
(
self
,
asr_config
,
tts_config
,
asr_init_path
,
ie_model_path
=
None
)
->
None
:
self
.
nlp
=
NLP
(
ie_model_path
=
ie_model_path
)
self
.
asr
=
ASR
(
config_path
=
asr_config
)
self
.
tts
=
TTS
(
config_path
=
tts_config
)
self
.
tts_sample_rate
=
24000
self
.
asr_sample_rate
=
16000
# 流式识别效果不如端到端的模型,这里流式模型与端到端模型分开
self
.
asr_model
=
ASRExecutor
()
self
.
asr_name
=
"conformer_wenetspeech"
self
.
warm_up_asrmodel
(
asr_init_path
)
def
warm_up_asrmodel
(
self
,
asr_init_path
):
def
warm_up_asrmodel
(
self
,
asr_init_path
):
if
not
os
.
path
.
exists
(
asr_init_path
):
path_dir
=
os
.
path
.
dirname
(
asr_init_path
)
if
not
os
.
path
.
exists
(
path_dir
):
os
.
makedirs
(
path_dir
,
exist_ok
=
True
)
# TTS生成,采样率24000
text
=
"生成初始音频"
self
.
text2speech
(
text
,
asr_init_path
)
# asr model初始化
self
.
asr_model
(
asr_init_path
,
model
=
self
.
asr_name
,
lang
=
'zh'
,
sample_rate
=
16000
,
force_yes
=
True
)
self
.
asr_model
(
asr_init_path
,
model
=
self
.
asr_name
,
lang
=
'zh'
,
sample_rate
=
16000
,
force_yes
=
True
)
def
speech2text
(
self
,
audio_file
):
self
.
asr_model
.
preprocess
(
self
.
asr_name
,
audio_file
)
self
.
asr_model
.
infer
(
self
.
asr_name
)
res
=
self
.
asr_model
.
postprocess
()
return
res
def
text2speech
(
self
,
text
,
outpath
):
wav
=
self
.
tts
.
offlineTTS
(
text
)
sf
.
write
(
outpath
,
wav
,
samplerate
=
self
.
tts_sample_rate
)
sf
.
write
(
outpath
,
wav
,
samplerate
=
self
.
tts_sample_rate
)
res
=
wav
return
res
def
text2speechStream
(
self
,
text
):
for
sub_wav_base64
in
self
.
tts
.
streamTTS
(
text
=
text
):
yield
sub_wav_base64
def
text2speechStreamBytes
(
self
,
text
):
for
wav_bytes
in
self
.
tts
.
streamTTSBytes
(
text
=
text
):
yield
wav_bytes
...
...
@@ -66,5 +70,3 @@ class Robot:
def
ie
(
self
,
text
):
result
=
self
.
nlp
.
ie
(
text
)
return
result
\ No newline at end of file
demos/speech_web/speech_server/src/util.py
浏览文件 @
795eb7bd
import
random
def
randName
(
n
=
5
):
return
""
.
join
(
random
.
sample
(
'zyxwvutsrqponmlkjihgfedcba'
,
n
))
return
""
.
join
(
random
.
sample
(
'zyxwvutsrqponmlkjihgfedcba'
,
n
))
def
SuccessRequest
(
result
=
None
,
message
=
"ok"
):
return
{
"code"
:
0
,
"result"
:
result
,
"message"
:
message
}
return
{
"code"
:
0
,
"result"
:
result
,
"message"
:
message
}
def
ErrorRequest
(
result
=
None
,
message
=
"error"
):
return
{
"code"
:
-
1
,
"result"
:
result
,
"message"
:
message
}
\ No newline at end of file
return
{
"code"
:
-
1
,
"result"
:
result
,
"message"
:
message
}
demos/streaming_asr_server/local/rtf_from_log.py
浏览文件 @
795eb7bd
...
...
@@ -34,7 +34,7 @@ if __name__ == '__main__':
n
=
0
for
m
in
rtfs
:
# not accurate, may have duplicate log
n
+=
1
n
+=
1
T
+=
m
[
'T'
]
P
+=
m
[
'P'
]
...
...
docs/requirements.txt
浏览文件 @
795eb7bd
myst-parser
numpydoc
recommonmark>=0.5.0
sphinx
sphinx-autobuild
sphinx-markdown-tables
sphinx_rtd_theme
paddlepaddle>=2.2.2
braceexpandcolorlog
editdistance
fastapi
g2p_en
g2pM
h5py
...
...
@@ -14,40 +8,45 @@ inflect
jieba
jsonlines
kaldiio
keyboard
librosa==0.8.1
loguru
matplotlib
myst-parser
nara_wpe
numpydoc
onnxruntime==1.10.0
opencc
pandas
paddlenlp
paddlepaddle>=2.2.2
paddlespeech_feat
pandas
pathos == 0.2.8
pattern_singleton
Pillow>=9.0.0
praatio==5.0.0
prettytable
pypinyin
pypinyin-dict
python-dateutil
pyworld==0.2.12
recommonmark>=0.5.0
resampy==0.2.2
sacrebleu
scipy
sentencepiece~=0.1.96
soundfile~=0.10
sphinx
sphinx-autobuild
sphinx-markdown-tables
sphinx_rtd_theme
textgrid
timer
tqdm
typeguard
uvicorn
visualdl
webrtcvad
websockets
yacs~=0.1.8
prettytable
zhon
colorlog
pathos == 0.2.8
fastapi
websockets
keyboard
uvicorn
pattern_singleton
braceexpand
\ No newline at end of file
docs/source/conf.py
浏览文件 @
795eb7bd
...
...
@@ -20,10 +20,11 @@
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
import
os
import
sys
import
recommonmark.parser
import
sphinx_rtd_theme
import
sys
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
'../..'
))
autodoc_mock_imports
=
[
"soundfile"
,
"librosa"
]
...
...
examples/iwslt2012/punc0/local/preprocess.py
浏览文件 @
795eb7bd
import
argparse
import
os
def
process_sentence
(
line
):
if
line
==
''
:
return
''
res
=
line
[
0
]
for
i
in
range
(
1
,
len
(
line
)):
res
+=
(
' '
+
line
[
i
])
return
res
if
line
==
''
:
return
''
res
=
line
[
0
]
for
i
in
range
(
1
,
len
(
line
)):
res
+=
(
' '
+
line
[
i
])
return
res
if
__name__
==
"__main__"
:
paser
=
argparse
.
ArgumentParser
(
description
=
"Input filename"
)
paser
.
add_argument
(
'-input_file'
)
paser
.
add_argument
(
'-output_file'
)
sentence_cnt
=
0
args
=
paser
.
parse_args
()
with
open
(
args
.
input_file
,
'r'
)
as
f
:
with
open
(
args
.
output_file
,
'w'
)
as
write_f
:
while
True
:
line
=
f
.
readline
()
if
line
:
sentence_cnt
+=
1
write_f
.
write
(
process_sentence
(
line
))
else
:
break
print
(
'preprocess over'
)
print
(
'total sentences number:'
,
sentence_cnt
)
paser
=
argparse
.
ArgumentParser
(
description
=
"Input filename"
)
paser
.
add_argument
(
'-input_file'
)
paser
.
add_argument
(
'-output_file'
)
sentence_cnt
=
0
args
=
paser
.
parse_args
()
with
open
(
args
.
input_file
,
'r'
)
as
f
:
with
open
(
args
.
output_file
,
'w'
)
as
write_f
:
while
True
:
line
=
f
.
readline
()
if
line
:
sentence_cnt
+=
1
write_f
.
write
(
process_sentence
(
line
))
else
:
break
print
(
'preprocess over'
)
print
(
'total sentences number:'
,
sentence_cnt
)
examples/other/tts_finetune/tts3/finetune.py
浏览文件 @
795eb7bd
...
...
@@ -17,15 +17,14 @@ from pathlib import Path
from
typing
import
Union
import
yaml
from
paddle
import
distributed
as
dist
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.exps.fastspeech2.train
import
train_sp
from
local.check_oov
import
get_check_result
from
local.extract
import
extract_feature
from
local.label_process
import
get_single_label
from
local.prepare_env
import
generate_finetune_env
from
paddle
import
distributed
as
dist
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.exps.fastspeech2.train
import
train_sp
from
utils.gen_duration_from_textgrid
import
gen_duration_from_textgrid
DICT_EN
=
'tools/aligner/cmudict-0.7b'
...
...
paddlespeech/__init__.py
浏览文件 @
795eb7bd
...
...
@@ -14,5 +14,3 @@
import
_locale
_locale
.
_getdefaultlocale
=
(
lambda
*
args
:
[
'en_US'
,
'utf8'
])
paddlespeech/audio/__init__.py
浏览文件 @
795eb7bd
...
...
@@ -14,12 +14,12 @@
from
.
import
compliance
from
.
import
datasets
from
.
import
features
from
.
import
text
from
.
import
transform
from
.
import
streamdata
from
.
import
functional
from
.
import
io
from
.
import
metric
from
.
import
sox_effects
from
.
import
streamdata
from
.
import
text
from
.
import
transform
from
.backends
import
load
from
.backends
import
save
paddlespeech/audio/streamdata/__init__.py
浏览文件 @
795eb7bd
...
...
@@ -4,67 +4,66 @@
# Modified from https://github.com/webdataset/webdataset
#
# flake8: noqa
from
.cache
import
(
cached_tarfile_samples
,
cached_tarfile_to_samples
,
lru_cleanup
,
pipe_cleaner
,
)
from
.compat
import
WebDataset
,
WebLoader
,
FluidWrapper
from
.extradatasets
import
MockDataset
,
with_epoch
,
with_length
from
.filters
import
(
associate
,
batched
,
decode
,
detshuffle
,
extract_keys
,
getfirst
,
info
,
map
,
map_dict
,
map_tuple
,
pipelinefilter
,
rename
,
rename_keys
,
audio_resample
,
select
,
shuffle
,
slice
,
to_tuple
,
transform_with
,
unbatched
,
xdecode
,
audio_data_filter
,
audio_tokenize
,
audio_resample
,
audio_compute_fbank
,
audio_spec_aug
,
sort
,
audio_padding
,
audio_cmvn
,
placeholder
,
)
from
.handlers
import
(
ignore_and_continue
,
ignore_and_stop
,
reraise_exception
,
warn_and_continue
,
warn_and_stop
,
)
from
.cache
import
cached_tarfile_samples
from
.cache
import
cached_tarfile_to_samples
from
.cache
import
lru_cleanup
from
.cache
import
pipe_cleaner
from
.compat
import
FluidWrapper
from
.compat
import
WebDataset
from
.compat
import
WebLoader
from
.extradatasets
import
MockDataset
from
.extradatasets
import
with_epoch
from
.extradatasets
import
with_length
from
.filters
import
associate
from
.filters
import
audio_cmvn
from
.filters
import
audio_compute_fbank
from
.filters
import
audio_data_filter
from
.filters
import
audio_padding
from
.filters
import
audio_resample
from
.filters
import
audio_spec_aug
from
.filters
import
audio_tokenize
from
.filters
import
batched
from
.filters
import
decode
from
.filters
import
detshuffle
from
.filters
import
extract_keys
from
.filters
import
getfirst
from
.filters
import
info
from
.filters
import
map
from
.filters
import
map_dict
from
.filters
import
map_tuple
from
.filters
import
pipelinefilter
from
.filters
import
placeholder
from
.filters
import
rename
from
.filters
import
rename_keys
from
.filters
import
select
from
.filters
import
shuffle
from
.filters
import
slice
from
.filters
import
sort
from
.filters
import
to_tuple
from
.filters
import
transform_with
from
.filters
import
unbatched
from
.filters
import
xdecode
from
.handlers
import
ignore_and_continue
from
.handlers
import
ignore_and_stop
from
.handlers
import
reraise_exception
from
.handlers
import
warn_and_continue
from
.handlers
import
warn_and_stop
from
.mix
import
RandomMix
from
.mix
import
RoundRobin
from
.pipeline
import
DataPipeline
from
.shardlists
import
(
MultiShardSample
,
ResampledShards
,
SimpleShardList
,
non_empty
,
resampled
,
shardspec
,
single_node_only
,
split_by_node
,
split_by_worker
,
)
from
.tariterators
import
tarfile_samples
,
tarfile_to_samples
from
.utils
import
PipelineStage
,
repeatedly
from
.writer
import
ShardWriter
,
TarWriter
,
numpy_dumps
from
.mix
import
RandomMix
,
RoundRobin
from
.shardlists
import
MultiShardSample
from
.shardlists
import
non_empty
from
.shardlists
import
resampled
from
.shardlists
import
ResampledShards
from
.shardlists
import
shardspec
from
.shardlists
import
SimpleShardList
from
.shardlists
import
single_node_only
from
.shardlists
import
split_by_node
from
.shardlists
import
split_by_worker
from
.tariterators
import
tarfile_samples
from
.tariterators
import
tarfile_to_samples
from
.utils
import
PipelineStage
from
.utils
import
repeatedly
from
.writer
import
numpy_dumps
from
.writer
import
ShardWriter
from
.writer
import
TarWriter
paddlespeech/audio/streamdata/autodecode.py
浏览文件 @
795eb7bd
...
...
@@ -5,18 +5,19 @@
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Automatically decode webdataset samples."""
import
io
,
json
,
os
,
pickle
,
re
,
tempfile
import
io
import
json
import
os
import
pickle
import
re
import
tempfile
from
functools
import
partial
import
numpy
as
np
"""Extensions passed on to the image decoder."""
image_extensions
=
"jpg jpeg png ppm pgm pbm pnm"
.
split
()
################################################################
# handle basic datatypes
################################################################
...
...
@@ -128,7 +129,7 @@ def call_extension_handler(key, data, f, extensions):
target
=
target
.
split
(
"."
)
if
len
(
target
)
>
len
(
extension
):
continue
if
extension
[
-
len
(
target
)
:]
==
target
:
if
extension
[
-
len
(
target
):]
==
target
:
return
f
(
data
)
return
None
...
...
@@ -268,7 +269,6 @@ def imagehandler(imagespec, extensions=image_extensions):
################################################################
# torch video
################################################################
'''
def torch_video(key, data):
"""Decode video using the torchvideo library.
...
...
@@ -289,7 +289,6 @@ def torch_video(key, data):
return torchvision.io.read_video(fname, pts_unit="sec")
'''
################################################################
# paddlespeech.audio
################################################################
...
...
@@ -359,7 +358,6 @@ def gzfilter(key, data):
# decode entire training amples
################################################################
default_pre_handlers
=
[
gzfilter
]
default_post_handlers
=
[
basichandlers
]
...
...
@@ -387,7 +385,8 @@ class Decoder:
pre
=
default_pre_handlers
if
post
is
None
:
post
=
default_post_handlers
assert
all
(
callable
(
h
)
for
h
in
handlers
),
f
"one of
{
handlers
}
not callable"
assert
all
(
callable
(
h
)
for
h
in
handlers
),
f
"one of
{
handlers
}
not callable"
assert
all
(
callable
(
h
)
for
h
in
pre
),
f
"one of
{
pre
}
not callable"
assert
all
(
callable
(
h
)
for
h
in
post
),
f
"one of
{
post
}
not callable"
self
.
handlers
=
pre
+
handlers
+
post
...
...
paddlespeech/audio/streamdata/cache.py
浏览文件 @
795eb7bd
...
...
@@ -2,7 +2,10 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
import
itertools
,
os
,
random
,
re
,
sys
import
os
import
random
import
re
import
sys
from
urllib.parse
import
urlparse
from
.
import
filters
...
...
@@ -40,7 +43,7 @@ def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False):
os
.
remove
(
fname
)
def
download
(
url
,
dest
,
chunk_size
=
1024
**
2
,
verbose
=
False
):
def
download
(
url
,
dest
,
chunk_size
=
1024
**
2
,
verbose
=
False
):
"""Download a file from `url` to `dest`."""
temp
=
dest
+
f
".temp
{
os
.
getpid
()
}
"
with
gopen
.
gopen
(
url
)
as
stream
:
...
...
@@ -65,12 +68,11 @@ def pipe_cleaner(spec):
def
get_file_cached
(
spec
,
cache_size
=-
1
,
cache_dir
=
None
,
url_to_name
=
pipe_cleaner
,
verbose
=
False
,
):
spec
,
cache_size
=-
1
,
cache_dir
=
None
,
url_to_name
=
pipe_cleaner
,
verbose
=
False
,
):
if
cache_size
==
-
1
:
cache_size
=
default_cache_size
if
cache_dir
is
None
:
...
...
@@ -107,15 +109,14 @@ verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0"))
def
cached_url_opener
(
data
,
handler
=
reraise_exception
,
cache_size
=-
1
,
cache_dir
=
None
,
url_to_name
=
pipe_cleaner
,
validator
=
check_tar_format
,
verbose
=
False
,
always
=
False
,
):
data
,
handler
=
reraise_exception
,
cache_size
=-
1
,
cache_dir
=
None
,
url_to_name
=
pipe_cleaner
,
validator
=
check_tar_format
,
verbose
=
False
,
always
=
False
,
):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
verbose
=
verbose
or
verbose_cache
for
sample
in
data
:
...
...
@@ -132,8 +133,7 @@ def cached_url_opener(
cache_size
=
cache_size
,
cache_dir
=
cache_dir
,
url_to_name
=
url_to_name
,
verbose
=
verbose
,
)
verbose
=
verbose
,
)
if
verbose
:
print
(
"# opening %s"
%
dest
,
file
=
sys
.
stderr
)
assert
os
.
path
.
exists
(
dest
)
...
...
@@ -143,9 +143,8 @@ def cached_url_opener(
data
=
f
.
read
(
200
)
os
.
remove
(
dest
)
raise
ValueError
(
"%s (%s) is not a tar archive, but a %s, contains %s"
%
(
dest
,
url
,
ftype
,
repr
(
data
))
)
"%s (%s) is not a tar archive, but a %s, contains %s"
%
(
dest
,
url
,
ftype
,
repr
(
data
)))
try
:
stream
=
open
(
dest
,
"rb"
)
sample
.
update
(
stream
=
stream
)
...
...
@@ -158,7 +157,7 @@ def cached_url_opener(
continue
raise
exn
except
Exception
as
exn
:
exn
.
args
=
exn
.
args
+
(
url
,)
exn
.
args
=
exn
.
args
+
(
url
,
)
if
handler
(
exn
):
continue
else
:
...
...
@@ -166,14 +165,13 @@ def cached_url_opener(
def
cached_tarfile_samples
(
src
,
handler
=
reraise_exception
,
cache_size
=-
1
,
cache_dir
=
None
,
verbose
=
False
,
url_to_name
=
pipe_cleaner
,
always
=
False
,
):
src
,
handler
=
reraise_exception
,
cache_size
=-
1
,
cache_dir
=
None
,
verbose
=
False
,
url_to_name
=
pipe_cleaner
,
always
=
False
,
):
streams
=
cached_url_opener
(
src
,
handler
=
handler
,
...
...
@@ -181,8 +179,7 @@ def cached_tarfile_samples(
cache_dir
=
cache_dir
,
verbose
=
verbose
,
url_to_name
=
url_to_name
,
always
=
always
,
)
always
=
always
,
)
samples
=
tar_file_and_group_expander
(
streams
,
handler
=
handler
)
return
samples
...
...
paddlespeech/audio/streamdata/compat.py
浏览文件 @
795eb7bd
...
...
@@ -2,17 +2,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
from
dataclasses
import
dataclass
from
itertools
import
islice
from
typing
import
List
import
braceexpand
,
yaml
import
yaml
from
.
import
autodecode
from
.
import
cache
,
filters
,
shardlists
,
tariterators
from
.
import
cache
from
.
import
filters
from
.
import
shardlists
from
.
import
tariterators
from
.filters
import
reraise_exception
from
.paddle_utils
import
DataLoader
from
.paddle_utils
import
IterableDataset
from
.pipeline
import
DataPipeline
from
.paddle_utils
import
DataLoader
,
IterableDataset
class
FluidInterface
:
...
...
@@ -26,7 +26,8 @@ class FluidInterface:
return
self
.
compose
(
filters
.
unbatched
())
def
listed
(
self
,
batchsize
,
partial
=
True
):
return
self
.
compose
(
filters
.
batched
(),
batchsize
=
batchsize
,
collation_fn
=
None
)
return
self
.
compose
(
filters
.
batched
(),
batchsize
=
batchsize
,
collation_fn
=
None
)
def
unlisted
(
self
):
return
self
.
compose
(
filters
.
unlisted
())
...
...
@@ -43,9 +44,19 @@ class FluidInterface:
def
map
(
self
,
f
,
handler
=
reraise_exception
):
return
self
.
compose
(
filters
.
map
(
f
,
handler
=
handler
))
def
decode
(
self
,
*
args
,
pre
=
None
,
post
=
None
,
only
=
None
,
partial
=
False
,
handler
=
reraise_exception
):
handlers
=
[
autodecode
.
ImageHandler
(
x
)
if
isinstance
(
x
,
str
)
else
x
for
x
in
args
]
decoder
=
autodecode
.
Decoder
(
handlers
,
pre
=
pre
,
post
=
post
,
only
=
only
,
partial
=
partial
)
def
decode
(
self
,
*
args
,
pre
=
None
,
post
=
None
,
only
=
None
,
partial
=
False
,
handler
=
reraise_exception
):
handlers
=
[
autodecode
.
ImageHandler
(
x
)
if
isinstance
(
x
,
str
)
else
x
for
x
in
args
]
decoder
=
autodecode
.
Decoder
(
handlers
,
pre
=
pre
,
post
=
post
,
only
=
only
,
partial
=
partial
)
return
self
.
map
(
decoder
,
handler
=
handler
)
def
map_dict
(
self
,
handler
=
reraise_exception
,
**
kw
):
...
...
@@ -80,12 +91,12 @@ class FluidInterface:
def
audio_data_filter
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_data_filter
(
*
args
,
**
kw
))
def
audio_tokenize
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_tokenize
(
*
args
,
**
kw
))
def
resample
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
resample
(
*
args
,
**
kw
))
return
self
.
compose
(
filters
.
resample
(
*
args
,
**
kw
))
def
audio_compute_fbank
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_compute_fbank
(
*
args
,
**
kw
))
...
...
@@ -102,27 +113,28 @@ class FluidInterface:
def
audio_cmvn
(
self
,
cmvn_file
):
return
self
.
compose
(
filters
.
audio_cmvn
(
cmvn_file
))
class
WebDataset
(
DataPipeline
,
FluidInterface
):
"""Small fluid-interface wrapper for DataPipeline."""
def
__init__
(
self
,
urls
,
handler
=
reraise_exception
,
resampled
=
False
,
repeat
=
False
,
shardshuffle
=
None
,
cache_size
=
0
,
cache_dir
=
None
,
detshuffle
=
False
,
nodesplitter
=
shardlists
.
single_node_only
,
verbose
=
False
,
):
self
,
urls
,
handler
=
reraise_exception
,
resampled
=
False
,
repeat
=
False
,
shardshuffle
=
None
,
cache_size
=
0
,
cache_dir
=
None
,
detshuffle
=
False
,
nodesplitter
=
shardlists
.
single_node_only
,
verbose
=
False
,
):
super
().
__init__
()
if
isinstance
(
urls
,
IterableDataset
):
assert
not
resampled
self
.
append
(
urls
)
elif
isinstance
(
urls
,
str
)
and
(
urls
.
endswith
(
".yaml"
)
or
urls
.
endswith
(
".yml"
)):
elif
isinstance
(
urls
,
str
)
and
(
urls
.
endswith
(
".yaml"
)
or
urls
.
endswith
(
".yml"
)):
with
(
open
(
urls
))
as
stream
:
spec
=
yaml
.
safe_load
(
stream
)
assert
"datasets"
in
spec
...
...
@@ -152,9 +164,7 @@ class WebDataset(DataPipeline, FluidInterface):
handler
=
handler
,
verbose
=
verbose
,
cache_size
=
cache_size
,
cache_dir
=
cache_dir
,
)
)
cache_dir
=
cache_dir
,
))
class
FluidWrapper
(
DataPipeline
,
FluidInterface
):
...
...
paddlespeech/audio/streamdata/extradatasets.py
浏览文件 @
795eb7bd
...
...
@@ -5,20 +5,10 @@
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
import
itertools
as
itt
import
os
import
random
import
sys
import
braceexpand
from
.
import
utils
from
.paddle_utils
import
IterableDataset
from
.utils
import
PipelineStage
...
...
@@ -63,8 +53,7 @@ class repeatedly(IterableDataset, PipelineStage):
return
utils
.
repeatedly
(
source
,
nepochs
=
self
.
nepochs
,
nbatches
=
self
.
nbatches
,
)
nbatches
=
self
.
nbatches
,
)
class
with_epoch
(
IterableDataset
):
...
...
paddlespeech/audio/streamdata/filters.py
浏览文件 @
795eb7bd
...
...
@@ -3,7 +3,6 @@
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""A collection of iterators for data transformations.
...
...
@@ -12,28 +11,29 @@ These functions are plain iterator functions. You can find curried versions
in webdataset.filters, and you can find IterableDataset wrappers in
webdataset.processing.
"""
import
io
from
fnmatch
import
fnmatch
import
itertools
import
os
import
random
import
re
import
itertools
,
os
,
random
,
sys
,
time
from
functools
import
reduce
,
wraps
import
sys
import
time
from
fnmatch
import
fnmatch
from
functools
import
reduce
import
numpy
as
np
import
paddle
from
.
import
autodecode
from
.
import
utils
from
.paddle_utils
import
PaddleTensor
from
.utils
import
PipelineStage
from
.
import
utils
from
..
import
backends
from
..compliance
import
kaldi
import
paddle
from
..transform.cmvn
import
GlobalCMVN
from
..utils.tensor_utils
import
pad_sequence
from
..transform.spec_augment
import
time_warp
from
..transform.spec_augment
import
time_mask
from
..transform.spec_augment
import
freq_mask
from
..transform.spec_augment
import
time_mask
from
..transform.spec_augment
import
time_warp
from
..utils.tensor_utils
import
pad_sequence
from
.utils
import
PipelineStage
class
FilterFunction
(
object
):
"""Helper class for currying pipeline stages.
...
...
@@ -159,10 +159,12 @@ def transform_with(sample, transformers):
result
[
i
]
=
f
(
sample
[
i
])
return
result
###
# Iterators
###
def
_info
(
data
,
fmt
=
None
,
n
=
3
,
every
=-
1
,
width
=
50
,
stream
=
sys
.
stderr
,
name
=
""
):
"""Print information about the samples that are passing through.
...
...
@@ -278,10 +280,16 @@ def _log_keys(data, logfile=None):
log_keys
=
pipelinefilter
(
_log_keys
)
def
_minedecode
(
x
):
if
isinstance
(
x
,
str
):
return
autodecode
.
imagehandler
(
x
)
else
:
return
x
def
_decode
(
data
,
*
args
,
handler
=
reraise_exception
,
**
kw
):
"""Decode data based on the decoding functions given as arguments."""
decoder
=
lambda
x
:
autodecode
.
imagehandler
(
x
)
if
isinstance
(
x
,
str
)
else
x
decoder
=
_minedecode
handlers
=
[
decoder
(
x
)
for
x
in
args
]
f
=
autodecode
.
Decoder
(
handlers
,
**
kw
)
...
...
@@ -325,15 +333,24 @@ def _rename(data, handler=reraise_exception, keep=True, **kw):
for
sample
in
data
:
try
:
if
not
keep
:
yield
{
k
:
getfirst
(
sample
,
v
,
missing_is_error
=
True
)
for
k
,
v
in
kw
.
items
()}
yield
{
k
:
getfirst
(
sample
,
v
,
missing_is_error
=
True
)
for
k
,
v
in
kw
.
items
()
}
else
:
def
listify
(
v
):
return
v
.
split
(
";"
)
if
isinstance
(
v
,
str
)
else
v
to_be_replaced
=
{
x
for
v
in
kw
.
values
()
for
x
in
listify
(
v
)}
result
=
{
k
:
v
for
k
,
v
in
sample
.
items
()
if
k
not
in
to_be_replaced
}
result
.
update
({
k
:
getfirst
(
sample
,
v
,
missing_is_error
=
True
)
for
k
,
v
in
kw
.
items
()})
result
=
{
k
:
v
for
k
,
v
in
sample
.
items
()
if
k
not
in
to_be_replaced
}
result
.
update
({
k
:
getfirst
(
sample
,
v
,
missing_is_error
=
True
)
for
k
,
v
in
kw
.
items
()
})
yield
result
except
Exception
as
exn
:
if
handler
(
exn
):
...
...
@@ -381,7 +398,11 @@ def _map_dict(data, handler=reraise_exception, **kw):
map_dict
=
pipelinefilter
(
_map_dict
)
def
_to_tuple
(
data
,
*
args
,
handler
=
reraise_exception
,
missing_is_error
=
True
,
none_is_error
=
None
):
def
_to_tuple
(
data
,
*
args
,
handler
=
reraise_exception
,
missing_is_error
=
True
,
none_is_error
=
None
):
"""Convert dict samples to tuples."""
if
none_is_error
is
None
:
none_is_error
=
missing_is_error
...
...
@@ -390,7 +411,10 @@ def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, non
for
sample
in
data
:
try
:
result
=
tuple
([
getfirst
(
sample
,
f
,
missing_is_error
=
missing_is_error
)
for
f
in
args
])
result
=
tuple
([
getfirst
(
sample
,
f
,
missing_is_error
=
missing_is_error
)
for
f
in
args
])
if
none_is_error
and
any
(
x
is
None
for
x
in
result
):
raise
ValueError
(
f
"to_tuple
{
args
}
got
{
sample
.
keys
()
}
"
)
yield
result
...
...
@@ -463,19 +487,28 @@ rsample = pipelinefilter(_rsample)
slice
=
pipelinefilter
(
itertools
.
islice
)
def
_extract_keys
(
source
,
*
patterns
,
duplicate_is_error
=
True
,
ignore_missing
=
False
):
def
_extract_keys
(
source
,
*
patterns
,
duplicate_is_error
=
True
,
ignore_missing
=
False
):
for
sample
in
source
:
result
=
[]
for
pattern
in
patterns
:
pattern
=
pattern
.
split
(
";"
)
if
isinstance
(
pattern
,
str
)
else
pattern
matches
=
[
x
for
x
in
sample
.
keys
()
if
any
(
fnmatch
(
"."
+
x
,
p
)
for
p
in
pattern
)]
pattern
=
pattern
.
split
(
";"
)
if
isinstance
(
pattern
,
str
)
else
pattern
matches
=
[
x
for
x
in
sample
.
keys
()
if
any
(
fnmatch
(
"."
+
x
,
p
)
for
p
in
pattern
)
]
if
len
(
matches
)
==
0
:
if
ignore_missing
:
continue
else
:
raise
ValueError
(
f
"Cannot find
{
pattern
}
in sample keys
{
sample
.
keys
()
}
."
)
raise
ValueError
(
f
"Cannot find
{
pattern
}
in sample keys
{
sample
.
keys
()
}
."
)
if
len
(
matches
)
>
1
and
duplicate_is_error
:
raise
ValueError
(
f
"Multiple sample keys
{
sample
.
keys
()
}
match
{
pattern
}
."
)
raise
ValueError
(
f
"Multiple sample keys
{
sample
.
keys
()
}
match
{
pattern
}
."
)
value
=
sample
[
matches
[
0
]]
result
.
append
(
value
)
yield
tuple
(
result
)
...
...
@@ -484,7 +517,12 @@ def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=Fal
extract_keys
=
pipelinefilter
(
_extract_keys
)
def
_rename_keys
(
source
,
*
args
,
keep_unselected
=
False
,
must_match
=
True
,
duplicate_is_error
=
True
,
**
kw
):
def
_rename_keys
(
source
,
*
args
,
keep_unselected
=
False
,
must_match
=
True
,
duplicate_is_error
=
True
,
**
kw
):
renamings
=
[(
pattern
,
output
)
for
output
,
pattern
in
args
]
renamings
+=
[(
pattern
,
output
)
for
output
,
pattern
in
kw
.
items
()]
for
sample
in
source
:
...
...
@@ -504,11 +542,15 @@ def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicat
continue
if
new_name
in
new_sample
:
if
duplicate_is_error
:
raise
ValueError
(
f
"Duplicate value in sample
{
sample
.
keys
()
}
after rename."
)
raise
ValueError
(
f
"Duplicate value in sample
{
sample
.
keys
()
}
after rename."
)
continue
new_sample
[
new_name
]
=
value
if
must_match
and
not
all
(
matched
.
values
()):
raise
ValueError
(
f
"Not all patterns (
{
matched
}
) matched sample keys (
{
sample
.
keys
()
}
)."
)
raise
ValueError
(
f
"Not all patterns (
{
matched
}
) matched sample keys (
{
sample
.
keys
()
}
)."
)
yield
new_sample
...
...
@@ -541,18 +583,18 @@ def find_decoder(decoders, path):
if
fname
.
startswith
(
"__"
):
return
lambda
x
:
x
for
pattern
,
fun
in
decoders
[::
-
1
]:
if
fnmatch
(
fname
.
lower
(),
pattern
)
or
fnmatch
(
"."
+
fname
.
lower
(),
pattern
):
if
fnmatch
(
fname
.
lower
(),
pattern
)
or
fnmatch
(
"."
+
fname
.
lower
(),
pattern
):
return
fun
return
None
def
_xdecode
(
source
,
*
args
,
must_decode
=
True
,
defaults
=
default_decoders
,
**
kw
,
):
source
,
*
args
,
must_decode
=
True
,
defaults
=
default_decoders
,
**
kw
,
):
decoders
=
list
(
defaults
)
+
list
(
args
)
decoders
+=
[(
"*."
+
k
,
v
)
for
k
,
v
in
kw
.
items
()]
for
sample
in
source
:
...
...
@@ -575,18 +617,18 @@ def _xdecode(
new_sample
[
path
]
=
value
yield
new_sample
xdecode
=
pipelinefilter
(
_xdecode
)
xdecode
=
pipelinefilter
(
_xdecode
)
def
_audio_data_filter
(
source
,
frame_shift
=
10
,
max_length
=
10240
,
min_length
=
10
,
token_max_length
=
200
,
token_min_length
=
1
,
min_output_input_ratio
=
0.0005
,
max_output_input_ratio
=
1
):
frame_shift
=
10
,
max_length
=
10240
,
min_length
=
10
,
token_max_length
=
200
,
token_min_length
=
1
,
min_output_input_ratio
=
0.0005
,
max_output_input_ratio
=
1
):
""" Filter sample according to feature and label length
Inplace operation.
...
...
@@ -613,7 +655,8 @@ def _audio_data_filter(source,
assert
'wav'
in
sample
assert
'label'
in
sample
# sample['wav'] is paddle.Tensor, we have 100 frames every second (default)
num_frames
=
sample
[
'wav'
].
shape
[
1
]
/
sample
[
'sample_rate'
]
*
(
1000
/
frame_shift
)
num_frames
=
sample
[
'wav'
].
shape
[
1
]
/
sample
[
'sample_rate'
]
*
(
1000
/
frame_shift
)
if
num_frames
<
min_length
:
continue
if
num_frames
>
max_length
:
...
...
@@ -629,13 +672,15 @@ def _audio_data_filter(source,
continue
yield
sample
audio_data_filter
=
pipelinefilter
(
_audio_data_filter
)
def
_audio_tokenize
(
source
,
symbol_table
,
bpe_model
=
None
,
non_lang_syms
=
None
,
split_with_space
=
False
):
symbol_table
,
bpe_model
=
None
,
non_lang_syms
=
None
,
split_with_space
=
False
):
""" Decode text to chars or BPE
Inplace operation
...
...
@@ -693,8 +738,10 @@ def _audio_tokenize(source,
sample
[
'label'
]
=
label
yield
sample
audio_tokenize
=
pipelinefilter
(
_audio_tokenize
)
def
_audio_resample
(
source
,
resample_rate
=
16000
):
""" Resample data.
Inplace operation.
...
...
@@ -713,18 +760,22 @@ def _audio_resample(source, resample_rate=16000):
waveform
=
sample
[
'wav'
]
if
sample_rate
!=
resample_rate
:
sample
[
'sample_rate'
]
=
resample_rate
sample
[
'wav'
]
=
paddle
.
to_tensor
(
backends
.
soundfile_backend
.
resample
(
waveform
.
numpy
(),
src_sr
=
sample_rate
,
target_sr
=
resample_rate
))
sample
[
'wav'
]
=
paddle
.
to_tensor
(
backends
.
soundfile_backend
.
resample
(
waveform
.
numpy
(),
src_sr
=
sample_rate
,
target_sr
=
resample_rate
))
yield
sample
audio_resample
=
pipelinefilter
(
_audio_resample
)
def
_audio_compute_fbank
(
source
,
num_mel_bins
=
80
,
frame_length
=
25
,
frame_shift
=
10
,
dither
=
0.0
):
num_mel_bins
=
80
,
frame_length
=
25
,
frame_shift
=
10
,
dither
=
0.0
):
""" Extract fbank
Args:
...
...
@@ -746,30 +797,33 @@ def _audio_compute_fbank(source,
waveform
=
sample
[
'wav'
]
waveform
=
waveform
*
(
1
<<
15
)
# Only keep fname, feat, label
mat
=
kaldi
.
fbank
(
waveform
,
n_mels
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
,
energy_floor
=
0.0
,
sr
=
sample_rate
)
mat
=
kaldi
.
fbank
(
waveform
,
n_mels
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
,
energy_floor
=
0.0
,
sr
=
sample_rate
)
yield
dict
(
fname
=
sample
[
'fname'
],
label
=
sample
[
'label'
],
feat
=
mat
)
audio_compute_fbank
=
pipelinefilter
(
_audio_compute_fbank
)
def
_audio_spec_aug
(
source
,
max_w
=
5
,
w_inplace
=
True
,
w_mode
=
"PIL"
,
max_f
=
30
,
num_f_mask
=
2
,
f_inplace
=
True
,
f_replace_with_zero
=
False
,
max_t
=
40
,
num_t_mask
=
2
,
t_inplace
=
True
,
t_replace_with_zero
=
False
,):
def
_audio_spec_aug
(
source
,
max_w
=
5
,
w_inplace
=
True
,
w_mode
=
"PIL"
,
max_f
=
30
,
num_f_mask
=
2
,
f_inplace
=
True
,
f_replace_with_zero
=
False
,
max_t
=
40
,
num_t_mask
=
2
,
t_inplace
=
True
,
t_replace_with_zero
=
False
,
):
""" Do spec augmentation
Inplace operation
...
...
@@ -793,12 +847,23 @@ def _audio_spec_aug(source,
for
sample
in
source
:
x
=
sample
[
'feat'
]
x
=
x
.
numpy
()
x
=
time_warp
(
x
,
max_time_warp
=
max_w
,
inplace
=
w_inplace
,
mode
=
w_mode
)
x
=
freq_mask
(
x
,
F
=
max_f
,
n_mask
=
num_f_mask
,
inplace
=
f_inplace
,
replace_with_zero
=
f_replace_with_zero
)
x
=
time_mask
(
x
,
T
=
max_t
,
n_mask
=
num_t_mask
,
inplace
=
t_inplace
,
replace_with_zero
=
t_replace_with_zero
)
x
=
time_warp
(
x
,
max_time_warp
=
max_w
,
inplace
=
w_inplace
,
mode
=
w_mode
)
x
=
freq_mask
(
x
,
F
=
max_f
,
n_mask
=
num_f_mask
,
inplace
=
f_inplace
,
replace_with_zero
=
f_replace_with_zero
)
x
=
time_mask
(
x
,
T
=
max_t
,
n_mask
=
num_t_mask
,
inplace
=
t_inplace
,
replace_with_zero
=
t_replace_with_zero
)
sample
[
'feat'
]
=
paddle
.
to_tensor
(
x
,
dtype
=
paddle
.
float32
)
yield
sample
audio_spec_aug
=
pipelinefilter
(
_audio_spec_aug
)
...
...
@@ -829,8 +894,10 @@ def _sort(source, sort_size=500):
for
x
in
buf
:
yield
x
sort
=
pipelinefilter
(
_sort
)
def
_batched
(
source
,
batch_size
=
16
):
""" Static batch the data by `batch_size`
...
...
@@ -850,8 +917,10 @@ def _batched(source, batch_size=16):
if
len
(
buf
)
>
0
:
yield
buf
batched
=
pipelinefilter
(
_batched
)
def
dynamic_batched
(
source
,
max_frames_in_batch
=
12000
):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
...
...
@@ -892,8 +961,8 @@ def _audio_padding(source):
"""
for
sample
in
source
:
assert
isinstance
(
sample
,
list
)
feats_length
=
paddle
.
to_tensor
(
[
x
[
'feat'
].
shape
[
0
]
for
x
in
sample
],
dtype
=
"int64"
)
feats_length
=
paddle
.
to_tensor
(
[
x
[
'feat'
].
shape
[
0
]
for
x
in
sample
],
dtype
=
"int64"
)
order
=
paddle
.
argsort
(
feats_length
,
descending
=
True
)
feats_lengths
=
paddle
.
to_tensor
(
[
sample
[
i
][
'feat'
].
shape
[
0
]
for
i
in
order
],
dtype
=
"int64"
)
...
...
@@ -902,20 +971,20 @@ def _audio_padding(source):
sorted_labels
=
[
paddle
.
to_tensor
(
sample
[
i
][
'label'
],
dtype
=
"int32"
)
for
i
in
order
]
label_lengths
=
paddle
.
to_tensor
([
x
.
shape
[
0
]
for
x
in
sorted_labels
],
dtype
=
"int64"
)
padded_feats
=
pad_sequence
(
sorted_feats
,
batch_first
=
True
,
padding_value
=
0
)
padding_labels
=
pad_sequence
(
sorted_labels
,
batch_first
=
True
,
padding_value
=-
1
)
yield
(
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
label_lengths
=
paddle
.
to_tensor
(
[
x
.
shape
[
0
]
for
x
in
sorted_labels
],
dtype
=
"int64"
)
padded_feats
=
pad_sequence
(
sorted_feats
,
batch_first
=
True
,
padding_value
=
0
)
padding_labels
=
pad_sequence
(
sorted_labels
,
batch_first
=
True
,
padding_value
=-
1
)
yield
(
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
label_lengths
)
audio_padding
=
pipelinefilter
(
_audio_padding
)
def
_audio_cmvn
(
source
,
cmvn_file
):
global_cmvn
=
GlobalCMVN
(
cmvn_file
)
for
batch
in
source
:
...
...
@@ -923,13 +992,16 @@ def _audio_cmvn(source, cmvn_file):
padded_feats
=
padded_feats
.
numpy
()
padded_feats
=
global_cmvn
(
padded_feats
)
padded_feats
=
paddle
.
to_tensor
(
padded_feats
,
dtype
=
paddle
.
float32
)
yield
(
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
label_lengths
)
yield
(
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
label_lengths
)
audio_cmvn
=
pipelinefilter
(
_audio_cmvn
)
def
_placeholder
(
source
):
for
data
in
source
:
yield
data
placeholder
=
pipelinefilter
(
_placeholder
)
paddlespeech/audio/streamdata/gopen.py
浏览文件 @
795eb7bd
...
...
@@ -3,12 +3,12 @@
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Open URLs by calling subcommands."""
import
os
,
sys
,
re
from
subprocess
import
PIPE
,
Popen
import
os
import
re
import
sys
from
subprocess
import
PIPE
from
subprocess
import
Popen
from
urllib.parse
import
urlparse
# global used for printing additional node information during verbose output
...
...
@@ -31,14 +31,13 @@ class Pipe:
"""
def
__init__
(
self
,
*
args
,
mode
=
None
,
timeout
=
7200.0
,
ignore_errors
=
False
,
ignore_status
=
[],
**
kw
,
):
self
,
*
args
,
mode
=
None
,
timeout
=
7200.0
,
ignore_errors
=
False
,
ignore_status
=
[],
**
kw
,
):
"""Create an IO Pipe."""
self
.
ignore_errors
=
ignore_errors
self
.
ignore_status
=
[
0
]
+
ignore_status
...
...
@@ -75,8 +74,7 @@ class Pipe:
if
verbose
:
print
(
f
"pipe exit [
{
self
.
status
}
{
os
.
getpid
()
}
:
{
self
.
proc
.
pid
}
]
{
self
.
args
}
{
info
}
"
,
file
=
sys
.
stderr
,
)
file
=
sys
.
stderr
,
)
if
self
.
status
not
in
self
.
ignore_status
and
not
self
.
ignore_errors
:
raise
Exception
(
f
"
{
self
.
args
}
: exit
{
self
.
status
}
(read)
{
info
}
"
)
...
...
@@ -114,9 +112,11 @@ class Pipe:
self
.
close
()
def
set_options
(
obj
,
timeout
=
None
,
ignore_errors
=
None
,
ignore_status
=
None
,
handler
=
None
):
def
set_options
(
obj
,
timeout
=
None
,
ignore_errors
=
None
,
ignore_status
=
None
,
handler
=
None
):
"""Set options for Pipes.
This function can be called on any stream. It will set pipe options only
...
...
@@ -168,16 +168,14 @@ def gopen_pipe(url, mode="rb", bufsize=8192):
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
],
)
# skipcq: BAN-B604
ignore_status
=
[
141
],
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
return
Pipe
(
cmd
,
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
],
)
# skipcq: BAN-B604
ignore_status
=
[
141
],
)
# skipcq: BAN-B604
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
...
...
@@ -196,8 +194,7 @@ def gopen_curl(url, mode="rb", bufsize=8192):
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
cmd
=
f
"curl -s -L -T - '
{
url
}
'"
return
Pipe
(
...
...
@@ -205,8 +202,7 @@ def gopen_curl(url, mode="rb", bufsize=8192):
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
26
],
)
# skipcq: BAN-B604
ignore_status
=
[
141
,
26
],
)
# skipcq: BAN-B604
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
...
...
@@ -226,15 +222,13 @@ def gopen_htgs(url, mode="rb", bufsize=8192):
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
raise
ValueError
(
f
"
{
mode
}
: cannot write"
)
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
def
gopen_gsutil
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
"""Open a URL with `curl`.
...
...
@@ -249,8 +243,7 @@ def gopen_gsutil(url, mode="rb", bufsize=8192):
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
cmd
=
f
"gsutil cp - '
{
url
}
'"
return
Pipe
(
...
...
@@ -258,13 +251,11 @@ def gopen_gsutil(url, mode="rb", bufsize=8192):
mode
=
mode
,
shell
=
True
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
26
],
)
# skipcq: BAN-B604
ignore_status
=
[
141
,
26
],
)
# skipcq: BAN-B604
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
def
gopen_error
(
url
,
*
args
,
**
kw
):
"""Raise a value error.
...
...
@@ -285,8 +276,7 @@ gopen_schemes = dict(
ftps
=
gopen_curl
,
scp
=
gopen_curl
,
gs
=
gopen_gsutil
,
htgs
=
gopen_htgs
,
)
htgs
=
gopen_htgs
,
)
def
gopen
(
url
,
mode
=
"rb"
,
bufsize
=
8192
,
**
kw
):
...
...
paddlespeech/audio/streamdata/handlers.py
浏览文件 @
795eb7bd
...
...
@@ -3,7 +3,6 @@
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Pluggable exception handlers.
These are functions that take an exception as an argument and then return...
...
...
@@ -14,8 +13,8 @@ These are functions that take an exception as an argument and then return...
They are used as handler= arguments in much of the library.
"""
import
time
,
warnings
import
time
import
warnings
def
reraise_exception
(
exn
):
...
...
paddlespeech/audio/streamdata/mix.py
浏览文件 @
795eb7bd
...
...
@@ -5,17 +5,12 @@
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes for mixing samples from multiple sources."""
import
itertools
,
os
,
random
,
time
,
sys
from
functools
import
reduce
,
wraps
import
random
import
numpy
as
np
from
.
import
autodecode
,
utils
from
.paddle_utils
import
PaddleTensor
,
IterableDataset
from
.utils
import
PipelineStage
from
.paddle_utils
import
IterableDataset
def
round_robin_shortest
(
*
sources
):
...
...
paddlespeech/audio/streamdata/paddle_utils.py
浏览文件 @
795eb7bd
...
...
@@ -5,12 +5,11 @@
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Mock implementations of paddle interfaces when paddle is not available."""
try
:
from
paddle.io
import
DataLoader
,
IterableDataset
from
paddle.io
import
DataLoader
from
paddle.io
import
IterableDataset
except
ModuleNotFoundError
:
class
IterableDataset
:
...
...
@@ -22,12 +21,3 @@ except ModuleNotFoundError:
"""Empty implementation of DataLoader when paddle is not available."""
pass
try
:
from
paddle
import
Tensor
as
PaddleTensor
except
ModuleNotFoundError
:
class
TorchTensor
:
"""Empty implementation of PaddleTensor when paddle is not available."""
pass
paddlespeech/audio/streamdata/pipeline.py
浏览文件 @
795eb7bd
...
...
@@ -3,15 +3,12 @@
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#%%
import
copy
,
os
,
random
,
sys
,
time
from
dataclasses
import
dataclas
s
import
copy
import
sy
s
from
itertools
import
islice
from
typing
import
List
import
braceexpand
,
yaml
from
.handlers
import
reraise_exception
from
.paddle_utils
import
DataLoader
,
IterableDataset
from
.paddle_utils
import
DataLoader
from
.paddle_utils
import
IterableDataset
from
.utils
import
PipelineStage
...
...
@@ -22,8 +19,7 @@ def add_length_method(obj):
Combined
=
type
(
obj
.
__class__
.
__name__
+
"_Length"
,
(
obj
.
__class__
,
IterableDataset
),
{
"__len__"
:
length
},
)
{
"__len__"
:
length
},
)
obj
.
__class__
=
Combined
return
obj
...
...
paddlespeech/audio/streamdata/shardlists.py
浏览文件 @
795eb7bd
...
...
@@ -4,28 +4,30 @@
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
import
os
,
random
,
sys
,
time
from
dataclasses
import
dataclass
,
field
import
os
import
random
import
sys
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
field
from
itertools
import
islice
from
typing
import
List
import
braceexpand
,
yaml
import
braceexpand
import
yaml
from
.
import
utils
from
..utils.log
import
Logger
from
.filters
import
pipelinefilter
from
.paddle_utils
import
IterableDataset
logger
=
Logger
(
__name__
)
from
..utils.log
import
Logger
logger
=
Logger
(
__name__
)
def
expand_urls
(
urls
):
if
isinstance
(
urls
,
str
):
urllist
=
urls
.
split
(
"::"
)
...
...
@@ -64,7 +66,8 @@ class SimpleShardList(IterableDataset):
def
split_by_node
(
src
,
group
=
None
):
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
(
group
=
group
)
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
(
group
=
group
)
logger
.
info
(
f
"world_size:
{
world_size
}
, rank:
{
rank
}
"
)
if
world_size
>
1
:
for
s
in
islice
(
src
,
rank
,
None
,
world_size
):
...
...
@@ -75,9 +78,11 @@ def split_by_node(src, group=None):
def
single_node_only
(
src
,
group
=
None
):
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
(
group
=
group
)
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
(
group
=
group
)
if
world_size
>
1
:
raise
ValueError
(
"input pipeline needs to be reconfigured for multinode training"
)
raise
ValueError
(
"input pipeline needs to be reconfigured for multinode training"
)
for
s
in
src
:
yield
s
...
...
@@ -104,7 +109,8 @@ def resampled_(src, n=sys.maxsize):
rng
=
random
.
Random
(
seed
)
print
(
"# resampled loading"
,
file
=
sys
.
stderr
)
items
=
list
(
src
)
print
(
f
"# resampled got
{
len
(
items
)
}
samples, yielding
{
n
}
"
,
file
=
sys
.
stderr
)
print
(
f
"# resampled got
{
len
(
items
)
}
samples, yielding
{
n
}
"
,
file
=
sys
.
stderr
)
for
i
in
range
(
n
):
yield
rng
.
choice
(
items
)
...
...
@@ -118,7 +124,9 @@ def non_empty(src):
yield
s
count
+=
1
if
count
==
0
:
raise
ValueError
(
"pipeline stage received no data at all and this was declared as an error"
)
raise
ValueError
(
"pipeline stage received no data at all and this was declared as an error"
)
@
dataclass
...
...
@@ -138,10 +146,6 @@ def expand(s):
return
os
.
path
.
expanduser
(
os
.
path
.
expandvars
(
s
))
class
MultiShardSample
(
IterableDataset
):
def
__init__
(
self
,
fname
):
"""Construct a shardlist from multiple sources using a YAML spec."""
self
.
epoch
=
-
1
class
MultiShardSample
(
IterableDataset
):
def
__init__
(
self
,
fname
):
"""Construct a shardlist from multiple sources using a YAML spec."""
...
...
@@ -156,20 +160,23 @@ class MultiShardSample(IterableDataset):
else
:
with
open
(
fname
)
as
stream
:
spec
=
yaml
.
safe_load
(
stream
)
assert
set
(
spec
.
keys
()).
issubset
(
set
(
"prefix datasets buckets"
.
split
())),
list
(
spec
.
keys
())
assert
set
(
spec
.
keys
()).
issubset
(
set
(
"prefix datasets buckets"
.
split
())),
list
(
spec
.
keys
())
prefix
=
expand
(
spec
.
get
(
"prefix"
,
""
))
self
.
sources
=
[]
for
ds
in
spec
[
"datasets"
]:
assert
set
(
ds
.
keys
()).
issubset
(
set
(
"buckets name shards resample choose"
.
split
())),
list
(
ds
.
keys
()
)
assert
set
(
ds
.
keys
()).
issubset
(
set
(
"buckets name shards resample choose"
.
split
())),
list
(
ds
.
keys
()
)
buckets
=
ds
.
get
(
"buckets"
,
spec
.
get
(
"buckets"
,
[]))
if
isinstance
(
buckets
,
str
):
buckets
=
[
buckets
]
buckets
=
[
expand
(
s
)
for
s
in
buckets
]
if
buckets
==
[]:
buckets
=
[
""
]
assert
len
(
buckets
)
==
1
,
f
"
{
buckets
}
: FIXME support for multiple buckets unimplemented"
assert
len
(
buckets
)
==
1
,
f
"
{
buckets
}
: FIXME support for multiple buckets unimplemented"
bucket
=
buckets
[
0
]
name
=
ds
.
get
(
"name"
,
"@"
+
bucket
)
urls
=
ds
[
"shards"
]
...
...
@@ -177,15 +184,19 @@ class MultiShardSample(IterableDataset):
urls
=
[
urls
]
# urls = [u for url in urls for u in braceexpand.braceexpand(url)]
urls
=
[
prefix
+
os
.
path
.
join
(
bucket
,
u
)
for
url
in
urls
for
u
in
braceexpand
.
braceexpand
(
expand
(
url
))
prefix
+
os
.
path
.
join
(
bucket
,
u
)
for
url
in
urls
for
u
in
braceexpand
.
braceexpand
(
expand
(
url
))
]
resample
=
ds
.
get
(
"resample"
,
-
1
)
nsample
=
ds
.
get
(
"choose"
,
-
1
)
if
nsample
>
len
(
urls
):
raise
ValueError
(
f
"perepoch
{
nsample
}
must be no greater than the number of shards"
)
raise
ValueError
(
f
"perepoch
{
nsample
}
must be no greater than the number of shards"
)
if
(
nsample
>
0
)
and
(
resample
>
0
):
raise
ValueError
(
"specify only one of perepoch or choose"
)
entry
=
MSSource
(
name
=
name
,
urls
=
urls
,
perepoch
=
nsample
,
resample
=
resample
)
entry
=
MSSource
(
name
=
name
,
urls
=
urls
,
perepoch
=
nsample
,
resample
=
resample
)
self
.
sources
.
append
(
entry
)
print
(
f
"#
{
name
}
{
len
(
urls
)
}
{
nsample
}
"
,
file
=
sys
.
stderr
)
...
...
@@ -203,7 +214,7 @@ class MultiShardSample(IterableDataset):
# sample without replacement
l
=
list
(
source
.
urls
)
self
.
rng
.
shuffle
(
l
)
l
=
l
[:
source
.
perepoch
]
l
=
l
[:
source
.
perepoch
]
else
:
l
=
list
(
source
.
urls
)
result
+=
l
...
...
@@ -227,12 +238,11 @@ class ResampledShards(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def
__init__
(
self
,
urls
,
nshards
=
sys
.
maxsize
,
worker_seed
=
None
,
deterministic
=
False
,
):
self
,
urls
,
nshards
=
sys
.
maxsize
,
worker_seed
=
None
,
deterministic
=
False
,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
...
...
@@ -252,7 +262,8 @@ class ResampledShards(IterableDataset):
if
self
.
deterministic
:
seed
=
utils
.
make_seed
(
self
.
worker_seed
(),
self
.
epoch
)
else
:
seed
=
utils
.
make_seed
(
self
.
worker_seed
(),
self
.
epoch
,
os
.
getpid
(),
time
.
time_ns
(),
os
.
urandom
(
4
))
seed
=
utils
.
make_seed
(
self
.
worker_seed
(),
self
.
epoch
,
os
.
getpid
(),
time
.
time_ns
(),
os
.
urandom
(
4
))
if
os
.
environ
.
get
(
"WDS_SHOW_SEED"
,
"0"
)
==
"1"
:
print
(
f
"# ResampledShards seed
{
seed
}
"
)
self
.
rng
=
random
.
Random
(
seed
)
...
...
paddlespeech/audio/streamdata/tariterators.py
浏览文件 @
795eb7bd
...
...
@@ -3,13 +3,12 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Low level iteration functions for tar archives."""
import
random
,
re
,
tarfile
import
random
import
re
import
tarfile
import
braceexpand
...
...
@@ -27,6 +26,7 @@ import numpy as np
AUDIO_FORMAT_SETS
=
set
([
'flac'
,
'mp3'
,
'm4a'
,
'ogg'
,
'opus'
,
'wav'
,
'wma'
])
def
base_plus_ext
(
path
):
"""Split off all file extensions.
...
...
@@ -47,12 +47,8 @@ def valid_sample(sample):
:param sample: sample to be checked
"""
return
(
sample
is
not
None
and
isinstance
(
sample
,
dict
)
and
len
(
list
(
sample
.
keys
()))
>
0
and
not
sample
.
get
(
"__bad__"
,
False
)
)
return
(
sample
is
not
None
and
isinstance
(
sample
,
dict
)
and
len
(
list
(
sample
.
keys
()))
>
0
and
not
sample
.
get
(
"__bad__"
,
False
))
# FIXME: UNUSED
...
...
@@ -79,16 +75,16 @@ def url_opener(data, handler=reraise_exception, **kw):
sample
.
update
(
stream
=
stream
)
yield
sample
except
Exception
as
exn
:
exn
.
args
=
exn
.
args
+
(
url
,)
exn
.
args
=
exn
.
args
+
(
url
,
)
if
handler
(
exn
):
continue
else
:
break
def
tar_file_iterator
(
fileobj
,
skip_meta
=
r
"__[^/]*__($|/)"
,
handler
=
reraise_exception
):
def
tar_file_iterator
(
fileobj
,
skip_meta
=
r
"__[^/]*__($|/)"
,
handler
=
reraise_exception
):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
:param fileobj: byte stream suitable for tarfile
...
...
@@ -103,11 +99,8 @@ def tar_file_iterator(
continue
if
fname
is
None
:
continue
if
(
"/"
not
in
fname
and
fname
.
startswith
(
meta_prefix
)
and
fname
.
endswith
(
meta_suffix
)
):
if
(
"/"
not
in
fname
and
fname
.
startswith
(
meta_prefix
)
and
fname
.
endswith
(
meta_suffix
)):
# skipping metadata for now
continue
if
skip_meta
is
not
None
and
re
.
match
(
skip_meta
,
fname
):
...
...
@@ -118,8 +111,10 @@ def tar_file_iterator(
assert
pos
>
0
prefix
,
postfix
=
name
[:
pos
],
name
[
pos
+
1
:]
if
postfix
==
'wav'
:
waveform
,
sample_rate
=
paddlespeech
.
audio
.
load
(
stream
.
extractfile
(
tarinfo
),
normal
=
False
)
result
=
dict
(
fname
=
prefix
,
wav
=
waveform
,
sample_rate
=
sample_rate
)
waveform
,
sample_rate
=
paddlespeech
.
audio
.
load
(
stream
.
extractfile
(
tarinfo
),
normal
=
False
)
result
=
dict
(
fname
=
prefix
,
wav
=
waveform
,
sample_rate
=
sample_rate
)
else
:
txt
=
stream
.
extractfile
(
tarinfo
).
read
().
decode
(
'utf8'
).
strip
()
result
=
dict
(
fname
=
prefix
,
txt
=
txt
)
...
...
@@ -128,16 +123,17 @@ def tar_file_iterator(
stream
.
members
=
[]
except
Exception
as
exn
:
if
hasattr
(
exn
,
"args"
)
and
len
(
exn
.
args
)
>
0
:
exn
.
args
=
(
exn
.
args
[
0
]
+
" @ "
+
str
(
fileobj
),)
+
exn
.
args
[
1
:]
exn
.
args
=
(
exn
.
args
[
0
]
+
" @ "
+
str
(
fileobj
),
)
+
exn
.
args
[
1
:]
if
handler
(
exn
):
continue
else
:
break
del
stream
def
tar_file_and_group_iterator
(
fileobj
,
skip_meta
=
r
"__[^/]*__($|/)"
,
handler
=
reraise_exception
):
def
tar_file_and_group_iterator
(
fileobj
,
skip_meta
=
r
"__[^/]*__($|/)"
,
handler
=
reraise_exception
):
""" Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
...
...
@@ -167,8 +163,11 @@ def tar_file_and_group_iterator(
if
postfix
==
'txt'
:
example
[
'txt'
]
=
file_obj
.
read
().
decode
(
'utf8'
).
strip
()
elif
postfix
in
AUDIO_FORMAT_SETS
:
waveform
,
sample_rate
=
paddlespeech
.
audio
.
load
(
file_obj
,
normal
=
False
)
waveform
=
paddle
.
to_tensor
(
np
.
expand_dims
(
np
.
array
(
waveform
),
0
),
dtype
=
paddle
.
float32
)
waveform
,
sample_rate
=
paddlespeech
.
audio
.
load
(
file_obj
,
normal
=
False
)
waveform
=
paddle
.
to_tensor
(
np
.
expand_dims
(
np
.
array
(
waveform
),
0
),
dtype
=
paddle
.
float32
)
example
[
'wav'
]
=
waveform
example
[
'sample_rate'
]
=
sample_rate
...
...
@@ -176,19 +175,21 @@ def tar_file_and_group_iterator(
example
[
postfix
]
=
file_obj
.
read
()
except
Exception
as
exn
:
if
hasattr
(
exn
,
"args"
)
and
len
(
exn
.
args
)
>
0
:
exn
.
args
=
(
exn
.
args
[
0
]
+
" @ "
+
str
(
fileobj
),)
+
exn
.
args
[
1
:]
exn
.
args
=
(
exn
.
args
[
0
]
+
" @ "
+
str
(
fileobj
),
)
+
exn
.
args
[
1
:]
if
handler
(
exn
):
continue
else
:
break
valid
=
False
# logging.warning('error to parse {}'.format(name))
# logging.warning('error to parse {}'.format(name))
prev_prefix
=
prefix
if
prev_prefix
is
not
None
:
example
[
'fname'
]
=
prev_prefix
yield
example
stream
.
close
()
def
tar_file_expander
(
data
,
handler
=
reraise_exception
):
"""Expand a stream of open tar files into a stream of tar file contents.
...
...
@@ -200,9 +201,8 @@ def tar_file_expander(data, handler=reraise_exception):
assert
isinstance
(
source
,
dict
)
assert
"stream"
in
source
for
sample
in
tar_file_iterator
(
source
[
"stream"
]):
assert
(
isinstance
(
sample
,
dict
)
and
"data"
in
sample
and
"fname"
in
sample
)
assert
(
isinstance
(
sample
,
dict
)
and
"data"
in
sample
and
"fname"
in
sample
)
sample
[
"__url__"
]
=
url
yield
sample
except
Exception
as
exn
:
...
...
@@ -213,8 +213,6 @@ def tar_file_expander(data, handler=reraise_exception):
break
def
tar_file_and_group_expander
(
data
,
handler
=
reraise_exception
):
"""Expand a stream of open tar files into a stream of tar file contents.
...
...
@@ -226,9 +224,8 @@ def tar_file_and_group_expander(data, handler=reraise_exception):
assert
isinstance
(
source
,
dict
)
assert
"stream"
in
source
for
sample
in
tar_file_and_group_iterator
(
source
[
"stream"
]):
assert
(
isinstance
(
sample
,
dict
)
and
"wav"
in
sample
and
"txt"
in
sample
and
"fname"
in
sample
)
assert
(
isinstance
(
sample
,
dict
)
and
"wav"
in
sample
and
"txt"
in
sample
and
"fname"
in
sample
)
sample
[
"__url__"
]
=
url
yield
sample
except
Exception
as
exn
:
...
...
@@ -239,7 +236,11 @@ def tar_file_and_group_expander(data, handler=reraise_exception):
break
def
group_by_keys
(
data
,
keys
=
base_plus_ext
,
lcase
=
True
,
suffixes
=
None
,
handler
=
None
):
def
group_by_keys
(
data
,
keys
=
base_plus_ext
,
lcase
=
True
,
suffixes
=
None
,
handler
=
None
):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
...
...
@@ -254,8 +255,8 @@ def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=N
print
(
prefix
,
suffix
,
current_sample
.
keys
()
if
isinstance
(
current_sample
,
dict
)
else
None
,
)
current_sample
.
keys
()
if
isinstance
(
current_sample
,
dict
)
else
None
,
)
if
prefix
is
None
:
continue
if
lcase
:
...
...
paddlespeech/audio/streamdata/utils.py
浏览文件 @
795eb7bd
...
...
@@ -4,22 +4,23 @@
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
# Modified from https://github.com/webdataset/webdataset
"""Miscellaneous utility functions."""
import
importlib
import
itertools
as
itt
import
os
import
re
import
sys
from
typing
import
Any
,
Callable
,
Iterator
,
Optional
,
Union
from
typing
import
Any
from
typing
import
Callable
from
typing
import
Iterator
from
typing
import
Union
from
..utils.log
import
Logger
logger
=
Logger
(
__name__
)
def
make_seed
(
*
args
):
seed
=
0
for
arg
in
args
:
...
...
@@ -37,7 +38,7 @@ def identity(x: Any) -> Any:
return
x
def
safe_eval
(
s
:
str
,
expr
:
str
=
"{}"
):
def
safe_eval
(
s
:
str
,
expr
:
str
=
"{}"
):
"""Evaluate the given expression more safely."""
if
re
.
sub
(
"[^A-Za-z0-9_]"
,
""
,
s
)
!=
s
:
raise
ValueError
(
f
"safe_eval: illegal characters in: '
{
s
}
'"
)
...
...
@@ -54,9 +55,9 @@ def lookup_sym(sym: str, modules: list):
return
None
def
repeatedly0
(
loader
:
Iterator
,
nepochs
:
int
=
sys
.
maxsize
,
nbatches
:
int
=
sys
.
maxsize
):
def
repeatedly0
(
loader
:
Iterator
,
nepochs
:
int
=
sys
.
maxsize
,
nbatches
:
int
=
sys
.
maxsize
):
"""Repeatedly returns batches from a DataLoader."""
for
epoch
in
range
(
nepochs
):
for
sample
in
itt
.
islice
(
loader
,
nbatches
):
...
...
@@ -69,12 +70,11 @@ def guess_batchsize(batch: Union[tuple, list]):
def
repeatedly
(
source
:
Iterator
,
nepochs
:
int
=
None
,
nbatches
:
int
=
None
,
nsamples
:
int
=
None
,
batchsize
:
Callable
[...,
int
]
=
guess_batchsize
,
):
source
:
Iterator
,
nepochs
:
int
=
None
,
nbatches
:
int
=
None
,
nsamples
:
int
=
None
,
batchsize
:
Callable
[...,
int
]
=
guess_batchsize
,
):
"""Repeatedly yield samples from an iterator."""
epoch
=
0
batch
=
0
...
...
@@ -93,6 +93,7 @@ def repeatedly(
if
nepochs
is
not
None
and
epoch
>=
nepochs
:
return
def
paddle_worker_info
(
group
=
None
):
"""Return node and worker info for PyTorch and some distributed environments."""
rank
=
0
...
...
@@ -116,7 +117,7 @@ def paddle_worker_info(group=None):
else
:
try
:
from
paddle.io
import
get_worker_info
worker_info
=
paddle
.
io
.
get_worker_info
()
worker_info
=
get_worker_info
()
if
worker_info
is
not
None
:
worker
=
worker_info
.
id
num_workers
=
worker_info
.
num_workers
...
...
@@ -126,6 +127,7 @@ def paddle_worker_info(group=None):
return
rank
,
world_size
,
worker
,
num_workers
def
paddle_worker_seed
(
group
=
None
):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank
,
world_size
,
worker
,
num_workers
=
paddle_worker_info
(
group
=
group
)
...
...
paddlespeech/audio/streamdata/writer.py
浏览文件 @
795eb7bd
...
...
@@ -5,18 +5,24 @@
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
"""Classes and functions for writing tar files and WebDataset files."""
import
io
,
json
,
pickle
,
re
,
tarfile
,
time
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
io
import
json
import
pickle
import
re
import
tarfile
import
time
from
typing
import
Any
from
typing
import
Callable
from
typing
import
Optional
from
typing
import
Union
import
numpy
as
np
from
.
import
gopen
def
imageencoder
(
image
:
Any
,
format
:
str
=
"PNG"
):
# skipcq: PYL-W0622
def
imageencoder
(
image
:
Any
,
format
:
str
=
"PNG"
):
# skipcq: PYL-W0622
"""Compress an image using PIL and return it as a string.
Can handle float or uint8 images.
...
...
@@ -67,6 +73,7 @@ def bytestr(data: Any):
return
data
.
encode
(
"ascii"
)
return
str
(
data
).
encode
(
"ascii"
)
def
paddle_dumps
(
data
:
Any
):
"""Dump data into a bytestring using paddle.dumps.
...
...
@@ -82,6 +89,7 @@ def paddle_dumps(data: Any):
paddle
.
save
(
data
,
stream
)
return
stream
.
getvalue
()
def
numpy_dumps
(
data
:
np
.
ndarray
):
"""Dump data into a bytestring using numpy npy format.
...
...
@@ -139,9 +147,8 @@ def add_handlers(d, keys, value):
def
make_handlers
():
"""Create a list of handlers for encoding data."""
handlers
=
{}
add_handlers
(
handlers
,
"cls cls2 class count index inx id"
,
lambda
x
:
str
(
x
).
encode
(
"ascii"
)
)
add_handlers
(
handlers
,
"cls cls2 class count index inx id"
,
lambda
x
:
str
(
x
).
encode
(
"ascii"
))
add_handlers
(
handlers
,
"txt text transcript"
,
lambda
x
:
x
.
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"html htm"
,
lambda
x
:
x
.
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"pyd pickle"
,
pickle
.
dumps
)
...
...
@@ -152,7 +159,8 @@ def make_handlers():
add_handlers
(
handlers
,
"json jsn"
,
lambda
x
:
json
.
dumps
(
x
).
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"mp msgpack msg"
,
mp_dumps
)
add_handlers
(
handlers
,
"cbor"
,
cbor_dumps
)
add_handlers
(
handlers
,
"jpg jpeg img image"
,
lambda
data
:
imageencoder
(
data
,
"jpg"
))
add_handlers
(
handlers
,
"jpg jpeg img image"
,
lambda
data
:
imageencoder
(
data
,
"jpg"
))
add_handlers
(
handlers
,
"png"
,
lambda
data
:
imageencoder
(
data
,
"png"
))
add_handlers
(
handlers
,
"pbm"
,
lambda
data
:
imageencoder
(
data
,
"pbm"
))
add_handlers
(
handlers
,
"pgm"
,
lambda
data
:
imageencoder
(
data
,
"pgm"
))
...
...
@@ -192,7 +200,8 @@ def encode_based_on_extension(sample: dict, handlers: dict):
:param handlers: handlers for encoding
"""
return
{
k
:
encode_based_on_extension1
(
v
,
k
,
handlers
)
for
k
,
v
in
list
(
sample
.
items
())
k
:
encode_based_on_extension1
(
v
,
k
,
handlers
)
for
k
,
v
in
list
(
sample
.
items
())
}
...
...
@@ -258,15 +267,14 @@ class TarWriter:
"""
def
__init__
(
self
,
fileobj
,
user
:
str
=
"bigdata"
,
group
:
str
=
"bigdata"
,
mode
:
int
=
0o0444
,
compress
:
Optional
[
bool
]
=
None
,
encoder
:
Union
[
None
,
bool
,
Callable
]
=
True
,
keep_meta
:
bool
=
False
,
):
self
,
fileobj
,
user
:
str
=
"bigdata"
,
group
:
str
=
"bigdata"
,
mode
:
int
=
0o0444
,
compress
:
Optional
[
bool
]
=
None
,
encoder
:
Union
[
None
,
bool
,
Callable
]
=
True
,
keep_meta
:
bool
=
False
,
):
"""Create a tar writer.
:param fileobj: stream to write data to
...
...
@@ -330,8 +338,7 @@ class TarWriter:
continue
if
not
isinstance
(
v
,
(
bytes
,
bytearray
,
memoryview
)):
raise
ValueError
(
f
"
{
k
}
doesn't map to a bytes after encoding (
{
type
(
v
)
}
)"
)
f
"
{
k
}
doesn't map to a bytes after encoding (
{
type
(
v
)
}
)"
)
key
=
obj
[
"__key__"
]
for
k
in
sorted
(
obj
.
keys
()):
if
k
==
"__key__"
:
...
...
@@ -349,7 +356,8 @@ class TarWriter:
ti
.
uname
=
self
.
user
ti
.
gname
=
self
.
group
if
not
isinstance
(
v
,
(
bytes
,
bytearray
,
memoryview
)):
raise
ValueError
(
f
"converter didn't yield bytes:
{
k
}
,
{
type
(
v
)
}
"
)
raise
ValueError
(
f
"converter didn't yield bytes:
{
k
}
,
{
type
(
v
)
}
"
)
stream
=
io
.
BytesIO
(
v
)
self
.
tarstream
.
addfile
(
ti
,
stream
)
total
+=
ti
.
size
...
...
@@ -360,14 +368,13 @@ class ShardWriter:
"""Like TarWriter but splits into multiple shards."""
def
__init__
(
self
,
pattern
:
str
,
maxcount
:
int
=
100000
,
maxsize
:
float
=
3e9
,
post
:
Optional
[
Callable
]
=
None
,
start_shard
:
int
=
0
,
**
kw
,
):
self
,
pattern
:
str
,
maxcount
:
int
=
100000
,
maxsize
:
float
=
3e9
,
post
:
Optional
[
Callable
]
=
None
,
start_shard
:
int
=
0
,
**
kw
,
):
"""Create a ShardWriter.
:param pattern: output file pattern
...
...
@@ -400,8 +407,7 @@ class ShardWriter:
self
.
fname
,
self
.
count
,
"%.1f GB"
%
(
self
.
size
/
1e9
),
self
.
total
,
)
self
.
total
,
)
self
.
shard
+=
1
stream
=
open
(
self
.
fname
,
"wb"
)
self
.
tarstream
=
TarWriter
(
stream
,
**
self
.
kw
)
...
...
@@ -413,11 +419,8 @@ class ShardWriter:
:param obj: sample to be written
"""
if
(
self
.
tarstream
is
None
or
self
.
count
>=
self
.
maxcount
or
self
.
size
>=
self
.
maxsize
):
if
(
self
.
tarstream
is
None
or
self
.
count
>=
self
.
maxcount
or
self
.
size
>=
self
.
maxsize
):
self
.
next_stream
()
size
=
self
.
tarstream
.
write
(
obj
)
self
.
count
+=
1
...
...
paddlespeech/audio/text/text_featurizer.py
浏览文件 @
795eb7bd
...
...
@@ -17,6 +17,7 @@ from typing import Union
import
sentencepiece
as
spm
from
..utils.log
import
Logger
from
.utility
import
BLANK
from
.utility
import
EOS
from
.utility
import
load_dict
...
...
@@ -24,7 +25,6 @@ from .utility import MASKCTC
from
.utility
import
SOS
from
.utility
import
SPACE
from
.utility
import
UNK
from
..utils.log
import
Logger
logger
=
Logger
(
__name__
)
...
...
paddlespeech/audio/transform/perturb.py
浏览文件 @
795eb7bd
...
...
@@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
io
import
os
import
h5py
import
librosa
import
numpy
import
numpy
as
np
import
scipy
import
soundfile
import
io
import
os
import
h5py
import
numpy
as
np
class
SoundHDF5File
():
"""Collecting sound files to a HDF5 file
...
...
@@ -109,6 +110,7 @@ class SoundHDF5File():
def
close
(
self
):
self
.
file
.
close
()
class
SpeedPerturbation
():
"""SpeedPerturbation
...
...
@@ -558,4 +560,3 @@ class RIRConvolve():
[
scipy
.
convolve
(
x
,
r
,
mode
=
"same"
)
for
r
in
rir
],
axis
=-
1
)
else
:
return
scipy
.
convolve
(
x
,
rir
,
mode
=
"same"
)
paddlespeech/audio/transform/spec_augment.py
浏览文件 @
795eb7bd
...
...
@@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet)
"""Spec Augment module for preprocessing i.e., data augmentation"""
import
random
import
numpy
from
PIL
import
Image
...
...
paddlespeech/cli/executor.py
浏览文件 @
795eb7bd
...
...
@@ -191,7 +191,7 @@ class BaseExecutor(ABC):
line
=
line
.
strip
()
if
not
line
:
continue
k
,
v
=
line
.
split
()
# space or \t
k
,
v
=
line
.
split
()
# space or \t
job_contents
[
k
]
=
v
return
job_contents
...
...
paddlespeech/s2t/__init__.py
浏览文件 @
795eb7bd
...
...
@@ -114,6 +114,7 @@ if not hasattr(paddle.Tensor, 'new_full'):
paddle
.
Tensor
.
new_full
=
new_full
paddle
.
static
.
Variable
.
new_full
=
new_full
def
contiguous
(
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
return
xs
...
...
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
795eb7bd
...
...
@@ -25,8 +25,6 @@ import paddle
from
paddle
import
distributed
as
dist
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
StreamDataLoader
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
...
...
@@ -109,7 +107,8 @@ class U2Trainer(Trainer):
def
valid
(
self
):
self
.
model
.
eval
()
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
total_loss
=
0.0
...
...
@@ -136,7 +135,8 @@ class U2Trainer(Trainer):
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
...
...
@@ -157,7 +157,8 @@ class U2Trainer(Trainer):
self
.
before_train
()
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
...
...
@@ -225,14 +226,18 @@ class U2Trainer(Trainer):
config
=
self
.
config
.
clone
()
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
logger
.
info
(
"Setup test/align Dataloader!"
)
def
setup_model
(
self
):
...
...
paddlespeech/s2t/exps/u2_kaldi/model.py
浏览文件 @
795eb7bd
...
...
@@ -105,7 +105,8 @@ class U2Trainer(Trainer):
def
valid
(
self
):
self
.
model
.
eval
()
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
total_loss
=
0.0
...
...
@@ -133,7 +134,8 @@ class U2Trainer(Trainer):
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
...
...
@@ -153,7 +155,8 @@ class U2Trainer(Trainer):
self
.
before_train
()
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
...
...
@@ -165,8 +168,8 @@ class U2Trainer(Trainer):
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
not
self
.
use_streamdata
:
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
len
(
self
.
train_loader
))
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
len
(
self
.
train_loader
))
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
...
...
@@ -204,21 +207,24 @@ class U2Trainer(Trainer):
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
config
=
self
.
config
.
clone
()
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
config
=
self
.
config
.
clone
()
config
[
'preprocess_config'
]
=
None
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
config
=
self
.
config
.
clone
()
config
[
'preprocess_config'
]
=
None
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
config
=
self
.
config
.
clone
()
config
[
'preprocess_config'
]
=
None
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
logger
.
info
(
"Setup test/align Dataloader!"
)
def
setup_model
(
self
):
config
=
self
.
config
...
...
paddlespeech/s2t/exps/u2_st/model.py
浏览文件 @
795eb7bd
...
...
@@ -121,7 +121,8 @@ class U2STTrainer(Trainer):
def
valid
(
self
):
self
.
model
.
eval
()
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
total_loss
=
0.0
...
...
@@ -155,7 +156,8 @@ class U2STTrainer(Trainer):
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
...
...
@@ -175,7 +177,8 @@ class U2STTrainer(Trainer):
self
.
before_train
()
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
...
...
@@ -248,14 +251,16 @@ class U2STTrainer(Trainer):
config
[
'load_transcript'
]
=
load_transcript
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
logger
.
info
(
"Setup test Dataloader!"
)
def
setup_model
(
self
):
config
=
self
.
config
model_conf
=
config
...
...
paddlespeech/s2t/io/dataloader.py
浏览文件 @
795eb7bd
...
...
@@ -22,17 +22,16 @@ import paddle
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
yacs.config
import
CfgNode
import
paddlespeech.audio.streamdata
as
streamdata
from
paddlespeech.audio.text.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.batchfy
import
make_batchset
from
paddlespeech.s2t.io.converter
import
CustomConverter
from
paddlespeech.s2t.io.dataset
import
TransformDataset
from
paddlespeech.s2t.io.reader
import
LoadInputsAndTargets
from
paddlespeech.s2t.utils.log
import
Log
import
paddlespeech.audio.streamdata
as
streamdata
from
paddlespeech.audio.text.text_featurizer
import
TextFeaturizer
from
yacs.config
import
CfgNode
__all__
=
[
"BatchDataLoader"
,
"StreamDataLoader"
]
logger
=
Log
(
__name__
).
getlog
()
...
...
@@ -61,6 +60,7 @@ def batch_collate(x):
"""
return
x
[
0
]
def
read_preprocess_cfg
(
preprocess_conf_file
):
augment_conf
=
dict
()
preprocess_cfg
=
CfgNode
(
new_allowed
=
True
)
...
...
@@ -82,7 +82,8 @@ def read_preprocess_cfg(preprocess_conf_file):
augment_conf
[
'num_t_mask'
]
=
process
[
'n_mask'
]
augment_conf
[
't_inplace'
]
=
process
[
'inplace'
]
augment_conf
[
't_replace_with_zero'
]
=
process
[
'replace_with_zero'
]
return
augment_conf
return
augment_conf
class
StreamDataLoader
():
def
__init__
(
self
,
...
...
@@ -95,12 +96,12 @@ class StreamDataLoader():
frame_length
=
25
,
frame_shift
=
10
,
dither
=
0.0
,
minlen_in
:
float
=
0.0
,
minlen_in
:
float
=
0.0
,
maxlen_in
:
float
=
float
(
'inf'
),
minlen_out
:
float
=
0.0
,
maxlen_out
:
float
=
float
(
'inf'
),
resample_rate
:
int
=
16000
,
shuffle_size
:
int
=
10000
,
shuffle_size
:
int
=
10000
,
sort_size
:
int
=
1000
,
n_iter_processes
:
int
=
1
,
prefetch_factor
:
int
=
2
,
...
...
@@ -116,11 +117,11 @@ class StreamDataLoader():
text_featurizer
=
TextFeaturizer
(
unit_type
,
vocab_filepath
)
symbol_table
=
text_featurizer
.
vocab_dict
self
.
feat_dim
=
num_mel_bins
self
.
vocab_size
=
text_featurizer
.
vocab_size
self
.
feat_dim
=
num_mel_bins
self
.
vocab_size
=
text_featurizer
.
vocab_size
augment_conf
=
read_preprocess_cfg
(
preprocess_conf
)
# The list of shard
shardlist
=
[]
with
open
(
manifest_file
,
"r"
)
as
f
:
...
...
@@ -128,58 +129,68 @@ class StreamDataLoader():
shardlist
.
append
(
line
.
strip
())
world_size
=
1
try
:
world_size
=
paddle
.
distributed
.
get_world_size
()
world_size
=
paddle
.
distributed
.
get_world_size
()
except
Exception
as
e
:
logger
.
warninig
(
e
)
logger
.
warninig
(
"can not get world_size using paddle.distributed.get_world_size(), use world_size=1"
)
assert
(
len
(
shardlist
)
>=
world_size
,
"the length of shard list should >= number of gpus/xpus/..."
)
logger
.
warninig
(
"can not get world_size using paddle.distributed.get_world_size(), use world_size=1"
)
assert
len
(
shardlist
)
>=
world_size
,
\
"the length of shard list should >= number of gpus/xpus/..."
update_n_iter_processes
=
int
(
max
(
min
(
len
(
shardlist
)
/
world_size
-
1
,
self
.
n_iter_processes
),
0
))
update_n_iter_processes
=
int
(
max
(
min
(
len
(
shardlist
)
/
world_size
-
1
,
self
.
n_iter_processes
),
0
))
logger
.
info
(
f
"update_n_iter_processes
{
update_n_iter_processes
}
"
)
if
update_n_iter_processes
!=
self
.
n_iter_processes
:
self
.
n_iter_processes
=
update_n_iter_processes
self
.
n_iter_processes
=
update_n_iter_processes
logger
.
info
(
f
"change nun_workers to
{
self
.
n_iter_processes
}
"
)
if
self
.
dist_sampler
:
base_dataset
=
streamdata
.
DataPipeline
(
streamdata
.
SimpleShardList
(
shardlist
),
streamdata
.
split_by_node
if
train_mode
else
streamdata
.
placeholder
(),
streamdata
.
SimpleShardList
(
shardlist
),
streamdata
.
split_by_node
if
train_mode
else
streamdata
.
placeholder
(),
streamdata
.
split_by_worker
,
streamdata
.
tarfile_to_samples
(
streamdata
.
reraise_exception
)
)
streamdata
.
tarfile_to_samples
(
streamdata
.
reraise_exception
))
else
:
base_dataset
=
streamdata
.
DataPipeline
(
streamdata
.
SimpleShardList
(
shardlist
),
streamdata
.
split_by_worker
,
streamdata
.
tarfile_to_samples
(
streamdata
.
reraise_exception
)
)
streamdata
.
tarfile_to_samples
(
streamdata
.
reraise_exception
))
self
.
dataset
=
base_dataset
.
append_list
(
streamdata
.
audio_tokenize
(
symbol_table
),
streamdata
.
audio_data_filter
(
frame_shift
=
frame_shift
,
max_length
=
maxlen_in
,
min_length
=
minlen_in
,
token_max_length
=
maxlen_out
,
token_min_length
=
minlen_out
),
streamdata
.
audio_data_filter
(
frame_shift
=
frame_shift
,
max_length
=
maxlen_in
,
min_length
=
minlen_in
,
token_max_length
=
maxlen_out
,
token_min_length
=
minlen_out
),
streamdata
.
audio_resample
(
resample_rate
=
resample_rate
),
streamdata
.
audio_compute_fbank
(
num_mel_bins
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
),
streamdata
.
audio_spec_aug
(
**
augment_conf
)
if
train_mode
else
streamdata
.
placeholder
(),
# num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata
.
audio_compute_fbank
(
num_mel_bins
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
),
streamdata
.
audio_spec_aug
(
**
augment_conf
)
if
train_mode
else
streamdata
.
placeholder
(
),
# num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata
.
shuffle
(
shuffle_size
),
streamdata
.
sort
(
sort_size
=
sort_size
),
streamdata
.
batched
(
batch_size
),
streamdata
.
audio_padding
(),
streamdata
.
audio_cmvn
(
cmvn_file
)
)
streamdata
.
audio_cmvn
(
cmvn_file
))
if
paddle
.
__version__
>=
'2.3.2'
:
self
.
loader
=
streamdata
.
WebLoader
(
self
.
dataset
,
num_workers
=
self
.
n_iter_processes
,
prefetch_factor
=
self
.
prefetch_factor
,
batch_size
=
None
)
self
.
dataset
,
num_workers
=
self
.
n_iter_processes
,
prefetch_factor
=
self
.
prefetch_factor
,
batch_size
=
None
)
else
:
self
.
loader
=
streamdata
.
WebLoader
(
self
.
dataset
,
num_workers
=
self
.
n_iter_processes
,
batch_size
=
None
)
self
.
dataset
,
num_workers
=
self
.
n_iter_processes
,
batch_size
=
None
)
def
__iter__
(
self
):
return
self
.
loader
.
__iter__
()
...
...
@@ -188,7 +199,9 @@ class StreamDataLoader():
return
self
.
__iter__
()
def
__len__
(
self
):
logger
.
info
(
"Stream dataloader does not support calculate the length of the dataset"
)
logger
.
info
(
"Stream dataloader does not support calculate the length of the dataset"
)
return
-
1
...
...
@@ -347,7 +360,7 @@ class DataLoaderFactory():
config
[
'train_mode'
]
=
True
elif
mode
==
'valid'
:
config
[
'manifest'
]
=
config
.
dev_manifest
config
[
'train_mode'
]
=
False
config
[
'train_mode'
]
=
False
elif
model
==
'test'
or
mode
==
'align'
:
config
[
'manifest'
]
=
config
.
test_manifest
config
[
'train_mode'
]
=
False
...
...
@@ -358,30 +371,31 @@ class DataLoaderFactory():
config
[
'maxlen_out'
]
=
float
(
'inf'
)
config
[
'dist_sampler'
]
=
False
else
:
raise
KeyError
(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return
StreamDataLoader
(
manifest_file
=
config
.
manifest
,
train_mode
=
config
.
train_mode
,
unit_type
=
config
.
unit_type
,
preprocess_conf
=
config
.
preprocess_config
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
config
.
dither
,
minlen_in
=
config
.
minlen_in
,
maxlen_in
=
config
.
maxlen_in
,
minlen_out
=
config
.
minlen_out
,
maxlen_out
=
config
.
maxlen_out
,
resample_rate
=
config
.
resample_rate
,
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
dist_sampler
,
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
raise
KeyError
(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return
StreamDataLoader
(
manifest_file
=
config
.
manifest
,
train_mode
=
config
.
train_mode
,
unit_type
=
config
.
unit_type
,
preprocess_conf
=
config
.
preprocess_config
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
config
.
dither
,
minlen_in
=
config
.
minlen_in
,
maxlen_in
=
config
.
maxlen_in
,
minlen_out
=
config
.
minlen_out
,
maxlen_out
=
config
.
maxlen_out
,
resample_rate
=
config
.
resample_rate
,
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
dist_sampler
,
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
else
:
if
mode
==
'train'
:
config
[
'manifest'
]
=
config
.
train_manifest
...
...
@@ -411,7 +425,7 @@ class DataLoaderFactory():
config
[
'train_mode'
]
=
False
config
[
'sortagrad'
]
=
False
config
[
'batch_size'
]
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
'decode_batch_size'
,
1
)
config
[
'maxlen_in'
]
=
float
(
'inf'
)
config
[
'maxlen_out'
]
=
float
(
'inf'
)
config
[
'minibatches'
]
=
0
...
...
@@ -427,8 +441,10 @@ class DataLoaderFactory():
config
[
'dist_sampler'
]
=
False
config
[
'shortest_first'
]
=
False
else
:
raise
KeyError
(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
raise
KeyError
(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return
BatchDataLoader
(
json_file
=
config
.
manifest
,
train_mode
=
config
.
train_mode
,
...
...
@@ -450,4 +466,3 @@ class DataLoaderFactory():
num_encs
=
config
.
num_encs
,
dist_sampler
=
config
.
dist_sampler
,
shortest_first
=
config
.
shortest_first
)
paddlespeech/s2t/models/u2_st/u2_st.py
浏览文件 @
795eb7bd
...
...
@@ -18,7 +18,6 @@ Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recogni
"""
import
time
from
typing
import
Dict
from
typing
import
List
from
typing
import
Optional
from
typing
import
Tuple
...
...
@@ -26,6 +25,8 @@ import paddle
from
paddle
import
jit
from
paddle
import
nn
from
paddlespeech.audio.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.audio.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.frontend.utility
import
IGNORE_ID
from
paddlespeech.s2t.frontend.utility
import
load_cmvn
from
paddlespeech.s2t.modules.cmvn
import
GlobalCMVN
...
...
@@ -38,8 +39,6 @@ from paddlespeech.s2t.modules.mask import subsequent_mask
from
paddlespeech.s2t.utils
import
checkpoint
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.audio.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.audio.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
__all__
=
[
"U2STModel"
,
"U2STInferModel"
]
...
...
@@ -401,8 +400,8 @@ class U2STBaseModel(nn.Layer):
xs
:
paddle
.
Tensor
,
offset
:
int
,
required_cache_size
:
int
,
att_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
cnn_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
att_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
cnn_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
...
...
@@ -435,8 +434,8 @@ class U2STBaseModel(nn.Layer):
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
return
self
.
encoder
.
forward_chunk
(
xs
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
)
return
self
.
encoder
.
forward_chunk
(
xs
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
)
# @jit.to_static
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
...
...
paddlespeech/s2t/modules/align.py
浏览文件 @
795eb7bd
...
...
@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
paddle
from
paddle
import
nn
import
math
"""
To align the initializer between paddle and torch,
the API below are set defalut initializer with priority higger than global initializer.
...
...
@@ -81,10 +82,18 @@ class Linear(nn.Linear):
name
=
None
):
if
weight_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
if
bias_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
super
(
Linear
,
self
).
__init__
(
in_features
,
out_features
,
weight_attr
,
bias_attr
,
name
)
...
...
@@ -104,10 +113,18 @@ class Conv1D(nn.Conv1D):
data_format
=
'NCL'
):
if
weight_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
if
bias_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
super
(
Conv1D
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
padding_mode
,
weight_attr
,
bias_attr
,
data_format
)
...
...
@@ -128,10 +145,18 @@ class Conv2D(nn.Conv2D):
data_format
=
'NCHW'
):
if
weight_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
if
bias_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
super
(
Conv2D
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
padding_mode
,
weight_attr
,
bias_attr
,
data_format
)
paddlespeech/s2t/modules/initializer.py
浏览文件 @
795eb7bd
...
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
class
DefaultInitializerContext
(
object
):
"""
...
...
paddlespeech/server/engine/asr/online/ctc_endpoint.py
浏览文件 @
795eb7bd
...
...
@@ -102,8 +102,10 @@ class OnlineCTCEndpoint:
assert
self
.
num_frames_decoded
>=
self
.
trailing_silence_frames
assert
self
.
frame_shift_in_ms
>
0
decoding_something
=
(
self
.
num_frames_decoded
>
self
.
trailing_silence_frames
)
and
decoding_something
decoding_something
=
(
self
.
num_frames_decoded
>
self
.
trailing_silence_frames
)
and
decoding_something
utterance_length
=
self
.
num_frames_decoded
*
self
.
frame_shift_in_ms
trailing_silence
=
self
.
trailing_silence_frames
*
self
.
frame_shift_in_ms
...
...
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
浏览文件 @
795eb7bd
...
...
@@ -21,12 +21,12 @@ import paddle
from
numpy
import
float32
from
yacs.config
import
CfgNode
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.log
import
logger
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils
import
onnx_infer
...
...
paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
浏览文件 @
795eb7bd
...
...
@@ -21,10 +21,10 @@ import paddle
from
numpy
import
float32
from
yacs.config
import
CfgNode
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.log
import
logger
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
...
...
paddlespeech/server/engine/asr/python/asr_engine.py
浏览文件 @
795eb7bd
...
...
@@ -66,12 +66,14 @@ class ASREngine(BaseEngine):
)
logger
.
error
(
e
)
return
False
self
.
executor
.
_init_from_path
(
model_type
=
self
.
config
.
model
,
lang
=
self
.
config
.
lang
,
sample_rate
=
self
.
config
.
sample_rate
,
cfg_path
=
self
.
config
.
cfg_path
,
decode_method
=
self
.
config
.
decode_method
,
ckpt_path
=
self
.
config
.
ckpt_path
)
self
.
executor
.
_init_from_path
(
model_type
=
self
.
config
.
model
,
lang
=
self
.
config
.
lang
,
sample_rate
=
self
.
config
.
sample_rate
,
cfg_path
=
self
.
config
.
cfg_path
,
decode_method
=
self
.
config
.
decode_method
,
ckpt_path
=
self
.
config
.
ckpt_path
)
logger
.
info
(
"Initialize ASR server engine successfully on device: %s."
%
(
self
.
device
))
...
...
paddlespeech/t2s/datasets/sampler.py
浏览文件 @
795eb7bd
import
paddle
import
math
import
numpy
as
np
from
paddle.io
import
BatchSampler
class
ErnieSATSampler
(
BatchSampler
):
"""Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance
...
...
@@ -110,8 +111,8 @@ class ErnieSATSampler(BatchSampler):
subsampled_indices
.
extend
(
indices
[
i
:
i
+
self
.
batch_size
])
indices
=
indices
[
len
(
indices
)
-
last_batch_size
:]
subsampled_indices
.
extend
(
indices
[
self
.
local_rank
*
last_local_batch_size
:(
subsampled_indices
.
extend
(
indices
[
self
.
local_rank
*
last_local_batch_size
:(
self
.
local_rank
+
1
)
*
last_local_batch_size
])
return
subsampled_indices
...
...
paddlespeech/t2s/exps/ernie_sat/train.py
浏览文件 @
795eb7bd
...
...
@@ -25,7 +25,6 @@ from paddle import DataParallel
from
paddle
import
distributed
as
dist
from
paddle
import
nn
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddle.optimizer
import
Adam
from
yacs.config
import
CfgNode
...
...
paddlespeech/t2s/exps/ernie_sat/utils.py
浏览文件 @
795eb7bd
...
...
@@ -11,32 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
hashlib
import
os
from
pathlib
import
Path
from
typing
import
Dict
from
typing
import
List
from
typing
import
Union
import
os
import
numpy
as
np
import
paddle
import
yaml
from
yacs.config
import
CfgNode
import
hashlib
from
paddlespeech.t2s.exps.syn_utils
import
get_am_inference
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_inference
def
_get_user
():
return
os
.
path
.
expanduser
(
'~'
).
split
(
'/'
)[
-
1
]
def
str2md5
(
string
):
md5_val
=
hashlib
.
md5
(
string
.
encode
(
'utf8'
)).
hexdigest
()
return
md5_val
def
get_tmp_name
(
text
:
str
):
def
get_tmp_name
(
text
:
str
):
return
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
+
'_'
+
str2md5
(
text
)
def
get_dict
(
dictfile
:
str
):
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
...
...
paddlespeech/t2s/exps/syn_utils.py
浏览文件 @
795eb7bd
...
...
@@ -298,8 +298,8 @@ def am_to_static(am_inference,
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
if
am_name
==
'fastspeech2'
:
if
am_dataset
in
{
"aishell3"
,
"vctk"
,
"mix"
}
and
speaker_dict
is
not
None
:
if
am_dataset
in
{
"aishell3"
,
"vctk"
,
"mix"
}
and
speaker_dict
is
not
None
:
am_inference
=
jit
.
to_static
(
am_inference
,
input_spec
=
[
...
...
@@ -311,8 +311,8 @@ def am_to_static(am_inference,
am_inference
,
input_spec
=
[
InputSpec
([
-
1
],
dtype
=
paddle
.
int64
)])
elif
am_name
==
'speedyspeech'
:
if
am_dataset
in
{
"aishell3"
,
"vctk"
,
"mix"
}
and
speaker_dict
is
not
None
:
if
am_dataset
in
{
"aishell3"
,
"vctk"
,
"mix"
}
and
speaker_dict
is
not
None
:
am_inference
=
jit
.
to_static
(
am_inference
,
input_spec
=
[
...
...
paddlespeech/t2s/frontend/g2pw/__init__.py
浏览文件 @
795eb7bd
from
paddlespeech.t2s.frontend.g2pw.onnx_api
import
G2PWOnnxConverter
paddlespeech/t2s/frontend/mix_frontend.py
浏览文件 @
795eb7bd
...
...
@@ -61,8 +61,11 @@ class MixFrontend():
return
False
def
is_end
(
self
,
before_char
,
after_char
)
->
bool
:
if
((
self
.
is_alphabet
(
before_char
)
or
before_char
==
" "
)
and
(
self
.
is_alphabet
(
after_char
)
or
after_char
==
" "
)):
flag
=
0
for
char
in
(
before_char
,
after_char
):
if
self
.
is_alphabet
(
char
)
or
char
==
" "
:
flag
+=
1
if
flag
==
2
:
return
True
else
:
return
False
...
...
paddlespeech/t2s/training/updaters/standard_updater.py
浏览文件 @
795eb7bd
...
...
@@ -24,10 +24,11 @@ from paddle.nn import Layer
from
paddle.optimizer
import
Optimizer
from
timer
import
timer
from
paddlespeech.t2s.datasets.sampler
import
ErnieSATSampler
from
paddlespeech.t2s.training.reporter
import
report
from
paddlespeech.t2s.training.updater
import
UpdaterBase
from
paddlespeech.t2s.training.updater
import
UpdaterState
from
paddlespeech.t2s.datasets.sampler
import
ErnieSATSampler
class
StandardUpdater
(
UpdaterBase
):
"""An example of over-simplification. Things may not be that simple, but
...
...
setup.py
浏览文件 @
795eb7bd
...
...
@@ -77,12 +77,7 @@ base = [
"pybind11"
,
]
server
=
[
"fastapi"
,
"uvicorn"
,
"pattern_singleton"
,
"websockets"
]
server
=
[
"fastapi"
,
"uvicorn"
,
"pattern_singleton"
,
"websockets"
]
requirements
=
{
"install"
:
...
...
@@ -330,4 +325,4 @@ setup_info = dict(
})
with
version_info
():
setup
(
**
setup_info
,
include_package_data
=
True
)
setup
(
**
setup_info
,
include_package_data
=
True
)
speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py
浏览文件 @
795eb7bd
...
...
@@ -490,18 +490,10 @@ class SymbolicShapeInference:
def
_onnx_infer_single_node
(
self
,
node
):
# skip onnx shape inference for some ops, as they are handled in _infer_*
skip_infer
=
node
.
op_type
in
[
'If'
,
'Loop'
,
'Scan'
,
'SplitToSequence'
,
'ZipMap'
,
\
# contrib ops
'Attention'
,
'BiasGelu'
,
\
'EmbedLayerNormalization'
,
\
'FastGelu'
,
'Gelu'
,
'LayerNormalization'
,
\
'LongformerAttention'
,
\
'SkipLayerNormalization'
,
\
'PythonOp'
'If'
,
'Loop'
,
'Scan'
,
'SplitToSequence'
,
'ZipMap'
,
'Attention'
,
'BiasGelu'
,
'EmbedLayerNormalization'
,
'FastGelu'
,
'Gelu'
,
'LayerNormalization'
,
'LongformerAttention'
,
'SkipLayerNormalization'
,
'PythonOp'
]
if
not
skip_infer
:
...
...
@@ -514,8 +506,8 @@ class SymbolicShapeInference:
if
(
get_opset
(
self
.
out_mp_
)
>=
9
)
and
node
.
op_type
in
[
'Unsqueeze'
]:
initializers
=
[
self
.
initializers_
[
name
]
for
name
in
node
.
input
if
(
name
in
self
.
initializers_
and
name
not
in
self
.
graph_inputs_
)
if
(
name
in
self
.
initializers_
and
name
not
in
self
.
graph_inputs_
)
]
# run single node inference with self.known_vi_ shapes
...
...
@@ -601,8 +593,8 @@ class SymbolicShapeInference:
for
o
in
symbolic_shape_inference
.
out_mp_
.
graph
.
output
]
subgraph_new_symbolic_dims
=
set
([
d
for
s
in
subgraph_shapes
if
s
for
d
in
s
if
type
(
d
)
==
str
and
not
d
in
self
.
symbolic_dims_
d
for
s
in
subgraph_shapes
if
s
for
d
in
s
if
type
(
d
)
==
str
and
not
d
in
self
.
symbolic_dims_
])
new_dims
=
{}
for
d
in
subgraph_new_symbolic_dims
:
...
...
@@ -729,8 +721,9 @@ class SymbolicShapeInference:
for
d
,
s
in
zip
(
sympy_shape
[
-
rank
:],
strides
)
]
total_pads
=
[
max
(
0
,
(
k
-
s
)
if
r
==
0
else
(
k
-
r
))
for
k
,
s
,
r
in
zip
(
effective_kernel_shape
,
strides
,
residual
)
max
(
0
,
(
k
-
s
)
if
r
==
0
else
(
k
-
r
))
for
k
,
s
,
r
in
zip
(
effective_kernel_shape
,
strides
,
residual
)
]
except
TypeError
:
# sympy may throw TypeError: cannot determine truth value of Relational
total_pads
=
[
...
...
@@ -1276,8 +1269,9 @@ class SymbolicShapeInference:
if
pads
is
not
None
:
assert
len
(
pads
)
==
2
*
rank
new_sympy_shape
=
[
d
+
pad_up
+
pad_down
for
d
,
pad_up
,
pad_down
in
zip
(
sympy_shape
,
pads
[:
rank
],
pads
[
rank
:])
d
+
pad_up
+
pad_down
for
d
,
pad_up
,
pad_down
in
zip
(
sympy_shape
,
pads
[:
rank
],
pads
[
rank
:])
]
self
.
_update_computed_dims
(
new_sympy_shape
)
else
:
...
...
@@ -1590,8 +1584,8 @@ class SymbolicShapeInference:
scales
=
list
(
scales
)
new_sympy_shape
=
[
sympy
.
simplify
(
sympy
.
floor
(
d
*
(
end
-
start
)
*
scale
))
for
d
,
start
,
end
,
scale
in
zip
(
input_sympy_shape
,
roi_start
,
roi_end
,
scales
)
for
d
,
start
,
end
,
scale
in
zip
(
input_sympy_shape
,
roi_start
,
roi_end
,
scales
)
]
self
.
_update_computed_dims
(
new_sympy_shape
)
else
:
...
...
@@ -2204,8 +2198,9 @@ class SymbolicShapeInference:
# topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate
sorted_nodes
=
[]
sorted_known_vi
=
set
([
i
.
name
for
i
in
list
(
self
.
out_mp_
.
graph
.
input
)
+
list
(
self
.
out_mp_
.
graph
.
initializer
)
i
.
name
for
i
in
list
(
self
.
out_mp_
.
graph
.
input
)
+
list
(
self
.
out_mp_
.
graph
.
initializer
)
])
if
any
([
o
.
name
in
sorted_known_vi
for
o
in
self
.
out_mp_
.
graph
.
output
]):
# Loop/Scan will have some graph output in graph inputs, so don't do topological sort
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录