Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
795eb7bd
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
795eb7bd
编写于
9月 01, 2022
作者:
小湉湉
提交者:
GitHub
9月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format paddlespeech with pre-commit (#2331)
上级
5d5888af
变更
61
隐藏空白更改
内联
并排
Showing
61 changed file
with
1052 addition
and
940 deletion
+1052
-940
demos/audio_searching/src/operations/load.py
demos/audio_searching/src/operations/load.py
+3
-2
demos/speech_web/API.md
demos/speech_web/API.md
+1
-1
demos/speech_web/speech_server/main.py
demos/speech_web/speech_server/main.py
+80
-80
demos/speech_web/speech_server/requirements.txt
demos/speech_web/speech_server/requirements.txt
+5
-6
demos/speech_web/speech_server/src/AudioManeger.py
demos/speech_web/speech_server/src/AudioManeger.py
+45
-42
demos/speech_web/speech_server/src/SpeechBase/asr.py
demos/speech_web/speech_server/src/SpeechBase/asr.py
+8
-10
demos/speech_web/speech_server/src/SpeechBase/nlp.py
demos/speech_web/speech_server/src/SpeechBase/nlp.py
+9
-9
demos/speech_web/speech_server/src/SpeechBase/sql_helper.py
demos/speech_web/speech_server/src/SpeechBase/sql_helper.py
+31
-25
demos/speech_web/speech_server/src/SpeechBase/tts.py
demos/speech_web/speech_server/src/SpeechBase/tts.py
+43
-49
demos/speech_web/speech_server/src/SpeechBase/vpr.py
demos/speech_web/speech_server/src/SpeechBase/vpr.py
+28
-26
demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py
demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py
+5
-4
demos/speech_web/speech_server/src/WebsocketManeger.py
demos/speech_web/speech_server/src/WebsocketManeger.py
+2
-1
demos/speech_web/speech_server/src/robot.py
demos/speech_web/speech_server/src/robot.py
+23
-21
demos/speech_web/speech_server/src/util.py
demos/speech_web/speech_server/src/util.py
+6
-11
demos/streaming_asr_server/local/rtf_from_log.py
demos/streaming_asr_server/local/rtf_from_log.py
+1
-1
docs/requirements.txt
docs/requirements.txt
+17
-18
docs/source/conf.py
docs/source/conf.py
+3
-2
examples/iwslt2012/punc0/local/preprocess.py
examples/iwslt2012/punc0/local/preprocess.py
+24
-22
examples/other/tts_finetune/tts3/finetune.py
examples/other/tts_finetune/tts3/finetune.py
+4
-5
paddlespeech/__init__.py
paddlespeech/__init__.py
+0
-2
paddlespeech/audio/__init__.py
paddlespeech/audio/__init__.py
+3
-3
paddlespeech/audio/streamdata/__init__.py
paddlespeech/audio/streamdata/__init__.py
+62
-63
paddlespeech/audio/streamdata/autodecode.py
paddlespeech/audio/streamdata/autodecode.py
+9
-10
paddlespeech/audio/streamdata/cache.py
paddlespeech/audio/streamdata/cache.py
+30
-33
paddlespeech/audio/streamdata/compat.py
paddlespeech/audio/streamdata/compat.py
+39
-29
paddlespeech/audio/streamdata/extradatasets.py
paddlespeech/audio/streamdata/extradatasets.py
+1
-12
paddlespeech/audio/streamdata/filters.py
paddlespeech/audio/streamdata/filters.py
+164
-92
paddlespeech/audio/streamdata/gopen.py
paddlespeech/audio/streamdata/gopen.py
+26
-36
paddlespeech/audio/streamdata/handlers.py
paddlespeech/audio/streamdata/handlers.py
+2
-3
paddlespeech/audio/streamdata/mix.py
paddlespeech/audio/streamdata/mix.py
+2
-7
paddlespeech/audio/streamdata/paddle_utils.py
paddlespeech/audio/streamdata/paddle_utils.py
+2
-12
paddlespeech/audio/streamdata/pipeline.py
paddlespeech/audio/streamdata/pipeline.py
+5
-9
paddlespeech/audio/streamdata/shardlists.py
paddlespeech/audio/streamdata/shardlists.py
+44
-33
paddlespeech/audio/streamdata/tariterators.py
paddlespeech/audio/streamdata/tariterators.py
+41
-40
paddlespeech/audio/streamdata/utils.py
paddlespeech/audio/streamdata/utils.py
+17
-15
paddlespeech/audio/streamdata/writer.py
paddlespeech/audio/streamdata/writer.py
+40
-37
paddlespeech/audio/text/text_featurizer.py
paddlespeech/audio/text/text_featurizer.py
+1
-1
paddlespeech/audio/transform/perturb.py
paddlespeech/audio/transform/perturb.py
+6
-5
paddlespeech/audio/transform/spec_augment.py
paddlespeech/audio/transform/spec_augment.py
+1
-0
paddlespeech/cli/executor.py
paddlespeech/cli/executor.py
+1
-1
paddlespeech/s2t/__init__.py
paddlespeech/s2t/__init__.py
+1
-0
paddlespeech/s2t/exps/u2/model.py
paddlespeech/s2t/exps/u2/model.py
+14
-9
paddlespeech/s2t/exps/u2_kaldi/model.py
paddlespeech/s2t/exps/u2_kaldi/model.py
+16
-10
paddlespeech/s2t/exps/u2_st/model.py
paddlespeech/s2t/exps/u2_st/model.py
+12
-7
paddlespeech/s2t/io/dataloader.py
paddlespeech/s2t/io/dataloader.py
+80
-65
paddlespeech/s2t/models/u2_st/u2_st.py
paddlespeech/s2t/models/u2_st/u2_st.py
+6
-7
paddlespeech/s2t/modules/align.py
paddlespeech/s2t/modules/align.py
+32
-7
paddlespeech/s2t/modules/initializer.py
paddlespeech/s2t/modules/initializer.py
+1
-1
paddlespeech/server/engine/asr/online/ctc_endpoint.py
paddlespeech/server/engine/asr/online/ctc_endpoint.py
+4
-2
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
+1
-1
paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
...ch/server/engine/asr/online/paddleinference/asr_engine.py
+1
-1
paddlespeech/server/engine/asr/python/asr_engine.py
paddlespeech/server/engine/asr/python/asr_engine.py
+7
-5
paddlespeech/t2s/datasets/sampler.py
paddlespeech/t2s/datasets/sampler.py
+4
-3
paddlespeech/t2s/exps/ernie_sat/train.py
paddlespeech/t2s/exps/ernie_sat/train.py
+0
-1
paddlespeech/t2s/exps/ernie_sat/utils.py
paddlespeech/t2s/exps/ernie_sat/utils.py
+7
-4
paddlespeech/t2s/exps/syn_utils.py
paddlespeech/t2s/exps/syn_utils.py
+4
-4
paddlespeech/t2s/frontend/g2pw/__init__.py
paddlespeech/t2s/frontend/g2pw/__init__.py
+0
-1
paddlespeech/t2s/frontend/mix_frontend.py
paddlespeech/t2s/frontend/mix_frontend.py
+5
-2
paddlespeech/t2s/training/updaters/standard_updater.py
paddlespeech/t2s/training/updaters/standard_updater.py
+2
-1
setup.py
setup.py
+2
-7
speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py
speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py
+19
-24
未找到文件。
demos/audio_searching/src/operations/load.py
浏览文件 @
795eb7bd
...
@@ -26,8 +26,9 @@ def get_audios(path):
...
@@ -26,8 +26,9 @@ def get_audios(path):
"""
"""
supported_formats
=
[
".wav"
,
".mp3"
,
".ogg"
,
".flac"
,
".m4a"
]
supported_formats
=
[
".wav"
,
".mp3"
,
".ogg"
,
".flac"
,
".m4a"
]
return
[
return
[
item
for
sublist
in
[[
os
.
path
.
join
(
dir
,
file
)
for
file
in
files
]
item
for
dir
,
_
,
files
in
list
(
os
.
walk
(
path
))]
for
sublist
in
[[
os
.
path
.
join
(
dir
,
file
)
for
file
in
files
]
for
dir
,
_
,
files
in
list
(
os
.
walk
(
path
))]
for
item
in
sublist
if
os
.
path
.
splitext
(
item
)[
1
]
in
supported_formats
for
item
in
sublist
if
os
.
path
.
splitext
(
item
)[
1
]
in
supported_formats
]
]
...
...
demos/speech_web/API.md
浏览文件 @
795eb7bd
...
@@ -401,4 +401,4 @@ curl -X 'GET' \
...
@@ -401,4 +401,4 @@ curl -X 'GET' \
"code"
:
0
,
"code"
:
0
,
"result"
:
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
,
"result"
:
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
,
"message"
:
"ok"
"message"
:
"ok"
```
```
\ No newline at end of file
demos/speech_web/speech_server/main.py
浏览文件 @
795eb7bd
...
@@ -3,48 +3,48 @@
...
@@ -3,48 +3,48 @@
# 2. 接收录音音频,返回识别结果
# 2. 接收录音音频,返回识别结果
# 3. 接收ASR识别结果,返回NLP对话结果
# 3. 接收ASR识别结果,返回NLP对话结果
# 4. 接收NLP对话结果,返回TTS音频
# 4. 接收NLP对话结果,返回TTS音频
import
argparse
import
base64
import
base64
import
yaml
import
os
import
json
import
datetime
import
datetime
import
json
import
os
from
typing
import
List
import
aiofiles
import
librosa
import
librosa
import
soundfile
as
sf
import
soundfile
as
sf
import
numpy
as
np
import
argparse
import
uvicorn
import
uvicorn
import
aiofiles
from
fastapi
import
FastAPI
from
typing
import
Optional
,
List
from
fastapi
import
File
from
pydantic
import
BaseModel
from
fastapi
import
Form
from
fastapi
import
FastAPI
,
Header
,
File
,
UploadFile
,
Form
,
Cookie
,
WebSocket
,
WebSocketDisconnect
from
fastapi
import
UploadFile
from
fastapi
import
WebSocket
from
fastapi
import
WebSocketDisconnect
from
fastapi.responses
import
StreamingResponse
from
fastapi.responses
import
StreamingResponse
from
starlette.responses
import
FileResponse
from
pydantic
import
BaseModel
from
starlette.middleware.cors
import
CORSMiddleware
from
starlette.requests
import
Request
from
starlette.websockets
import
WebSocketState
as
WebSocketState
from
src.AudioManeger
import
AudioMannger
from
src.AudioManeger
import
AudioMannger
from
src.util
import
*
from
src.robot
import
Robot
from
src.robot
import
Robot
from
src.WebsocketManeger
import
ConnectionManager
from
src.SpeechBase.vpr
import
VPR
from
src.SpeechBase.vpr
import
VPR
from
src.util
import
*
from
src.WebsocketManeger
import
ConnectionManager
from
starlette.middleware.cors
import
CORSMiddleware
from
starlette.requests
import
Request
from
starlette.responses
import
FileResponse
from
starlette.websockets
import
WebSocketState
as
WebSocketState
from
paddlespeech.server.engine.asr.online.python.asr_engine
import
PaddleASRConnectionHanddler
from
paddlespeech.server.engine.asr.online.python.asr_engine
import
PaddleASRConnectionHanddler
from
paddlespeech.server.utils.audio_process
import
float2pcm
from
paddlespeech.server.utils.audio_process
import
float2pcm
# 解析配置
# 解析配置
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
prog
=
'PaddleSpeechDemo'
,
add_help
=
True
)
prog
=
'PaddleSpeechDemo'
,
add_help
=
True
)
parser
.
add_argument
(
parser
.
add_argument
(
"--port"
,
"--port"
,
action
=
"store"
,
action
=
"store"
,
type
=
int
,
type
=
int
,
help
=
"port of the app"
,
help
=
"port of the app"
,
default
=
8010
,
default
=
8010
,
required
=
False
)
required
=
False
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
port
=
args
.
port
port
=
args
.
port
...
@@ -60,39 +60,41 @@ ie_model_path = "source/model"
...
@@ -60,39 +60,41 @@ ie_model_path = "source/model"
UPLOAD_PATH
=
"source/vpr"
UPLOAD_PATH
=
"source/vpr"
WAV_PATH
=
"source/wav"
WAV_PATH
=
"source/wav"
base_sources
=
[
UPLOAD_PATH
,
WAV_PATH
]
base_sources
=
[
UPLOAD_PATH
,
WAV_PATH
]
for
path
in
base_sources
:
for
path
in
base_sources
:
os
.
makedirs
(
path
,
exist_ok
=
True
)
os
.
makedirs
(
path
,
exist_ok
=
True
)
# 初始化
# 初始化
app
=
FastAPI
()
app
=
FastAPI
()
chatbot
=
Robot
(
asr_config
,
tts_config
,
asr_init_path
,
ie_model_path
=
ie_model_path
)
chatbot
=
Robot
(
asr_config
,
tts_config
,
asr_init_path
,
ie_model_path
=
ie_model_path
)
manager
=
ConnectionManager
()
manager
=
ConnectionManager
()
aumanager
=
AudioMannger
(
chatbot
)
aumanager
=
AudioMannger
(
chatbot
)
aumanager
.
init
()
aumanager
.
init
()
vpr
=
VPR
(
db_path
,
dim
=
192
,
top_k
=
5
)
vpr
=
VPR
(
db_path
,
dim
=
192
,
top_k
=
5
)
# 服务配置
# 服务配置
class
NlpBase
(
BaseModel
):
class
NlpBase
(
BaseModel
):
chat
:
str
chat
:
str
class
TtsBase
(
BaseModel
):
class
TtsBase
(
BaseModel
):
text
:
str
text
:
str
class
Audios
:
class
Audios
:
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
audios
=
b
""
self
.
audios
=
b
""
audios
=
Audios
()
audios
=
Audios
()
######################################################################
######################################################################
########################### ASR 服务 #################################
########################### ASR 服务 #################################
#####################################################################
#####################################################################
# 接收文件,返回ASR结果
# 接收文件,返回ASR结果
# 上传文件
# 上传文件
@
app
.
post
(
"/asr/offline"
)
@
app
.
post
(
"/asr/offline"
)
...
@@ -101,7 +103,8 @@ async def speech2textOffline(files: List[UploadFile]):
...
@@ -101,7 +103,8 @@ async def speech2textOffline(files: List[UploadFile]):
asr_res
=
""
asr_res
=
""
for
file
in
files
[:
1
]:
for
file
in
files
[:
1
]:
# 生成时间戳
# 生成时间戳
now_name
=
"asr_offline_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
now_name
=
"asr_offline_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
out_file_path
=
os
.
path
.
join
(
WAV_PATH
,
now_name
)
out_file_path
=
os
.
path
.
join
(
WAV_PATH
,
now_name
)
async
with
aiofiles
.
open
(
out_file_path
,
'wb'
)
as
out_file
:
async
with
aiofiles
.
open
(
out_file_path
,
'wb'
)
as
out_file
:
content
=
await
file
.
read
()
# async read
content
=
await
file
.
read
()
# async read
...
@@ -110,10 +113,9 @@ async def speech2textOffline(files: List[UploadFile]):
...
@@ -110,10 +113,9 @@ async def speech2textOffline(files: List[UploadFile]):
# 返回ASR识别结果
# 返回ASR识别结果
asr_res
=
chatbot
.
speech2text
(
out_file_path
)
asr_res
=
chatbot
.
speech2text
(
out_file_path
)
return
SuccessRequest
(
result
=
asr_res
)
return
SuccessRequest
(
result
=
asr_res
)
# else:
# return ErrorRequest(message="文件不是.wav格式")
return
ErrorRequest
(
message
=
"上传文件为空"
)
return
ErrorRequest
(
message
=
"上传文件为空"
)
# 接收文件,同时将wav强制转成16k, int16类型
# 接收文件,同时将wav强制转成16k, int16类型
@
app
.
post
(
"/asr/offlinefile"
)
@
app
.
post
(
"/asr/offlinefile"
)
async
def
speech2textOfflineFile
(
files
:
List
[
UploadFile
]):
async
def
speech2textOfflineFile
(
files
:
List
[
UploadFile
]):
...
@@ -121,7 +123,8 @@ async def speech2textOfflineFile(files: List[UploadFile]):
...
@@ -121,7 +123,8 @@ async def speech2textOfflineFile(files: List[UploadFile]):
asr_res
=
""
asr_res
=
""
for
file
in
files
[:
1
]:
for
file
in
files
[:
1
]:
# 生成时间戳
# 生成时间戳
now_name
=
"asr_offline_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
now_name
=
"asr_offline_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
out_file_path
=
os
.
path
.
join
(
WAV_PATH
,
now_name
)
out_file_path
=
os
.
path
.
join
(
WAV_PATH
,
now_name
)
async
with
aiofiles
.
open
(
out_file_path
,
'wb'
)
as
out_file
:
async
with
aiofiles
.
open
(
out_file_path
,
'wb'
)
as
out_file
:
content
=
await
file
.
read
()
# async read
content
=
await
file
.
read
()
# async read
...
@@ -132,22 +135,18 @@ async def speech2textOfflineFile(files: List[UploadFile]):
...
@@ -132,22 +135,18 @@ async def speech2textOfflineFile(files: List[UploadFile]):
wav
=
float2pcm
(
wav
)
# float32 to int16
wav
=
float2pcm
(
wav
)
# float32 to int16
wav_bytes
=
wav
.
tobytes
()
# to bytes
wav_bytes
=
wav
.
tobytes
()
# to bytes
wav_base64
=
base64
.
b64encode
(
wav_bytes
).
decode
(
'utf8'
)
wav_base64
=
base64
.
b64encode
(
wav_bytes
).
decode
(
'utf8'
)
# 将文件重新写入
# 将文件重新写入
now_name
=
now_name
[:
-
4
]
+
"_16k"
+
".wav"
now_name
=
now_name
[:
-
4
]
+
"_16k"
+
".wav"
out_file_path
=
os
.
path
.
join
(
WAV_PATH
,
now_name
)
out_file_path
=
os
.
path
.
join
(
WAV_PATH
,
now_name
)
sf
.
write
(
out_file_path
,
wav
,
16000
)
sf
.
write
(
out_file_path
,
wav
,
16000
)
# 返回ASR识别结果
# 返回ASR识别结果
asr_res
=
chatbot
.
speech2text
(
out_file_path
)
asr_res
=
chatbot
.
speech2text
(
out_file_path
)
response_res
=
{
response_res
=
{
"asr_result"
:
asr_res
,
"wav_base64"
:
wav_base64
}
"asr_result"
:
asr_res
,
"wav_base64"
:
wav_base64
}
return
SuccessRequest
(
result
=
response_res
)
return
SuccessRequest
(
result
=
response_res
)
return
ErrorRequest
(
message
=
"上传文件为空"
)
return
ErrorRequest
(
message
=
"上传文件为空"
)
# 流式接收测试
# 流式接收测试
...
@@ -161,15 +160,17 @@ async def speech2textOnlineRecive(files: List[UploadFile]):
...
@@ -161,15 +160,17 @@ async def speech2textOnlineRecive(files: List[UploadFile]):
print
(
f
"audios长度变化:
{
len
(
audios
.
audios
)
}
"
)
print
(
f
"audios长度变化:
{
len
(
audios
.
audios
)
}
"
)
return
SuccessRequest
(
message
=
"接收成功"
)
return
SuccessRequest
(
message
=
"接收成功"
)
# 采集环境噪音大小
# 采集环境噪音大小
@
app
.
post
(
"/asr/collectEnv"
)
@
app
.
post
(
"/asr/collectEnv"
)
async
def
collectEnv
(
files
:
List
[
UploadFile
]):
async
def
collectEnv
(
files
:
List
[
UploadFile
]):
for
file
in
files
[:
1
]:
for
file
in
files
[:
1
]:
content
=
await
file
.
read
()
# async read
content
=
await
file
.
read
()
# async read
# 初始化, wav 前44字节是头部信息
# 初始化, wav 前44字节是头部信息
aumanager
.
compute_env_volume
(
content
[
44
:])
aumanager
.
compute_env_volume
(
content
[
44
:])
vad_
=
aumanager
.
vad_threshold
vad_
=
aumanager
.
vad_threshold
return
SuccessRequest
(
result
=
vad_
,
message
=
"采集环境噪音成功"
)
return
SuccessRequest
(
result
=
vad_
,
message
=
"采集环境噪音成功"
)
# 停止录音
# 停止录音
@
app
.
get
(
"/asr/stopRecord"
)
@
app
.
get
(
"/asr/stopRecord"
)
...
@@ -179,6 +180,7 @@ async def stopRecord():
...
@@ -179,6 +180,7 @@ async def stopRecord():
print
(
"Online录音暂停"
)
print
(
"Online录音暂停"
)
return
SuccessRequest
(
message
=
"停止成功"
)
return
SuccessRequest
(
message
=
"停止成功"
)
# 恢复录音
# 恢复录音
@
app
.
get
(
"/asr/resumeRecord"
)
@
app
.
get
(
"/asr/resumeRecord"
)
async
def
resumeRecord
():
async
def
resumeRecord
():
...
@@ -187,7 +189,7 @@ async def resumeRecord():
...
@@ -187,7 +189,7 @@ async def resumeRecord():
return
SuccessRequest
(
message
=
"Online录音恢复"
)
return
SuccessRequest
(
message
=
"Online录音恢复"
)
# 聊天用的ASR
# 聊天用的
ASR
@
app
.
websocket
(
"/ws/asr/offlineStream"
)
@
app
.
websocket
(
"/ws/asr/offlineStream"
)
async
def
websocket_endpoint
(
websocket
:
WebSocket
):
async
def
websocket_endpoint
(
websocket
:
WebSocket
):
await
manager
.
connect
(
websocket
)
await
manager
.
connect
(
websocket
)
...
@@ -210,9 +212,9 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -210,9 +212,9 @@ async def websocket_endpoint(websocket: WebSocket):
# print(f"用户-{user}-离开")
# print(f"用户-{user}-离开")
# Online识别的
ASR
# 流式识别的
ASR
@
app
.
websocket
(
'/ws/asr/onlineStream'
)
@
app
.
websocket
(
'/ws/asr/onlineStream'
)
async
def
websocket_endpoint
(
websocket
:
WebSocket
):
async
def
websocket_endpoint
_online
(
websocket
:
WebSocket
):
"""PaddleSpeech Online ASR Server api
"""PaddleSpeech Online ASR Server api
Args:
Args:
...
@@ -298,12 +300,14 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -298,12 +300,14 @@ async def websocket_endpoint(websocket: WebSocket):
except
WebSocketDisconnect
:
except
WebSocketDisconnect
:
pass
pass
######################################################################
######################################################################
########################### NLP 服务 #################################
########################### NLP 服务 #################################
#####################################################################
#####################################################################
@
app
.
post
(
"/nlp/chat"
)
@
app
.
post
(
"/nlp/chat"
)
async
def
chatOffline
(
nlp_base
:
NlpBase
):
async
def
chatOffline
(
nlp_base
:
NlpBase
):
chat
=
nlp_base
.
chat
chat
=
nlp_base
.
chat
if
not
chat
:
if
not
chat
:
return
ErrorRequest
(
message
=
"传入文本为空"
)
return
ErrorRequest
(
message
=
"传入文本为空"
)
...
@@ -311,8 +315,9 @@ async def chatOffline(nlp_base:NlpBase):
...
@@ -311,8 +315,9 @@ async def chatOffline(nlp_base:NlpBase):
res
=
chatbot
.
chat
(
chat
)
res
=
chatbot
.
chat
(
chat
)
return
SuccessRequest
(
result
=
res
)
return
SuccessRequest
(
result
=
res
)
@
app
.
post
(
"/nlp/ie"
)
@
app
.
post
(
"/nlp/ie"
)
async
def
ieOffline
(
nlp_base
:
NlpBase
):
async
def
ieOffline
(
nlp_base
:
NlpBase
):
nlp_text
=
nlp_base
.
chat
nlp_text
=
nlp_base
.
chat
if
not
nlp_text
:
if
not
nlp_text
:
return
ErrorRequest
(
message
=
"传入文本为空"
)
return
ErrorRequest
(
message
=
"传入文本为空"
)
...
@@ -320,17 +325,20 @@ async def ieOffline(nlp_base:NlpBase):
...
@@ -320,17 +325,20 @@ async def ieOffline(nlp_base:NlpBase):
res
=
chatbot
.
ie
(
nlp_text
)
res
=
chatbot
.
ie
(
nlp_text
)
return
SuccessRequest
(
result
=
res
)
return
SuccessRequest
(
result
=
res
)
######################################################################
######################################################################
########################### TTS 服务 #################################
########################### TTS 服务 #################################
#####################################################################
#####################################################################
@
app
.
post
(
"/tts/offline"
)
@
app
.
post
(
"/tts/offline"
)
async
def
text2speechOffline
(
tts_base
:
TtsBase
):
async
def
text2speechOffline
(
tts_base
:
TtsBase
):
text
=
tts_base
.
text
text
=
tts_base
.
text
if
not
text
:
if
not
text
:
return
ErrorRequest
(
message
=
"文本为空"
)
return
ErrorRequest
(
message
=
"文本为空"
)
else
:
else
:
now_name
=
"tts_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
now_name
=
"tts_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
out_file_path
=
os
.
path
.
join
(
WAV_PATH
,
now_name
)
out_file_path
=
os
.
path
.
join
(
WAV_PATH
,
now_name
)
# 保存为文件,再转成base64传输
# 保存为文件,再转成base64传输
chatbot
.
text2speech
(
text
,
outpath
=
out_file_path
)
chatbot
.
text2speech
(
text
,
outpath
=
out_file_path
)
...
@@ -339,12 +347,14 @@ async def text2speechOffline(tts_base:TtsBase):
...
@@ -339,12 +347,14 @@ async def text2speechOffline(tts_base:TtsBase):
base_str
=
base64
.
b64encode
(
data_bin
)
base_str
=
base64
.
b64encode
(
data_bin
)
return
SuccessRequest
(
result
=
base_str
)
return
SuccessRequest
(
result
=
base_str
)
# http流式TTS
# http流式TTS
@
app
.
post
(
"/tts/online"
)
@
app
.
post
(
"/tts/online"
)
async
def
stream_tts
(
request_body
:
TtsBase
):
async
def
stream_tts
(
request_body
:
TtsBase
):
text
=
request_body
.
text
text
=
request_body
.
text
return
StreamingResponse
(
chatbot
.
text2speechStreamBytes
(
text
=
text
))
return
StreamingResponse
(
chatbot
.
text2speechStreamBytes
(
text
=
text
))
# ws流式TTS
# ws流式TTS
@
app
.
websocket
(
"/ws/tts/online"
)
@
app
.
websocket
(
"/ws/tts/online"
)
async
def
stream_ttsWS
(
websocket
:
WebSocket
):
async
def
stream_ttsWS
(
websocket
:
WebSocket
):
...
@@ -356,17 +366,11 @@ async def stream_ttsWS(websocket: WebSocket):
...
@@ -356,17 +366,11 @@ async def stream_ttsWS(websocket: WebSocket):
if
text
:
if
text
:
for
sub_wav
in
chatbot
.
text2speechStream
(
text
=
text
):
for
sub_wav
in
chatbot
.
text2speechStream
(
text
=
text
):
# print("发送sub wav: ", len(sub_wav))
# print("发送sub wav: ", len(sub_wav))
res
=
{
res
=
{
"wav"
:
sub_wav
,
"done"
:
False
}
"wav"
:
sub_wav
,
"done"
:
False
}
await
websocket
.
send_json
(
res
)
await
websocket
.
send_json
(
res
)
# 输送结束
# 输送结束
res
=
{
res
=
{
"wav"
:
sub_wav
,
"done"
:
True
}
"wav"
:
sub_wav
,
"done"
:
True
}
await
websocket
.
send_json
(
res
)
await
websocket
.
send_json
(
res
)
# manager.disconnect(websocket)
# manager.disconnect(websocket)
...
@@ -396,8 +400,9 @@ async def vpr_enroll(table_name: str=None,
...
@@ -396,8 +400,9 @@ async def vpr_enroll(table_name: str=None,
return
{
'status'
:
False
,
'msg'
:
"spk_id can not be None"
}
return
{
'status'
:
False
,
'msg'
:
"spk_id can not be None"
}
# Save the upload data to server.
# Save the upload data to server.
content
=
await
audio
.
read
()
content
=
await
audio
.
read
()
now_name
=
"vpr_enroll_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
now_name
=
"vpr_enroll_"
+
datetime
.
datetime
.
strftime
(
audio_path
=
os
.
path
.
join
(
UPLOAD_PATH
,
now_name
)
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
audio_path
=
os
.
path
.
join
(
UPLOAD_PATH
,
now_name
)
with
open
(
audio_path
,
"wb+"
)
as
f
:
with
open
(
audio_path
,
"wb+"
)
as
f
:
f
.
write
(
content
)
f
.
write
(
content
)
...
@@ -413,20 +418,19 @@ async def vpr_recog(request: Request,
...
@@ -413,20 +418,19 @@ async def vpr_recog(request: Request,
audio
:
UploadFile
=
File
(...)):
audio
:
UploadFile
=
File
(...)):
# Voice print recognition online
# Voice print recognition online
# try:
# try:
# Save the upload data to server.
# Save the upload data to server.
content
=
await
audio
.
read
()
content
=
await
audio
.
read
()
now_name
=
"vpr_query_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
now_name
=
"vpr_query_"
+
datetime
.
datetime
.
strftime
(
query_audio_path
=
os
.
path
.
join
(
UPLOAD_PATH
,
now_name
)
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
query_audio_path
=
os
.
path
.
join
(
UPLOAD_PATH
,
now_name
)
with
open
(
query_audio_path
,
"wb+"
)
as
f
:
with
open
(
query_audio_path
,
"wb+"
)
as
f
:
f
.
write
(
content
)
f
.
write
(
content
)
spk_ids
,
paths
,
scores
=
vpr
.
do_search_vpr
(
query_audio_path
)
spk_ids
,
paths
,
scores
=
vpr
.
do_search_vpr
(
query_audio_path
)
res
=
dict
(
zip
(
spk_ids
,
zip
(
paths
,
scores
)))
res
=
dict
(
zip
(
spk_ids
,
zip
(
paths
,
scores
)))
# Sort results by distance metric, closest distances first
# Sort results by distance metric, closest distances first
res
=
sorted
(
res
.
items
(),
key
=
lambda
item
:
item
[
1
][
1
],
reverse
=
True
)
res
=
sorted
(
res
.
items
(),
key
=
lambda
item
:
item
[
1
][
1
],
reverse
=
True
)
return
res
return
res
# except Exception as e:
# return {'status': False, 'msg': e}, 400
@
app
.
post
(
'/vpr/del'
)
@
app
.
post
(
'/vpr/del'
)
...
@@ -460,17 +464,18 @@ async def vpr_database64(vprId: int):
...
@@ -460,17 +464,18 @@ async def vpr_database64(vprId: int):
return
{
'status'
:
False
,
'msg'
:
"vpr_id can not be None"
}
return
{
'status'
:
False
,
'msg'
:
"vpr_id can not be None"
}
audio_path
=
vpr
.
do_get_wav
(
vprId
)
audio_path
=
vpr
.
do_get_wav
(
vprId
)
# 返回base64
# 返回base64
# 将文件转成16k, 16bit类型的wav文件
# 将文件转成16k, 16bit类型的wav文件
wav
,
sr
=
librosa
.
load
(
audio_path
,
sr
=
16000
)
wav
,
sr
=
librosa
.
load
(
audio_path
,
sr
=
16000
)
wav
=
float2pcm
(
wav
)
# float32 to int16
wav
=
float2pcm
(
wav
)
# float32 to int16
wav_bytes
=
wav
.
tobytes
()
# to bytes
wav_bytes
=
wav
.
tobytes
()
# to bytes
wav_base64
=
base64
.
b64encode
(
wav_bytes
).
decode
(
'utf8'
)
wav_base64
=
base64
.
b64encode
(
wav_bytes
).
decode
(
'utf8'
)
return
SuccessRequest
(
result
=
wav_base64
)
return
SuccessRequest
(
result
=
wav_base64
)
except
Exception
as
e
:
except
Exception
as
e
:
return
{
'status'
:
False
,
'msg'
:
e
},
400
return
{
'status'
:
False
,
'msg'
:
e
},
400
@
app
.
get
(
'/vpr/data'
)
@
app
.
get
(
'/vpr/data'
)
async
def
vpr_data
(
vprId
:
int
):
async
def
vpr_data
(
vprId
:
int
):
# Get the audio file from path by spk_id in MySQL
# Get the audio file from path by spk_id in MySQL
...
@@ -482,11 +487,6 @@ async def vpr_data(vprId: int):
...
@@ -482,11 +487,6 @@ async def vpr_data(vprId: int):
except
Exception
as
e
:
except
Exception
as
e
:
return
{
'status'
:
False
,
'msg'
:
e
},
400
return
{
'status'
:
False
,
'msg'
:
e
},
400
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
uvicorn
.
run
(
app
=
app
,
host
=
'0.0.0.0'
,
port
=
port
)
uvicorn
.
run
(
app
=
app
,
host
=
'0.0.0.0'
,
port
=
port
)
demos/speech_web/speech_server/requirements.txt
浏览文件 @
795eb7bd
aiofiles
aiofiles
faiss-cpu
fastapi
fastapi
librosa
librosa
numpy
numpy
paddlenlp
paddlepaddle
paddlespeech
pydantic
pydantic
scikit_learn
python-multipart
scikit_learn
SoundFile
SoundFile
starlette
starlette
uvicorn
uvicorn
paddlepaddle
paddlespeech
paddlenlp
faiss-cpu
python-multipart
\ No newline at end of file
demos/speech_web/speech_server/src/AudioManeger.py
浏览文件 @
795eb7bd
import
imp
import
datetime
from
queue
import
Queue
import
numpy
as
np
import
os
import
os
import
wave
import
wave
import
random
import
datetime
import
numpy
as
np
from
.util
import
randName
from
.util
import
randName
class
AudioMannger
:
class
AudioMannger
:
def
__init__
(
self
,
robot
,
frame_length
=
160
,
frame
=
10
,
data_width
=
2
,
vad_default
=
300
):
def
__init__
(
self
,
robot
,
frame_length
=
160
,
frame
=
10
,
data_width
=
2
,
vad_default
=
300
):
# 二进制 pcm 流
# 二进制 pcm 流
self
.
audios
=
b
''
self
.
audios
=
b
''
self
.
asr_result
=
""
self
.
asr_result
=
""
...
@@ -20,8 +24,9 @@ class AudioMannger:
...
@@ -20,8 +24,9 @@ class AudioMannger:
os
.
makedirs
(
self
.
file_dir
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
file_dir
,
exist_ok
=
True
)
self
.
vad_deafult
=
vad_default
self
.
vad_deafult
=
vad_default
self
.
vad_threshold
=
vad_default
self
.
vad_threshold
=
vad_default
self
.
vad_threshold_path
=
os
.
path
.
join
(
self
.
file_dir
,
"vad_threshold.npy"
)
self
.
vad_threshold_path
=
os
.
path
.
join
(
self
.
file_dir
,
"vad_threshold.npy"
)
# 10ms 一帧
# 10ms 一帧
self
.
frame_length
=
frame_length
self
.
frame_length
=
frame_length
# 10帧,检测一次 vad
# 10帧,检测一次 vad
...
@@ -30,67 +35,64 @@ class AudioMannger:
...
@@ -30,67 +35,64 @@ class AudioMannger:
self
.
data_width
=
data_width
self
.
data_width
=
data_width
# window
# window
self
.
window_length
=
frame_length
*
frame
*
data_width
self
.
window_length
=
frame_length
*
frame
*
data_width
# 是否开始录音
# 是否开始录音
self
.
on_asr
=
False
self
.
on_asr
=
False
self
.
silence_cnt
=
0
self
.
silence_cnt
=
0
self
.
max_silence_cnt
=
4
self
.
max_silence_cnt
=
4
self
.
is_pause
=
False
# 录音暂停与恢复
self
.
is_pause
=
False
# 录音暂停与恢复
def
init
(
self
):
def
init
(
self
):
if
os
.
path
.
exists
(
self
.
vad_threshold_path
):
if
os
.
path
.
exists
(
self
.
vad_threshold_path
):
# 平均响度文件存在
# 平均响度文件存在
self
.
vad_threshold
=
np
.
load
(
self
.
vad_threshold_path
)
self
.
vad_threshold
=
np
.
load
(
self
.
vad_threshold_path
)
def
clear_audio
(
self
):
def
clear_audio
(
self
):
# 清空 pcm 累积片段与 asr 识别结果
# 清空 pcm 累积片段与 asr 识别结果
self
.
audios
=
b
''
self
.
audios
=
b
''
def
clear_asr
(
self
):
def
clear_asr
(
self
):
self
.
asr_result
=
""
self
.
asr_result
=
""
def
compute_chunk_volume
(
self
,
start_index
,
pcm_bins
):
def
compute_chunk_volume
(
self
,
start_index
,
pcm_bins
):
# 根据帧长计算能量平均值
# 根据帧长计算能量平均值
pcm_bin
=
pcm_bins
[
start_index
:
start_index
+
self
.
window_length
]
pcm_bin
=
pcm_bins
[
start_index
:
start_index
+
self
.
window_length
]
# 转成 numpy
# 转成 numpy
pcm_np
=
np
.
frombuffer
(
pcm_bin
,
np
.
int16
)
pcm_np
=
np
.
frombuffer
(
pcm_bin
,
np
.
int16
)
# 归一化 + 计算响度
# 归一化 + 计算响度
x
=
pcm_np
.
astype
(
np
.
float32
)
x
=
pcm_np
.
astype
(
np
.
float32
)
x
=
np
.
abs
(
x
)
x
=
np
.
abs
(
x
)
return
np
.
mean
(
x
)
return
np
.
mean
(
x
)
def
is_speech
(
self
,
start_index
,
pcm_bins
):
def
is_speech
(
self
,
start_index
,
pcm_bins
):
# 检查是否没
# 检查是否没
if
start_index
>
len
(
pcm_bins
):
if
start_index
>
len
(
pcm_bins
):
return
False
return
False
# 检查从这个 start 开始是否为静音帧
# 检查从这个 start 开始是否为静音帧
energy
=
self
.
compute_chunk_volume
(
start_index
=
start_index
,
pcm_bins
=
pcm_bins
)
energy
=
self
.
compute_chunk_volume
(
start_index
=
start_index
,
pcm_bins
=
pcm_bins
)
# print(energy)
# print(energy)
if
energy
>
self
.
vad_threshold
:
if
energy
>
self
.
vad_threshold
:
return
True
return
True
else
:
else
:
return
False
return
False
def
compute_env_volume
(
self
,
pcm_bins
):
def
compute_env_volume
(
self
,
pcm_bins
):
max_energy
=
0
max_energy
=
0
start
=
0
start
=
0
while
start
<
len
(
pcm_bins
):
while
start
<
len
(
pcm_bins
):
energy
=
self
.
compute_chunk_volume
(
start_index
=
start
,
pcm_bins
=
pcm_bins
)
energy
=
self
.
compute_chunk_volume
(
start_index
=
start
,
pcm_bins
=
pcm_bins
)
if
energy
>
max_energy
:
if
energy
>
max_energy
:
max_energy
=
energy
max_energy
=
energy
start
+=
self
.
window_length
start
+=
self
.
window_length
self
.
vad_threshold
=
max_energy
+
100
if
max_energy
>
self
.
vad_deafult
else
self
.
vad_deafult
self
.
vad_threshold
=
max_energy
+
100
if
max_energy
>
self
.
vad_deafult
else
self
.
vad_deafult
# 保存成文件
# 保存成文件
np
.
save
(
self
.
vad_threshold_path
,
self
.
vad_threshold
)
np
.
save
(
self
.
vad_threshold_path
,
self
.
vad_threshold
)
print
(
f
"vad 阈值大小:
{
self
.
vad_threshold
}
"
)
print
(
f
"vad 阈值大小:
{
self
.
vad_threshold
}
"
)
print
(
f
"环境采样保存:
{
os
.
path
.
realpath
(
self
.
vad_threshold_path
)
}
"
)
print
(
f
"环境采样保存:
{
os
.
path
.
realpath
(
self
.
vad_threshold_path
)
}
"
)
def
stream_asr
(
self
,
pcm_bin
):
def
stream_asr
(
self
,
pcm_bin
):
# 先把 pcm_bin 送进去做端点检测
# 先把 pcm_bin 送进去做端点检测
start
=
0
start
=
0
...
@@ -99,7 +101,7 @@ class AudioMannger:
...
@@ -99,7 +101,7 @@ class AudioMannger:
self
.
on_asr
=
True
self
.
on_asr
=
True
self
.
silence_cnt
=
0
self
.
silence_cnt
=
0
print
(
"录音中"
)
print
(
"录音中"
)
self
.
audios
+=
pcm_bin
[
start
:
start
+
self
.
window_length
]
self
.
audios
+=
pcm_bin
[
start
:
start
+
self
.
window_length
]
else
:
else
:
if
self
.
on_asr
:
if
self
.
on_asr
:
self
.
silence_cnt
+=
1
self
.
silence_cnt
+=
1
...
@@ -110,41 +112,42 @@ class AudioMannger:
...
@@ -110,41 +112,42 @@ class AudioMannger:
print
(
"录音停止"
)
print
(
"录音停止"
)
# audios 保存为 wav, 送入 ASR
# audios 保存为 wav, 送入 ASR
if
len
(
self
.
audios
)
>
2
*
16000
:
if
len
(
self
.
audios
)
>
2
*
16000
:
file_path
=
os
.
path
.
join
(
self
.
file_dir
,
"asr_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
)
file_path
=
os
.
path
.
join
(
self
.
file_dir
,
"asr_"
+
datetime
.
datetime
.
strftime
(
datetime
.
datetime
.
now
(),
'%Y%m%d%H%M%S'
)
+
randName
()
+
".wav"
)
self
.
save_audio
(
file_path
=
file_path
)
self
.
save_audio
(
file_path
=
file_path
)
self
.
asr_result
=
self
.
robot
.
speech2text
(
file_path
)
self
.
asr_result
=
self
.
robot
.
speech2text
(
file_path
)
self
.
clear_audio
()
self
.
clear_audio
()
return
self
.
asr_result
return
self
.
asr_result
else
:
else
:
# 正常接收
# 正常接收
print
(
"录音中 静音"
)
print
(
"录音中 静音"
)
self
.
audios
+=
pcm_bin
[
start
:
start
+
self
.
window_length
]
self
.
audios
+=
pcm_bin
[
start
:
start
+
self
.
window_length
]
start
+=
self
.
window_length
start
+=
self
.
window_length
return
""
return
""
def
save_audio
(
self
,
file_path
):
def
save_audio
(
self
,
file_path
):
print
(
"保存音频"
)
print
(
"保存音频"
)
wf
=
wave
.
open
(
file_path
,
'wb'
)
# 创建一个音频文件,名字为“01.wav"
wf
=
wave
.
open
(
file_path
,
'wb'
)
# 创建一个音频文件,名字为“01.wav"
wf
.
setnchannels
(
1
)
# 设置声道数为2
wf
.
setnchannels
(
1
)
# 设置声道数为2
wf
.
setsampwidth
(
2
)
# 设置采样深度为
wf
.
setsampwidth
(
2
)
# 设置采样深度为
wf
.
setframerate
(
16000
)
# 设置采样率为16000
wf
.
setframerate
(
16000
)
# 设置采样率为16000
# 将数据写入创建的音频文件
# 将数据写入创建的音频文件
wf
.
writeframes
(
self
.
audios
)
wf
.
writeframes
(
self
.
audios
)
# 写完后将文件关闭
# 写完后将文件关闭
wf
.
close
()
wf
.
close
()
def
end
(
self
):
def
end
(
self
):
# audios 保存为 wav, 送入 ASR
# audios 保存为 wav, 送入 ASR
file_path
=
os
.
path
.
join
(
self
.
file_dir
,
"asr.wav"
)
file_path
=
os
.
path
.
join
(
self
.
file_dir
,
"asr.wav"
)
self
.
save_audio
(
file_path
=
file_path
)
self
.
save_audio
(
file_path
=
file_path
)
return
self
.
robot
.
speech2text
(
file_path
)
return
self
.
robot
.
speech2text
(
file_path
)
def
stop
(
self
):
def
stop
(
self
):
self
.
is_pause
=
True
self
.
is_pause
=
True
self
.
audios
=
b
''
self
.
audios
=
b
''
def
resume
(
self
):
def
resume
(
self
):
self
.
is_pause
=
False
self
.
is_pause
=
False
\ No newline at end of file
demos/speech_web/speech_server/src/SpeechBase/asr.py
浏览文件 @
795eb7bd
from
re
import
sub
import
numpy
as
np
import
numpy
as
np
import
paddle
import
librosa
import
soundfile
from
paddlespeech.server.engine.asr.online.python.asr_engine
import
ASREngine
from
paddlespeech.server.engine.asr.online.python.asr_engine
import
ASREngine
from
paddlespeech.server.engine.asr.online.python.asr_engine
import
PaddleASRConnectionHanddler
from
paddlespeech.server.engine.asr.online.python.asr_engine
import
PaddleASRConnectionHanddler
from
paddlespeech.server.utils.config
import
get_config
from
paddlespeech.server.utils.config
import
get_config
def
readWave
(
samples
):
def
readWave
(
samples
):
x_len
=
len
(
samples
)
x_len
=
len
(
samples
)
...
@@ -31,20 +28,23 @@ def readWave(samples):
...
@@ -31,20 +28,23 @@ def readWave(samples):
class
ASR
:
class
ASR
:
def
__init__
(
self
,
config_path
,
)
->
None
:
def
__init__
(
self
,
config_path
,
)
->
None
:
self
.
config
=
get_config
(
config_path
)[
'asr_online'
]
self
.
config
=
get_config
(
config_path
)[
'asr_online'
]
self
.
engine
=
ASREngine
()
self
.
engine
=
ASREngine
()
self
.
engine
.
init
(
self
.
config
)
self
.
engine
.
init
(
self
.
config
)
self
.
connection_handler
=
PaddleASRConnectionHanddler
(
self
.
engine
)
self
.
connection_handler
=
PaddleASRConnectionHanddler
(
self
.
engine
)
def
offlineASR
(
self
,
samples
,
sample_rate
=
16000
):
def
offlineASR
(
self
,
samples
,
sample_rate
=
16000
):
x_chunk
,
x_chunk_lens
=
self
.
engine
.
preprocess
(
samples
=
samples
,
sample_rate
=
sample_rate
)
x_chunk
,
x_chunk_lens
=
self
.
engine
.
preprocess
(
samples
=
samples
,
sample_rate
=
sample_rate
)
self
.
engine
.
run
(
x_chunk
,
x_chunk_lens
)
self
.
engine
.
run
(
x_chunk
,
x_chunk_lens
)
result
=
self
.
engine
.
postprocess
()
result
=
self
.
engine
.
postprocess
()
self
.
engine
.
reset
()
self
.
engine
.
reset
()
return
result
return
result
def
onlineASR
(
self
,
samples
:
bytes
=
None
,
is_finished
=
False
):
def
onlineASR
(
self
,
samples
:
bytes
=
None
,
is_finished
=
False
):
if
not
is_finished
:
if
not
is_finished
:
# 流式开始
# 流式开始
self
.
connection_handler
.
extract_feat
(
samples
)
self
.
connection_handler
.
extract_feat
(
samples
)
...
@@ -58,5 +58,3 @@ class ASR:
...
@@ -58,5 +58,3 @@ class ASR:
asr_results
=
self
.
connection_handler
.
get_result
()
asr_results
=
self
.
connection_handler
.
get_result
()
self
.
connection_handler
.
reset
()
self
.
connection_handler
.
reset
()
return
asr_results
return
asr_results
\ No newline at end of file
demos/speech_web/speech_server/src/SpeechBase/nlp.py
浏览文件 @
795eb7bd
from
paddlenlp
import
Taskflow
from
paddlenlp
import
Taskflow
class
NLP
:
class
NLP
:
def
__init__
(
self
,
ie_model_path
=
None
):
def
__init__
(
self
,
ie_model_path
=
None
):
schema
=
[
"时间"
,
"出发地"
,
"目的地"
,
"费用"
]
schema
=
[
"时间"
,
"出发地"
,
"目的地"
,
"费用"
]
if
ie_model_path
:
if
ie_model_path
:
self
.
ie_model
=
Taskflow
(
"information_extraction"
,
self
.
ie_model
=
Taskflow
(
schema
=
schema
,
task_path
=
ie_model_path
)
"information_extraction"
,
schema
=
schema
,
task_path
=
ie_model_path
)
else
:
else
:
self
.
ie_model
=
Taskflow
(
"information_extraction"
,
self
.
ie_model
=
Taskflow
(
"information_extraction"
,
schema
=
schema
)
schema
=
schema
)
self
.
dialogue_model
=
Taskflow
(
"dialogue"
)
self
.
dialogue_model
=
Taskflow
(
"dialogue"
)
def
chat
(
self
,
text
):
def
chat
(
self
,
text
):
result
=
self
.
dialogue_model
([
text
])
result
=
self
.
dialogue_model
([
text
])
return
result
[
0
]
return
result
[
0
]
def
ie
(
self
,
text
):
def
ie
(
self
,
text
):
result
=
self
.
ie_model
(
text
)
result
=
self
.
ie_model
(
text
)
return
result
return
result
\ No newline at end of file
demos/speech_web/speech_server/src/SpeechBase/sql_helper.py
浏览文件 @
795eb7bd
import
base64
import
base64
import
sqlite3
import
os
import
os
import
sqlite3
import
numpy
as
np
import
numpy
as
np
from
pkg_resources
import
resource_stream
def
dict_factory
(
cursor
,
row
):
def
dict_factory
(
cursor
,
row
):
d
=
{}
d
=
{}
for
idx
,
col
in
enumerate
(
cursor
.
description
):
for
idx
,
col
in
enumerate
(
cursor
.
description
):
d
[
col
[
0
]]
=
row
[
idx
]
d
[
col
[
0
]]
=
row
[
idx
]
return
d
return
d
class
DataBase
(
object
):
class
DataBase
(
object
):
def
__init__
(
self
,
db_path
:
str
):
def
__init__
(
self
,
db_path
:
str
):
db_path
=
os
.
path
.
realpath
(
db_path
)
db_path
=
os
.
path
.
realpath
(
db_path
)
if
os
.
path
.
exists
(
db_path
):
if
os
.
path
.
exists
(
db_path
):
...
@@ -21,12 +22,12 @@ class DataBase(object):
...
@@ -21,12 +22,12 @@ class DataBase(object):
db_path_dir
=
os
.
path
.
dirname
(
db_path
)
db_path_dir
=
os
.
path
.
dirname
(
db_path
)
os
.
makedirs
(
db_path_dir
,
exist_ok
=
True
)
os
.
makedirs
(
db_path_dir
,
exist_ok
=
True
)
self
.
db_path
=
db_path
self
.
db_path
=
db_path
self
.
conn
=
sqlite3
.
connect
(
self
.
db_path
)
self
.
conn
=
sqlite3
.
connect
(
self
.
db_path
)
self
.
conn
.
row_factory
=
dict_factory
self
.
conn
.
row_factory
=
dict_factory
self
.
cursor
=
self
.
conn
.
cursor
()
self
.
cursor
=
self
.
conn
.
cursor
()
self
.
init_database
()
self
.
init_database
()
def
init_database
(
self
):
def
init_database
(
self
):
"""
"""
初始化数据库, 若表不存在则创建
初始化数据库, 若表不存在则创建
...
@@ -41,20 +42,21 @@ class DataBase(object):
...
@@ -41,20 +42,21 @@ class DataBase(object):
"""
"""
self
.
cursor
.
execute
(
sql
)
self
.
cursor
.
execute
(
sql
)
self
.
conn
.
commit
()
self
.
conn
.
commit
()
def
execute_base
(
self
,
sql
,
data_dict
):
def
execute_base
(
self
,
sql
,
data_dict
):
self
.
cursor
.
execute
(
sql
,
data_dict
)
self
.
cursor
.
execute
(
sql
,
data_dict
)
self
.
conn
.
commit
()
self
.
conn
.
commit
()
def
insert_one
(
self
,
username
,
vector_base64
:
str
,
wav_path
):
def
insert_one
(
self
,
username
,
vector_base64
:
str
,
wav_path
):
if
not
os
.
path
.
exists
(
wav_path
):
if
not
os
.
path
.
exists
(
wav_path
):
return
None
,
"wav not exists"
return
None
,
"wav not exists"
else
:
else
:
sql
=
f
"""
sql
=
"""
insert into
insert into
vprtable (username, vector, wavpath)
vprtable (username, vector, wavpath)
values (?, ?, ?)
values (?, ?, ?)
"""
"""
try
:
try
:
self
.
cursor
.
execute
(
sql
,
(
username
,
vector_base64
,
wav_path
))
self
.
cursor
.
execute
(
sql
,
(
username
,
vector_base64
,
wav_path
))
self
.
conn
.
commit
()
self
.
conn
.
commit
()
...
@@ -63,25 +65,27 @@ class DataBase(object):
...
@@ -63,25 +65,27 @@ class DataBase(object):
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
return
None
,
e
return
None
,
e
def
select_all
(
self
):
def
select_all
(
self
):
sql
=
"""
sql
=
"""
SELECT * from vprtable
SELECT * from vprtable
"""
"""
result
=
self
.
cursor
.
execute
(
sql
).
fetchall
()
result
=
self
.
cursor
.
execute
(
sql
).
fetchall
()
return
result
return
result
def
select_by_id
(
self
,
vpr_id
):
def
select_by_id
(
self
,
vpr_id
):
sql
=
f
"""
sql
=
f
"""
SELECT * from vprtable WHERE `id` =
{
vpr_id
}
SELECT * from vprtable WHERE `id` =
{
vpr_id
}
"""
"""
result
=
self
.
cursor
.
execute
(
sql
).
fetchall
()
result
=
self
.
cursor
.
execute
(
sql
).
fetchall
()
return
result
return
result
def
select_by_username
(
self
,
username
):
def
select_by_username
(
self
,
username
):
sql
=
f
"""
sql
=
f
"""
SELECT * from vprtable WHERE `username` = '
{
username
}
'
SELECT * from vprtable WHERE `username` = '
{
username
}
'
"""
"""
result
=
self
.
cursor
.
execute
(
sql
).
fetchall
()
result
=
self
.
cursor
.
execute
(
sql
).
fetchall
()
return
result
return
result
...
@@ -89,28 +93,30 @@ class DataBase(object):
...
@@ -89,28 +93,30 @@ class DataBase(object):
sql
=
f
"""
sql
=
f
"""
DELETE from vprtable WHERE `username`='
{
username
}
'
DELETE from vprtable WHERE `username`='
{
username
}
'
"""
"""
self
.
cursor
.
execute
(
sql
)
self
.
cursor
.
execute
(
sql
)
self
.
conn
.
commit
()
self
.
conn
.
commit
()
def
drop_all
(
self
):
def
drop_all
(
self
):
sql
=
f
"""
sql
=
"""
DELETE from vprtable
DELETE from vprtable
"""
"""
self
.
cursor
.
execute
(
sql
)
self
.
cursor
.
execute
(
sql
)
self
.
conn
.
commit
()
self
.
conn
.
commit
()
def
drop_table
(
self
):
def
drop_table
(
self
):
sql
=
f
"""
sql
=
"""
DROP TABLE vprtable
DROP TABLE vprtable
"""
"""
self
.
cursor
.
execute
(
sql
)
self
.
cursor
.
execute
(
sql
)
self
.
conn
.
commit
()
self
.
conn
.
commit
()
def
encode_vector
(
self
,
vector
:
np
.
ndarray
):
def
encode_vector
(
self
,
vector
:
np
.
ndarray
):
return
base64
.
b64encode
(
vector
).
decode
(
'utf8'
)
return
base64
.
b64encode
(
vector
).
decode
(
'utf8'
)
def
decode_vector
(
self
,
vector_base64
,
dtype
=
np
.
float32
):
def
decode_vector
(
self
,
vector_base64
,
dtype
=
np
.
float32
):
b
=
base64
.
b64decode
(
vector_base64
)
b
=
base64
.
b64decode
(
vector_base64
)
vc
=
np
.
frombuffer
(
b
,
dtype
=
dtype
)
vc
=
np
.
frombuffer
(
b
,
dtype
=
dtype
)
return
vc
return
vc
\ No newline at end of file
demos/speech_web/speech_server/src/SpeechBase/tts.py
浏览文件 @
795eb7bd
...
@@ -5,18 +5,19 @@
...
@@ -5,18 +5,19 @@
# 2. 加载模型
# 2. 加载模型
# 3. 端到端推理
# 3. 端到端推理
# 4. 流式推理
# 4. 流式推理
import
base64
import
base64
import
math
import
logging
import
logging
import
math
import
numpy
as
np
import
numpy
as
np
from
paddlespeech.server.utils.onnx_infer
import
get_sess
from
paddlespeech.t2s.frontend.zh_frontend
import
Frontend
from
paddlespeech.server.engine.tts.online.onnx.tts_engine
import
TTSEngine
from
paddlespeech.server.utils.util
import
denorm
,
get_chunks
from
paddlespeech.server.utils.audio_process
import
float2pcm
from
paddlespeech.server.utils.audio_process
import
float2pcm
from
paddlespeech.server.utils.config
import
get_config
from
paddlespeech.server.utils.config
import
get_config
from
paddlespeech.server.utils.util
import
denorm
from
paddlespeech.server.utils.util
import
get_chunks
from
paddlespeech.t2s.frontend.zh_frontend
import
Frontend
from
paddlespeech.server.engine.tts.online.onnx.tts_engine
import
TTSEngine
class
TTS
:
class
TTS
:
def
__init__
(
self
,
config_path
):
def
__init__
(
self
,
config_path
):
...
@@ -26,12 +27,12 @@ class TTS:
...
@@ -26,12 +27,12 @@ class TTS:
self
.
engine
.
init
(
self
.
config
)
self
.
engine
.
init
(
self
.
config
)
self
.
executor
=
self
.
engine
.
executor
self
.
executor
=
self
.
engine
.
executor
#self.engine.warm_up()
#self.engine.warm_up()
# 前端初始化
# 前端初始化
self
.
frontend
=
Frontend
(
self
.
frontend
=
Frontend
(
phone_vocab_path
=
self
.
engine
.
executor
.
phones_dict
,
phone_vocab_path
=
self
.
engine
.
executor
.
phones_dict
,
tone_vocab_path
=
None
)
tone_vocab_path
=
None
)
def
depadding
(
self
,
data
,
chunk_num
,
chunk_id
,
block
,
pad
,
upsample
):
def
depadding
(
self
,
data
,
chunk_num
,
chunk_id
,
block
,
pad
,
upsample
):
"""
"""
Streaming inference removes the result of pad inference
Streaming inference removes the result of pad inference
...
@@ -48,39 +49,37 @@ class TTS:
...
@@ -48,39 +49,37 @@ class TTS:
data
=
data
[
front_pad
*
upsample
:(
front_pad
+
block
)
*
upsample
]
data
=
data
[
front_pad
*
upsample
:(
front_pad
+
block
)
*
upsample
]
return
data
return
data
def
offlineTTS
(
self
,
text
):
def
offlineTTS
(
self
,
text
):
get_tone_ids
=
False
get_tone_ids
=
False
merge_sentences
=
False
merge_sentences
=
False
input_ids
=
self
.
frontend
.
get_input_ids
(
input_ids
=
self
.
frontend
.
get_input_ids
(
text
,
text
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
phone_ids
=
input_ids
[
"phone_ids"
]
phone_ids
=
input_ids
[
"phone_ids"
]
wav_list
=
[]
wav_list
=
[]
for
i
in
range
(
len
(
phone_ids
)):
for
i
in
range
(
len
(
phone_ids
)):
orig_hs
=
self
.
engine
.
executor
.
am_encoder_infer_sess
.
run
(
orig_hs
=
self
.
engine
.
executor
.
am_encoder_infer_sess
.
run
(
None
,
input_feed
=
{
'text'
:
phone_ids
[
i
].
numpy
()}
None
,
input_feed
=
{
'text'
:
phone_ids
[
i
].
numpy
()})
)
hs
=
orig_hs
[
0
]
hs
=
orig_hs
[
0
]
am_decoder_output
=
self
.
engine
.
executor
.
am_decoder_sess
.
run
(
am_decoder_output
=
self
.
engine
.
executor
.
am_decoder_sess
.
run
(
None
,
input_feed
=
{
'xs'
:
hs
})
None
,
input_feed
=
{
'xs'
:
hs
})
am_postnet_output
=
self
.
engine
.
executor
.
am_postnet_sess
.
run
(
am_postnet_output
=
self
.
engine
.
executor
.
am_postnet_sess
.
run
(
None
,
None
,
input_feed
=
{
input_feed
=
{
'xs'
:
np
.
transpose
(
am_decoder_output
[
0
],
(
0
,
2
,
1
))
'xs'
:
np
.
transpose
(
am_decoder_output
[
0
],
(
0
,
2
,
1
))
})
})
am_output_data
=
am_decoder_output
+
np
.
transpose
(
am_output_data
=
am_decoder_output
+
np
.
transpose
(
am_postnet_output
[
0
],
(
0
,
2
,
1
))
am_postnet_output
[
0
],
(
0
,
2
,
1
))
normalized_mel
=
am_output_data
[
0
][
0
]
normalized_mel
=
am_output_data
[
0
][
0
]
mel
=
denorm
(
normalized_mel
,
self
.
engine
.
executor
.
am_mu
,
self
.
engine
.
executor
.
am_std
)
mel
=
denorm
(
normalized_mel
,
self
.
engine
.
executor
.
am_mu
,
self
.
engine
.
executor
.
am_std
)
wav
=
self
.
engine
.
executor
.
voc_sess
.
run
(
wav
=
self
.
engine
.
executor
.
voc_sess
.
run
(
output_names
=
None
,
input_feed
=
{
'logmel'
:
mel
})[
0
]
output_names
=
None
,
input_feed
=
{
'logmel'
:
mel
})[
0
]
wav_list
.
append
(
wav
)
wav_list
.
append
(
wav
)
wavs
=
np
.
concatenate
(
wav_list
)
wavs
=
np
.
concatenate
(
wav_list
)
return
wavs
return
wavs
def
streamTTS
(
self
,
text
):
def
streamTTS
(
self
,
text
):
get_tone_ids
=
False
get_tone_ids
=
False
...
@@ -88,9 +87,7 @@ class TTS:
...
@@ -88,9 +87,7 @@ class TTS:
# front
# front
input_ids
=
self
.
frontend
.
get_input_ids
(
input_ids
=
self
.
frontend
.
get_input_ids
(
text
,
text
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
phone_ids
=
input_ids
[
"phone_ids"
]
phone_ids
=
input_ids
[
"phone_ids"
]
for
i
in
range
(
len
(
phone_ids
)):
for
i
in
range
(
len
(
phone_ids
)):
...
@@ -105,14 +102,15 @@ class TTS:
...
@@ -105,14 +102,15 @@ class TTS:
mel
=
mel
[
0
]
mel
=
mel
[
0
]
# voc streaming
# voc streaming
mel_chunks
=
get_chunks
(
mel
,
self
.
config
.
voc_block
,
self
.
config
.
voc_pad
,
"voc"
)
mel_chunks
=
get_chunks
(
mel
,
self
.
config
.
voc_block
,
self
.
config
.
voc_pad
,
"voc"
)
voc_chunk_num
=
len
(
mel_chunks
)
voc_chunk_num
=
len
(
mel_chunks
)
for
i
,
mel_chunk
in
enumerate
(
mel_chunks
):
for
i
,
mel_chunk
in
enumerate
(
mel_chunks
):
sub_wav
=
self
.
executor
.
voc_sess
.
run
(
sub_wav
=
self
.
executor
.
voc_sess
.
run
(
output_names
=
None
,
input_feed
=
{
'logmel'
:
mel_chunk
})
output_names
=
None
,
input_feed
=
{
'logmel'
:
mel_chunk
})
sub_wav
=
self
.
depadding
(
sub_wav
[
0
],
voc_chunk_num
,
i
,
sub_wav
=
self
.
depadding
(
self
.
config
.
voc_block
,
self
.
config
.
voc_pad
,
sub_wav
[
0
],
voc_chunk_num
,
i
,
self
.
config
.
voc_block
,
self
.
config
.
voc_upsample
)
self
.
config
.
voc_pad
,
self
.
config
.
voc_upsample
)
yield
self
.
after_process
(
sub_wav
)
yield
self
.
after_process
(
sub_wav
)
...
@@ -130,7 +128,8 @@ class TTS:
...
@@ -130,7 +128,8 @@ class TTS:
end
=
min
(
self
.
config
.
voc_block
+
self
.
config
.
voc_pad
,
mel_len
)
end
=
min
(
self
.
config
.
voc_block
+
self
.
config
.
voc_pad
,
mel_len
)
# streaming am
# streaming am
hss
=
get_chunks
(
orig_hs
,
self
.
config
.
am_block
,
self
.
config
.
am_pad
,
"am"
)
hss
=
get_chunks
(
orig_hs
,
self
.
config
.
am_block
,
self
.
config
.
am_pad
,
"am"
)
am_chunk_num
=
len
(
hss
)
am_chunk_num
=
len
(
hss
)
for
i
,
hs
in
enumerate
(
hss
):
for
i
,
hs
in
enumerate
(
hss
):
am_decoder_output
=
self
.
executor
.
am_decoder_sess
.
run
(
am_decoder_output
=
self
.
executor
.
am_decoder_sess
.
run
(
...
@@ -147,7 +146,8 @@ class TTS:
...
@@ -147,7 +146,8 @@ class TTS:
sub_mel
=
denorm
(
normalized_mel
,
self
.
executor
.
am_mu
,
sub_mel
=
denorm
(
normalized_mel
,
self
.
executor
.
am_mu
,
self
.
executor
.
am_std
)
self
.
executor
.
am_std
)
sub_mel
=
self
.
depadding
(
sub_mel
,
am_chunk_num
,
i
,
sub_mel
=
self
.
depadding
(
sub_mel
,
am_chunk_num
,
i
,
self
.
config
.
am_block
,
self
.
config
.
am_pad
,
1
)
self
.
config
.
am_block
,
self
.
config
.
am_pad
,
1
)
if
i
==
0
:
if
i
==
0
:
mel_streaming
=
sub_mel
mel_streaming
=
sub_mel
...
@@ -165,23 +165,22 @@ class TTS:
...
@@ -165,23 +165,22 @@ class TTS:
output_names
=
None
,
input_feed
=
{
'logmel'
:
voc_chunk
})
output_names
=
None
,
input_feed
=
{
'logmel'
:
voc_chunk
})
sub_wav
=
self
.
depadding
(
sub_wav
=
self
.
depadding
(
sub_wav
[
0
],
voc_chunk_num
,
voc_chunk_id
,
sub_wav
[
0
],
voc_chunk_num
,
voc_chunk_id
,
self
.
config
.
voc_block
,
self
.
config
.
voc_pad
,
self
.
config
.
voc_upsample
)
self
.
config
.
voc_block
,
self
.
config
.
voc_pad
,
self
.
config
.
voc_upsample
)
yield
self
.
after_process
(
sub_wav
)
yield
self
.
after_process
(
sub_wav
)
voc_chunk_id
+=
1
voc_chunk_id
+=
1
start
=
max
(
start
=
max
(
0
,
voc_chunk_id
*
self
.
config
.
voc_block
-
0
,
voc_chunk_id
*
self
.
config
.
voc_block
-
self
.
config
.
voc_pad
)
self
.
config
.
voc_pad
)
end
=
min
(
end
=
min
((
voc_chunk_id
+
1
)
*
self
.
config
.
voc_block
+
(
voc_chunk_id
+
1
)
*
self
.
config
.
voc_block
+
self
.
config
.
voc_pad
,
self
.
config
.
voc_pad
,
mel_len
)
mel_len
)
else
:
else
:
logging
.
error
(
logging
.
error
(
"Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts."
"Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts."
)
)
def
streamTTSBytes
(
self
,
text
):
def
streamTTSBytes
(
self
,
text
):
for
wav
in
self
.
engine
.
executor
.
infer
(
for
wav
in
self
.
engine
.
executor
.
infer
(
text
=
text
,
text
=
text
,
...
@@ -191,19 +190,14 @@ class TTS:
...
@@ -191,19 +190,14 @@ class TTS:
wav
=
float2pcm
(
wav
)
# float32 to int16
wav
=
float2pcm
(
wav
)
# float32 to int16
wav_bytes
=
wav
.
tobytes
()
# to bytes
wav_bytes
=
wav
.
tobytes
()
# to bytes
yield
wav_bytes
yield
wav_bytes
def
after_process
(
self
,
wav
):
def
after_process
(
self
,
wav
):
# for tvm
# for tvm
wav
=
float2pcm
(
wav
)
# float32 to int16
wav
=
float2pcm
(
wav
)
# float32 to int16
wav_bytes
=
wav
.
tobytes
()
# to bytes
wav_bytes
=
wav
.
tobytes
()
# to bytes
wav_base64
=
base64
.
b64encode
(
wav_bytes
).
decode
(
'utf8'
)
# to base64
wav_base64
=
base64
.
b64encode
(
wav_bytes
).
decode
(
'utf8'
)
# to base64
return
wav_base64
return
wav_base64
def
streamTTS_TVM
(
self
,
text
):
def
streamTTS_TVM
(
self
,
text
):
# 用 TVM 优化
# 用 TVM 优化
pass
pass
\ No newline at end of file
demos/speech_web/speech_server/src/SpeechBase/vpr.py
浏览文件 @
795eb7bd
# vpr Demo 没有使用 mysql 与 muilvs, 仅用于docker演示
# vpr Demo 没有使用 mysql 与 muilvs, 仅用于docker演示
import
logging
import
logging
import
faiss
import
faiss
from
matplotlib
import
use
import
numpy
as
np
import
numpy
as
np
from
.sql_helper
import
DataBase
from
.sql_helper
import
DataBase
from
.vpr_encode
import
get_audio_embedding
from
.vpr_encode
import
get_audio_embedding
class
VPR
:
class
VPR
:
def
__init__
(
self
,
db_path
,
dim
,
top_k
)
->
None
:
def
__init__
(
self
,
db_path
,
dim
,
top_k
)
->
None
:
# 初始化
# 初始化
...
@@ -14,15 +16,15 @@ class VPR:
...
@@ -14,15 +16,15 @@ class VPR:
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
dtype
=
np
.
float32
self
.
dtype
=
np
.
float32
self
.
vpr_idx
=
0
self
.
vpr_idx
=
0
# db 初始化
# db 初始化
self
.
db
=
DataBase
(
db_path
)
self
.
db
=
DataBase
(
db_path
)
# faiss 初始化
# faiss 初始化
index_ip
=
faiss
.
IndexFlatIP
(
dim
)
index_ip
=
faiss
.
IndexFlatIP
(
dim
)
self
.
index_ip
=
faiss
.
IndexIDMap
(
index_ip
)
self
.
index_ip
=
faiss
.
IndexIDMap
(
index_ip
)
self
.
init
()
self
.
init
()
def
init
(
self
):
def
init
(
self
):
# demo 初始化,把 mysql中的向量注册到 faiss 中
# demo 初始化,把 mysql中的向量注册到 faiss 中
sql_dbs
=
self
.
db
.
select_all
()
sql_dbs
=
self
.
db
.
select_all
()
...
@@ -34,12 +36,13 @@ class VPR:
...
@@ -34,12 +36,13 @@ class VPR:
if
len
(
vc
.
shape
)
==
1
:
if
len
(
vc
.
shape
)
==
1
:
vc
=
np
.
expand_dims
(
vc
,
axis
=
0
)
vc
=
np
.
expand_dims
(
vc
,
axis
=
0
)
# 构建数据库
# 构建数据库
self
.
index_ip
.
add_with_ids
(
vc
,
np
.
array
((
idx
,)).
astype
(
'int64'
))
self
.
index_ip
.
add_with_ids
(
vc
,
np
.
array
(
(
idx
,
)).
astype
(
'int64'
))
logging
.
info
(
"faiss 构建完毕"
)
logging
.
info
(
"faiss 构建完毕"
)
def
faiss_enroll
(
self
,
idx
,
vc
):
def
faiss_enroll
(
self
,
idx
,
vc
):
self
.
index_ip
.
add_with_ids
(
vc
,
np
.
array
((
idx
,)).
astype
(
'int64'
))
self
.
index_ip
.
add_with_ids
(
vc
,
np
.
array
((
idx
,
)).
astype
(
'int64'
))
def
vpr_enroll
(
self
,
username
,
wav_path
):
def
vpr_enroll
(
self
,
username
,
wav_path
):
# 注册声纹
# 注册声纹
emb
=
get_audio_embedding
(
wav_path
)
emb
=
get_audio_embedding
(
wav_path
)
...
@@ -53,21 +56,22 @@ class VPR:
...
@@ -53,21 +56,22 @@ class VPR:
else
:
else
:
last_idx
,
mess
=
None
last_idx
,
mess
=
None
return
last_idx
return
last_idx
def
vpr_recog
(
self
,
wav_path
):
def
vpr_recog
(
self
,
wav_path
):
# 识别声纹
# 识别声纹
emb_search
=
get_audio_embedding
(
wav_path
)
emb_search
=
get_audio_embedding
(
wav_path
)
if
emb_search
is
not
None
:
if
emb_search
is
not
None
:
emb_search
=
np
.
expand_dims
(
emb_search
,
axis
=
0
)
emb_search
=
np
.
expand_dims
(
emb_search
,
axis
=
0
)
D
,
I
=
self
.
index_ip
.
search
(
emb_search
,
self
.
top_k
)
D
,
I
=
self
.
index_ip
.
search
(
emb_search
,
self
.
top_k
)
D
=
D
.
tolist
()[
0
]
D
=
D
.
tolist
()[
0
]
I
=
I
.
tolist
()[
0
]
I
=
I
.
tolist
()[
0
]
return
[(
round
(
D
[
i
]
*
100
,
2
),
I
[
i
])
for
i
in
range
(
len
(
D
))
if
I
[
i
]
!=
-
1
]
return
[(
round
(
D
[
i
]
*
100
,
2
),
I
[
i
])
for
i
in
range
(
len
(
D
))
if
I
[
i
]
!=
-
1
]
else
:
else
:
logging
.
error
(
"识别失败"
)
logging
.
error
(
"识别失败"
)
return
None
return
None
def
do_search_vpr
(
self
,
wav_path
):
def
do_search_vpr
(
self
,
wav_path
):
spk_ids
,
paths
,
scores
=
[],
[],
[]
spk_ids
,
paths
,
scores
=
[],
[],
[]
recog_result
=
self
.
vpr_recog
(
wav_path
)
recog_result
=
self
.
vpr_recog
(
wav_path
)
...
@@ -78,41 +82,39 @@ class VPR:
...
@@ -78,41 +82,39 @@ class VPR:
scores
.
append
(
score
)
scores
.
append
(
score
)
paths
.
append
(
""
)
paths
.
append
(
""
)
return
spk_ids
,
paths
,
scores
return
spk_ids
,
paths
,
scores
def
vpr_del
(
self
,
username
):
def
vpr_del
(
self
,
username
):
# 根据用户username, 删除声纹
# 根据用户username, 删除声纹
# 查用户ID,删除对应向量
# 查用户ID,删除对应向量
res
=
self
.
db
.
select_by_username
(
username
)
res
=
self
.
db
.
select_by_username
(
username
)
for
r
in
res
:
for
r
in
res
:
idx
=
r
[
'id'
]
idx
=
r
[
'id'
]
self
.
index_ip
.
remove_ids
(
np
.
array
((
idx
,)).
astype
(
'int64'
))
self
.
index_ip
.
remove_ids
(
np
.
array
((
idx
,
)).
astype
(
'int64'
))
self
.
db
.
drop_by_username
(
username
)
self
.
db
.
drop_by_username
(
username
)
def
vpr_list
(
self
):
def
vpr_list
(
self
):
# 获取数据列表
# 获取数据列表
return
self
.
db
.
select_all
()
return
self
.
db
.
select_all
()
def
do_list
(
self
):
def
do_list
(
self
):
spk_ids
,
vpr_ids
=
[],
[]
spk_ids
,
vpr_ids
=
[],
[]
for
res
in
self
.
db
.
select_all
():
for
res
in
self
.
db
.
select_all
():
spk_ids
.
append
(
res
[
'username'
])
spk_ids
.
append
(
res
[
'username'
])
vpr_ids
.
append
(
res
[
'id'
])
vpr_ids
.
append
(
res
[
'id'
])
return
spk_ids
,
vpr_ids
return
spk_ids
,
vpr_ids
def
do_get_wav
(
self
,
vpr_idx
):
def
do_get_wav
(
self
,
vpr_idx
):
res
=
self
.
db
.
select_by_id
(
vpr_idx
)
res
=
self
.
db
.
select_by_id
(
vpr_idx
)
return
res
[
0
][
'wavpath'
]
return
res
[
0
][
'wavpath'
]
def
vpr_data
(
self
,
idx
):
def
vpr_data
(
self
,
idx
):
# 获取对应ID的数据
# 获取对应ID的数据
res
=
self
.
db
.
select_by_id
(
idx
)
res
=
self
.
db
.
select_by_id
(
idx
)
return
res
return
res
def
vpr_droptable
(
self
):
def
vpr_droptable
(
self
):
# 删除表
# 删除表
self
.
db
.
drop_table
()
self
.
db
.
drop_table
()
# 清空 faiss
# 清空 faiss
self
.
index_ip
.
reset
()
self
.
index_ip
.
reset
()
demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py
浏览文件 @
795eb7bd
from
paddlespeech.cli.vector
import
VectorExecutor
import
numpy
as
np
import
logging
import
logging
import
numpy
as
np
from
paddlespeech.cli.vector
import
VectorExecutor
vector_executor
=
VectorExecutor
()
vector_executor
=
VectorExecutor
()
def
get_audio_embedding
(
path
):
def
get_audio_embedding
(
path
):
"""
"""
Use vpr_inference to generate embedding of audio
Use vpr_inference to generate embedding of audio
...
@@ -16,5 +19,3 @@ def get_audio_embedding(path):
...
@@ -16,5 +19,3 @@ def get_audio_embedding(path):
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
f
"Error with embedding:
{
e
}
"
)
logging
.
error
(
f
"Error with embedding:
{
e
}
"
)
return
None
return
None
\ No newline at end of file
demos/speech_web/speech_server/src/WebsocketManeger.py
浏览文件 @
795eb7bd
...
@@ -2,6 +2,7 @@ from typing import List
...
@@ -2,6 +2,7 @@ from typing import List
from
fastapi
import
WebSocket
from
fastapi
import
WebSocket
class
ConnectionManager
:
class
ConnectionManager
:
def
__init__
(
self
):
def
__init__
(
self
):
# 存放激活的ws连接对象
# 存放激活的ws连接对象
...
@@ -28,4 +29,4 @@ class ConnectionManager:
...
@@ -28,4 +29,4 @@ class ConnectionManager:
await
connection
.
send_text
(
message
)
await
connection
.
send_text
(
message
)
manager
=
ConnectionManager
()
manager
=
ConnectionManager
()
\ No newline at end of file
demos/speech_web/speech_server/src/robot.py
浏览文件 @
795eb7bd
from
paddlespeech.cli.asr.infer
import
ASRExecutor
import
soundfile
as
sf
import
os
import
os
import
librosa
import
soundfile
as
sf
from
src.SpeechBase.asr
import
ASR
from
src.SpeechBase.asr
import
ASR
from
src.SpeechBase.tts
import
TTS
from
src.SpeechBase.nlp
import
NLP
from
src.SpeechBase.nlp
import
NLP
from
src.SpeechBase.tts
import
TTS
from
paddlespeech.cli.asr.infer
import
ASRExecutor
class
Robot
:
class
Robot
:
def
__init__
(
self
,
asr_config
,
tts_config
,
asr_init_path
,
def
__init__
(
self
,
asr_config
,
tts_config
,
asr_init_path
,
ie_model_path
=
None
)
->
None
:
ie_model_path
=
None
)
->
None
:
self
.
nlp
=
NLP
(
ie_model_path
=
ie_model_path
)
self
.
nlp
=
NLP
(
ie_model_path
=
ie_model_path
)
self
.
asr
=
ASR
(
config_path
=
asr_config
)
self
.
asr
=
ASR
(
config_path
=
asr_config
)
self
.
tts
=
TTS
(
config_path
=
tts_config
)
self
.
tts
=
TTS
(
config_path
=
tts_config
)
self
.
tts_sample_rate
=
24000
self
.
tts_sample_rate
=
24000
self
.
asr_sample_rate
=
16000
self
.
asr_sample_rate
=
16000
# 流式识别效果不如端到端的模型,这里流式模型与端到端模型分开
# 流式识别效果不如端到端的模型,这里流式模型与端到端模型分开
self
.
asr_model
=
ASRExecutor
()
self
.
asr_model
=
ASRExecutor
()
self
.
asr_name
=
"conformer_wenetspeech"
self
.
asr_name
=
"conformer_wenetspeech"
self
.
warm_up_asrmodel
(
asr_init_path
)
self
.
warm_up_asrmodel
(
asr_init_path
)
def
warm_up_asrmodel
(
self
,
asr_init_path
):
def
warm_up_asrmodel
(
self
,
asr_init_path
):
if
not
os
.
path
.
exists
(
asr_init_path
):
if
not
os
.
path
.
exists
(
asr_init_path
):
path_dir
=
os
.
path
.
dirname
(
asr_init_path
)
path_dir
=
os
.
path
.
dirname
(
asr_init_path
)
if
not
os
.
path
.
exists
(
path_dir
):
if
not
os
.
path
.
exists
(
path_dir
):
os
.
makedirs
(
path_dir
,
exist_ok
=
True
)
os
.
makedirs
(
path_dir
,
exist_ok
=
True
)
# TTS生成,采样率24000
# TTS生成,采样率24000
text
=
"生成初始音频"
text
=
"生成初始音频"
self
.
text2speech
(
text
,
asr_init_path
)
self
.
text2speech
(
text
,
asr_init_path
)
# asr model初始化
# asr model初始化
self
.
asr_model
(
asr_init_path
,
model
=
self
.
asr_name
,
lang
=
'zh'
,
self
.
asr_model
(
sample_rate
=
16000
,
force_yes
=
True
)
asr_init_path
,
model
=
self
.
asr_name
,
lang
=
'zh'
,
sample_rate
=
16000
,
force_yes
=
True
)
def
speech2text
(
self
,
audio_file
):
def
speech2text
(
self
,
audio_file
):
self
.
asr_model
.
preprocess
(
self
.
asr_name
,
audio_file
)
self
.
asr_model
.
preprocess
(
self
.
asr_name
,
audio_file
)
self
.
asr_model
.
infer
(
self
.
asr_name
)
self
.
asr_model
.
infer
(
self
.
asr_name
)
res
=
self
.
asr_model
.
postprocess
()
res
=
self
.
asr_model
.
postprocess
()
return
res
return
res
def
text2speech
(
self
,
text
,
outpath
):
def
text2speech
(
self
,
text
,
outpath
):
wav
=
self
.
tts
.
offlineTTS
(
text
)
wav
=
self
.
tts
.
offlineTTS
(
text
)
sf
.
write
(
sf
.
write
(
outpath
,
wav
,
samplerate
=
self
.
tts_sample_rate
)
outpath
,
wav
,
samplerate
=
self
.
tts_sample_rate
)
res
=
wav
res
=
wav
return
res
return
res
def
text2speechStream
(
self
,
text
):
def
text2speechStream
(
self
,
text
):
for
sub_wav_base64
in
self
.
tts
.
streamTTS
(
text
=
text
):
for
sub_wav_base64
in
self
.
tts
.
streamTTS
(
text
=
text
):
yield
sub_wav_base64
yield
sub_wav_base64
def
text2speechStreamBytes
(
self
,
text
):
def
text2speechStreamBytes
(
self
,
text
):
for
wav_bytes
in
self
.
tts
.
streamTTSBytes
(
text
=
text
):
for
wav_bytes
in
self
.
tts
.
streamTTSBytes
(
text
=
text
):
yield
wav_bytes
yield
wav_bytes
...
@@ -66,5 +70,3 @@ class Robot:
...
@@ -66,5 +70,3 @@ class Robot:
def
ie
(
self
,
text
):
def
ie
(
self
,
text
):
result
=
self
.
nlp
.
ie
(
text
)
result
=
self
.
nlp
.
ie
(
text
)
return
result
return
result
\ No newline at end of file
demos/speech_web/speech_server/src/util.py
浏览文件 @
795eb7bd
import
random
import
random
def
randName
(
n
=
5
):
def
randName
(
n
=
5
):
return
""
.
join
(
random
.
sample
(
'zyxwvutsrqponmlkjihgfedcba'
,
n
))
return
""
.
join
(
random
.
sample
(
'zyxwvutsrqponmlkjihgfedcba'
,
n
))
def
SuccessRequest
(
result
=
None
,
message
=
"ok"
):
def
SuccessRequest
(
result
=
None
,
message
=
"ok"
):
return
{
return
{
"code"
:
0
,
"result"
:
result
,
"message"
:
message
}
"code"
:
0
,
"result"
:
result
,
"message"
:
message
}
def
ErrorRequest
(
result
=
None
,
message
=
"error"
):
def
ErrorRequest
(
result
=
None
,
message
=
"error"
):
return
{
return
{
"code"
:
-
1
,
"result"
:
result
,
"message"
:
message
}
"code"
:
-
1
,
"result"
:
result
,
"message"
:
message
}
\ No newline at end of file
demos/streaming_asr_server/local/rtf_from_log.py
浏览文件 @
795eb7bd
...
@@ -34,7 +34,7 @@ if __name__ == '__main__':
...
@@ -34,7 +34,7 @@ if __name__ == '__main__':
n
=
0
n
=
0
for
m
in
rtfs
:
for
m
in
rtfs
:
# not accurate, may have duplicate log
# not accurate, may have duplicate log
n
+=
1
n
+=
1
T
+=
m
[
'T'
]
T
+=
m
[
'T'
]
P
+=
m
[
'P'
]
P
+=
m
[
'P'
]
...
...
docs/requirements.txt
浏览文件 @
795eb7bd
myst-parser
braceexpandcolorlog
numpydoc
recommonmark>=0.5.0
sphinx
sphinx-autobuild
sphinx-markdown-tables
sphinx_rtd_theme
paddlepaddle>=2.2.2
editdistance
editdistance
fastapi
g2p_en
g2p_en
g2pM
g2pM
h5py
h5py
...
@@ -14,40 +8,45 @@ inflect
...
@@ -14,40 +8,45 @@ inflect
jieba
jieba
jsonlines
jsonlines
kaldiio
kaldiio
keyboard
librosa==0.8.1
librosa==0.8.1
loguru
loguru
matplotlib
matplotlib
myst-parser
nara_wpe
nara_wpe
numpydoc
onnxruntime==1.10.0
onnxruntime==1.10.0
opencc
opencc
pandas
paddlenlp
paddlenlp
paddlepaddle>=2.2.2
paddlespeech_feat
paddlespeech_feat
pandas
pathos == 0.2.8
pattern_singleton
Pillow>=9.0.0
Pillow>=9.0.0
praatio==5.0.0
praatio==5.0.0
prettytable
pypinyin
pypinyin
pypinyin-dict
pypinyin-dict
python-dateutil
python-dateutil
pyworld==0.2.12
pyworld==0.2.12
recommonmark>=0.5.0
resampy==0.2.2
resampy==0.2.2
sacrebleu
sacrebleu
scipy
scipy
sentencepiece~=0.1.96
sentencepiece~=0.1.96
soundfile~=0.10
soundfile~=0.10
sphinx
sphinx-autobuild
sphinx-markdown-tables
sphinx_rtd_theme
textgrid
textgrid
timer
timer
tqdm
tqdm
typeguard
typeguard
uvicorn
visualdl
visualdl
webrtcvad
webrtcvad
websockets
yacs~=0.1.8
yacs~=0.1.8
prettytable
zhon
zhon
colorlog
pathos == 0.2.8
fastapi
websockets
keyboard
uvicorn
pattern_singleton
braceexpand
\ No newline at end of file
docs/source/conf.py
浏览文件 @
795eb7bd
...
@@ -20,10 +20,11 @@
...
@@ -20,10 +20,11 @@
# If extensions (or modules to document with autodoc) are in another directory,
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
# documentation root, use os.path.abspath to make it absolute, like shown here.
import
os
import
sys
import
recommonmark.parser
import
recommonmark.parser
import
sphinx_rtd_theme
import
sphinx_rtd_theme
import
sys
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
'../..'
))
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
'../..'
))
autodoc_mock_imports
=
[
"soundfile"
,
"librosa"
]
autodoc_mock_imports
=
[
"soundfile"
,
"librosa"
]
...
...
examples/iwslt2012/punc0/local/preprocess.py
浏览文件 @
795eb7bd
import
argparse
import
argparse
import
os
def
process_sentence
(
line
):
def
process_sentence
(
line
):
if
line
==
''
:
return
''
if
line
==
''
:
res
=
line
[
0
]
return
''
for
i
in
range
(
1
,
len
(
line
)):
res
=
line
[
0
]
res
+=
(
' '
+
line
[
i
])
for
i
in
range
(
1
,
len
(
line
)):
return
res
res
+=
(
' '
+
line
[
i
])
return
res
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
paser
=
argparse
.
ArgumentParser
(
description
=
"Input filename"
)
paser
=
argparse
.
ArgumentParser
(
description
=
"Input filename"
)
paser
.
add_argument
(
'-input_file'
)
paser
.
add_argument
(
'-input_file'
)
paser
.
add_argument
(
'-output_file'
)
paser
.
add_argument
(
'-output_file'
)
sentence_cnt
=
0
sentence_cnt
=
0
args
=
paser
.
parse_args
()
args
=
paser
.
parse_args
()
with
open
(
args
.
input_file
,
'r'
)
as
f
:
with
open
(
args
.
input_file
,
'r'
)
as
f
:
with
open
(
args
.
output_file
,
'w'
)
as
write_f
:
with
open
(
args
.
output_file
,
'w'
)
as
write_f
:
while
True
:
while
True
:
line
=
f
.
readline
()
line
=
f
.
readline
()
if
line
:
if
line
:
sentence_cnt
+=
1
sentence_cnt
+=
1
write_f
.
write
(
process_sentence
(
line
))
write_f
.
write
(
process_sentence
(
line
))
else
:
else
:
break
break
print
(
'preprocess over'
)
print
(
'preprocess over'
)
print
(
'total sentences number:'
,
sentence_cnt
)
print
(
'total sentences number:'
,
sentence_cnt
)
examples/other/tts_finetune/tts3/finetune.py
浏览文件 @
795eb7bd
...
@@ -17,15 +17,14 @@ from pathlib import Path
...
@@ -17,15 +17,14 @@ from pathlib import Path
from
typing
import
Union
from
typing
import
Union
import
yaml
import
yaml
from
paddle
import
distributed
as
dist
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.exps.fastspeech2.train
import
train_sp
from
local.check_oov
import
get_check_result
from
local.check_oov
import
get_check_result
from
local.extract
import
extract_feature
from
local.extract
import
extract_feature
from
local.label_process
import
get_single_label
from
local.label_process
import
get_single_label
from
local.prepare_env
import
generate_finetune_env
from
local.prepare_env
import
generate_finetune_env
from
paddle
import
distributed
as
dist
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.exps.fastspeech2.train
import
train_sp
from
utils.gen_duration_from_textgrid
import
gen_duration_from_textgrid
from
utils.gen_duration_from_textgrid
import
gen_duration_from_textgrid
DICT_EN
=
'tools/aligner/cmudict-0.7b'
DICT_EN
=
'tools/aligner/cmudict-0.7b'
...
...
paddlespeech/__init__.py
浏览文件 @
795eb7bd
...
@@ -14,5 +14,3 @@
...
@@ -14,5 +14,3 @@
import
_locale
import
_locale
_locale
.
_getdefaultlocale
=
(
lambda
*
args
:
[
'en_US'
,
'utf8'
])
_locale
.
_getdefaultlocale
=
(
lambda
*
args
:
[
'en_US'
,
'utf8'
])
paddlespeech/audio/__init__.py
浏览文件 @
795eb7bd
...
@@ -14,12 +14,12 @@
...
@@ -14,12 +14,12 @@
from
.
import
compliance
from
.
import
compliance
from
.
import
datasets
from
.
import
datasets
from
.
import
features
from
.
import
features
from
.
import
text
from
.
import
transform
from
.
import
streamdata
from
.
import
functional
from
.
import
functional
from
.
import
io
from
.
import
io
from
.
import
metric
from
.
import
metric
from
.
import
sox_effects
from
.
import
sox_effects
from
.
import
streamdata
from
.
import
text
from
.
import
transform
from
.backends
import
load
from
.backends
import
load
from
.backends
import
save
from
.backends
import
save
paddlespeech/audio/streamdata/__init__.py
浏览文件 @
795eb7bd
...
@@ -4,67 +4,66 @@
...
@@ -4,67 +4,66 @@
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
#
#
# flake8: noqa
# flake8: noqa
from
.cache
import
cached_tarfile_samples
from
.cache
import
(
from
.cache
import
cached_tarfile_to_samples
cached_tarfile_samples
,
from
.cache
import
lru_cleanup
cached_tarfile_to_samples
,
from
.cache
import
pipe_cleaner
lru_cleanup
,
from
.compat
import
FluidWrapper
pipe_cleaner
,
from
.compat
import
WebDataset
)
from
.compat
import
WebLoader
from
.compat
import
WebDataset
,
WebLoader
,
FluidWrapper
from
.extradatasets
import
MockDataset
from
.extradatasets
import
MockDataset
,
with_epoch
,
with_length
from
.extradatasets
import
with_epoch
from
.filters
import
(
from
.extradatasets
import
with_length
associate
,
from
.filters
import
associate
batched
,
from
.filters
import
audio_cmvn
decode
,
from
.filters
import
audio_compute_fbank
detshuffle
,
from
.filters
import
audio_data_filter
extract_keys
,
from
.filters
import
audio_padding
getfirst
,
from
.filters
import
audio_resample
info
,
from
.filters
import
audio_spec_aug
map
,
from
.filters
import
audio_tokenize
map_dict
,
from
.filters
import
batched
map_tuple
,
from
.filters
import
decode
pipelinefilter
,
from
.filters
import
detshuffle
rename
,
from
.filters
import
extract_keys
rename_keys
,
from
.filters
import
getfirst
audio_resample
,
from
.filters
import
info
select
,
from
.filters
import
map
shuffle
,
from
.filters
import
map_dict
slice
,
from
.filters
import
map_tuple
to_tuple
,
from
.filters
import
pipelinefilter
transform_with
,
from
.filters
import
placeholder
unbatched
,
from
.filters
import
rename
xdecode
,
from
.filters
import
rename_keys
audio_data_filter
,
from
.filters
import
select
audio_tokenize
,
from
.filters
import
shuffle
audio_resample
,
from
.filters
import
slice
audio_compute_fbank
,
from
.filters
import
sort
audio_spec_aug
,
from
.filters
import
to_tuple
sort
,
from
.filters
import
transform_with
audio_padding
,
from
.filters
import
unbatched
audio_cmvn
,
from
.filters
import
xdecode
placeholder
,
from
.handlers
import
ignore_and_continue
)
from
.handlers
import
ignore_and_stop
from
.handlers
import
(
from
.handlers
import
reraise_exception
ignore_and_continue
,
from
.handlers
import
warn_and_continue
ignore_and_stop
,
from
.handlers
import
warn_and_stop
reraise_exception
,
from
.mix
import
RandomMix
warn_and_continue
,
from
.mix
import
RoundRobin
warn_and_stop
,
)
from
.pipeline
import
DataPipeline
from
.pipeline
import
DataPipeline
from
.shardlists
import
(
from
.shardlists
import
MultiShardSample
MultiShardSample
,
from
.shardlists
import
non_empty
ResampledShards
,
from
.shardlists
import
resampled
SimpleShardList
,
from
.shardlists
import
ResampledShards
non_empty
,
from
.shardlists
import
shardspec
resampled
,
from
.shardlists
import
SimpleShardList
shardspec
,
from
.shardlists
import
single_node_only
single_node_only
,
from
.shardlists
import
split_by_node
split_by_node
,
from
.shardlists
import
split_by_worker
split_by_worker
,
from
.tariterators
import
tarfile_samples
)
from
.tariterators
import
tarfile_to_samples
from
.tariterators
import
tarfile_samples
,
tarfile_to_samples
from
.utils
import
PipelineStage
from
.utils
import
PipelineStage
,
repeatedly
from
.utils
import
repeatedly
from
.writer
import
ShardWriter
,
TarWriter
,
numpy_dumps
from
.writer
import
numpy_dumps
from
.mix
import
RandomMix
,
RoundRobin
from
.writer
import
ShardWriter
from
.writer
import
TarWriter
paddlespeech/audio/streamdata/autodecode.py
浏览文件 @
795eb7bd
...
@@ -5,18 +5,19 @@
...
@@ -5,18 +5,19 @@
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
#
#
"""Automatically decode webdataset samples."""
"""Automatically decode webdataset samples."""
import
io
import
io
,
json
,
os
,
pickle
,
re
,
tempfile
import
json
import
os
import
pickle
import
re
import
tempfile
from
functools
import
partial
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
"""Extensions passed on to the image decoder."""
"""Extensions passed on to the image decoder."""
image_extensions
=
"jpg jpeg png ppm pgm pbm pnm"
.
split
()
image_extensions
=
"jpg jpeg png ppm pgm pbm pnm"
.
split
()
################################################################
################################################################
# handle basic datatypes
# handle basic datatypes
################################################################
################################################################
...
@@ -128,7 +129,7 @@ def call_extension_handler(key, data, f, extensions):
...
@@ -128,7 +129,7 @@ def call_extension_handler(key, data, f, extensions):
target
=
target
.
split
(
"."
)
target
=
target
.
split
(
"."
)
if
len
(
target
)
>
len
(
extension
):
if
len
(
target
)
>
len
(
extension
):
continue
continue
if
extension
[
-
len
(
target
)
:]
==
target
:
if
extension
[
-
len
(
target
):]
==
target
:
return
f
(
data
)
return
f
(
data
)
return
None
return
None
...
@@ -268,7 +269,6 @@ def imagehandler(imagespec, extensions=image_extensions):
...
@@ -268,7 +269,6 @@ def imagehandler(imagespec, extensions=image_extensions):
################################################################
################################################################
# torch video
# torch video
################################################################
################################################################
'''
'''
def torch_video(key, data):
def torch_video(key, data):
"""Decode video using the torchvideo library.
"""Decode video using the torchvideo library.
...
@@ -289,7 +289,6 @@ def torch_video(key, data):
...
@@ -289,7 +289,6 @@ def torch_video(key, data):
return torchvision.io.read_video(fname, pts_unit="sec")
return torchvision.io.read_video(fname, pts_unit="sec")
'''
'''
################################################################
################################################################
# paddlespeech.audio
# paddlespeech.audio
################################################################
################################################################
...
@@ -359,7 +358,6 @@ def gzfilter(key, data):
...
@@ -359,7 +358,6 @@ def gzfilter(key, data):
# decode entire training amples
# decode entire training amples
################################################################
################################################################
default_pre_handlers
=
[
gzfilter
]
default_pre_handlers
=
[
gzfilter
]
default_post_handlers
=
[
basichandlers
]
default_post_handlers
=
[
basichandlers
]
...
@@ -387,7 +385,8 @@ class Decoder:
...
@@ -387,7 +385,8 @@ class Decoder:
pre
=
default_pre_handlers
pre
=
default_pre_handlers
if
post
is
None
:
if
post
is
None
:
post
=
default_post_handlers
post
=
default_post_handlers
assert
all
(
callable
(
h
)
for
h
in
handlers
),
f
"one of
{
handlers
}
not callable"
assert
all
(
callable
(
h
)
for
h
in
handlers
),
f
"one of
{
handlers
}
not callable"
assert
all
(
callable
(
h
)
for
h
in
pre
),
f
"one of
{
pre
}
not callable"
assert
all
(
callable
(
h
)
for
h
in
pre
),
f
"one of
{
pre
}
not callable"
assert
all
(
callable
(
h
)
for
h
in
post
),
f
"one of
{
post
}
not callable"
assert
all
(
callable
(
h
)
for
h
in
post
),
f
"one of
{
post
}
not callable"
self
.
handlers
=
pre
+
handlers
+
post
self
.
handlers
=
pre
+
handlers
+
post
...
...
paddlespeech/audio/streamdata/cache.py
浏览文件 @
795eb7bd
...
@@ -2,7 +2,10 @@
...
@@ -2,7 +2,10 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
import
itertools
,
os
,
random
,
re
,
sys
import
os
import
random
import
re
import
sys
from
urllib.parse
import
urlparse
from
urllib.parse
import
urlparse
from
.
import
filters
from
.
import
filters
...
@@ -40,7 +43,7 @@ def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False):
...
@@ -40,7 +43,7 @@ def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False):
os
.
remove
(
fname
)
os
.
remove
(
fname
)
def
download
(
url
,
dest
,
chunk_size
=
1024
**
2
,
verbose
=
False
):
def
download
(
url
,
dest
,
chunk_size
=
1024
**
2
,
verbose
=
False
):
"""Download a file from `url` to `dest`."""
"""Download a file from `url` to `dest`."""
temp
=
dest
+
f
".temp
{
os
.
getpid
()
}
"
temp
=
dest
+
f
".temp
{
os
.
getpid
()
}
"
with
gopen
.
gopen
(
url
)
as
stream
:
with
gopen
.
gopen
(
url
)
as
stream
:
...
@@ -65,12 +68,11 @@ def pipe_cleaner(spec):
...
@@ -65,12 +68,11 @@ def pipe_cleaner(spec):
def
get_file_cached
(
def
get_file_cached
(
spec
,
spec
,
cache_size
=-
1
,
cache_size
=-
1
,
cache_dir
=
None
,
cache_dir
=
None
,
url_to_name
=
pipe_cleaner
,
url_to_name
=
pipe_cleaner
,
verbose
=
False
,
verbose
=
False
,
):
):
if
cache_size
==
-
1
:
if
cache_size
==
-
1
:
cache_size
=
default_cache_size
cache_size
=
default_cache_size
if
cache_dir
is
None
:
if
cache_dir
is
None
:
...
@@ -107,15 +109,14 @@ verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0"))
...
@@ -107,15 +109,14 @@ verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0"))
def
cached_url_opener
(
def
cached_url_opener
(
data
,
data
,
handler
=
reraise_exception
,
handler
=
reraise_exception
,
cache_size
=-
1
,
cache_size
=-
1
,
cache_dir
=
None
,
cache_dir
=
None
,
url_to_name
=
pipe_cleaner
,
url_to_name
=
pipe_cleaner
,
validator
=
check_tar_format
,
validator
=
check_tar_format
,
verbose
=
False
,
verbose
=
False
,
always
=
False
,
always
=
False
,
):
):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
verbose
=
verbose
or
verbose_cache
verbose
=
verbose
or
verbose_cache
for
sample
in
data
:
for
sample
in
data
:
...
@@ -132,8 +133,7 @@ def cached_url_opener(
...
@@ -132,8 +133,7 @@ def cached_url_opener(
cache_size
=
cache_size
,
cache_size
=
cache_size
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
url_to_name
=
url_to_name
,
url_to_name
=
url_to_name
,
verbose
=
verbose
,
verbose
=
verbose
,
)
)
if
verbose
:
if
verbose
:
print
(
"# opening %s"
%
dest
,
file
=
sys
.
stderr
)
print
(
"# opening %s"
%
dest
,
file
=
sys
.
stderr
)
assert
os
.
path
.
exists
(
dest
)
assert
os
.
path
.
exists
(
dest
)
...
@@ -143,9 +143,8 @@ def cached_url_opener(
...
@@ -143,9 +143,8 @@ def cached_url_opener(
data
=
f
.
read
(
200
)
data
=
f
.
read
(
200
)
os
.
remove
(
dest
)
os
.
remove
(
dest
)
raise
ValueError
(
raise
ValueError
(
"%s (%s) is not a tar archive, but a %s, contains %s"
"%s (%s) is not a tar archive, but a %s, contains %s"
%
%
(
dest
,
url
,
ftype
,
repr
(
data
))
(
dest
,
url
,
ftype
,
repr
(
data
)))
)
try
:
try
:
stream
=
open
(
dest
,
"rb"
)
stream
=
open
(
dest
,
"rb"
)
sample
.
update
(
stream
=
stream
)
sample
.
update
(
stream
=
stream
)
...
@@ -158,7 +157,7 @@ def cached_url_opener(
...
@@ -158,7 +157,7 @@ def cached_url_opener(
continue
continue
raise
exn
raise
exn
except
Exception
as
exn
:
except
Exception
as
exn
:
exn
.
args
=
exn
.
args
+
(
url
,)
exn
.
args
=
exn
.
args
+
(
url
,
)
if
handler
(
exn
):
if
handler
(
exn
):
continue
continue
else
:
else
:
...
@@ -166,14 +165,13 @@ def cached_url_opener(
...
@@ -166,14 +165,13 @@ def cached_url_opener(
def
cached_tarfile_samples
(
def
cached_tarfile_samples
(
src
,
src
,
handler
=
reraise_exception
,
handler
=
reraise_exception
,
cache_size
=-
1
,
cache_size
=-
1
,
cache_dir
=
None
,
cache_dir
=
None
,
verbose
=
False
,
verbose
=
False
,
url_to_name
=
pipe_cleaner
,
url_to_name
=
pipe_cleaner
,
always
=
False
,
always
=
False
,
):
):
streams
=
cached_url_opener
(
streams
=
cached_url_opener
(
src
,
src
,
handler
=
handler
,
handler
=
handler
,
...
@@ -181,8 +179,7 @@ def cached_tarfile_samples(
...
@@ -181,8 +179,7 @@ def cached_tarfile_samples(
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
verbose
=
verbose
,
verbose
=
verbose
,
url_to_name
=
url_to_name
,
url_to_name
=
url_to_name
,
always
=
always
,
always
=
always
,
)
)
samples
=
tar_file_and_group_expander
(
streams
,
handler
=
handler
)
samples
=
tar_file_and_group_expander
(
streams
,
handler
=
handler
)
return
samples
return
samples
...
...
paddlespeech/audio/streamdata/compat.py
浏览文件 @
795eb7bd
...
@@ -2,17 +2,17 @@
...
@@ -2,17 +2,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
from
dataclasses
import
dataclass
import
yaml
from
itertools
import
islice
from
typing
import
List
import
braceexpand
,
yaml
from
.
import
autodecode
from
.
import
autodecode
from
.
import
cache
,
filters
,
shardlists
,
tariterators
from
.
import
cache
from
.
import
filters
from
.
import
shardlists
from
.
import
tariterators
from
.filters
import
reraise_exception
from
.filters
import
reraise_exception
from
.paddle_utils
import
DataLoader
from
.paddle_utils
import
IterableDataset
from
.pipeline
import
DataPipeline
from
.pipeline
import
DataPipeline
from
.paddle_utils
import
DataLoader
,
IterableDataset
class
FluidInterface
:
class
FluidInterface
:
...
@@ -26,7 +26,8 @@ class FluidInterface:
...
@@ -26,7 +26,8 @@ class FluidInterface:
return
self
.
compose
(
filters
.
unbatched
())
return
self
.
compose
(
filters
.
unbatched
())
def
listed
(
self
,
batchsize
,
partial
=
True
):
def
listed
(
self
,
batchsize
,
partial
=
True
):
return
self
.
compose
(
filters
.
batched
(),
batchsize
=
batchsize
,
collation_fn
=
None
)
return
self
.
compose
(
filters
.
batched
(),
batchsize
=
batchsize
,
collation_fn
=
None
)
def
unlisted
(
self
):
def
unlisted
(
self
):
return
self
.
compose
(
filters
.
unlisted
())
return
self
.
compose
(
filters
.
unlisted
())
...
@@ -43,9 +44,19 @@ class FluidInterface:
...
@@ -43,9 +44,19 @@ class FluidInterface:
def
map
(
self
,
f
,
handler
=
reraise_exception
):
def
map
(
self
,
f
,
handler
=
reraise_exception
):
return
self
.
compose
(
filters
.
map
(
f
,
handler
=
handler
))
return
self
.
compose
(
filters
.
map
(
f
,
handler
=
handler
))
def
decode
(
self
,
*
args
,
pre
=
None
,
post
=
None
,
only
=
None
,
partial
=
False
,
handler
=
reraise_exception
):
def
decode
(
self
,
handlers
=
[
autodecode
.
ImageHandler
(
x
)
if
isinstance
(
x
,
str
)
else
x
for
x
in
args
]
*
args
,
decoder
=
autodecode
.
Decoder
(
handlers
,
pre
=
pre
,
post
=
post
,
only
=
only
,
partial
=
partial
)
pre
=
None
,
post
=
None
,
only
=
None
,
partial
=
False
,
handler
=
reraise_exception
):
handlers
=
[
autodecode
.
ImageHandler
(
x
)
if
isinstance
(
x
,
str
)
else
x
for
x
in
args
]
decoder
=
autodecode
.
Decoder
(
handlers
,
pre
=
pre
,
post
=
post
,
only
=
only
,
partial
=
partial
)
return
self
.
map
(
decoder
,
handler
=
handler
)
return
self
.
map
(
decoder
,
handler
=
handler
)
def
map_dict
(
self
,
handler
=
reraise_exception
,
**
kw
):
def
map_dict
(
self
,
handler
=
reraise_exception
,
**
kw
):
...
@@ -80,12 +91,12 @@ class FluidInterface:
...
@@ -80,12 +91,12 @@ class FluidInterface:
def
audio_data_filter
(
self
,
*
args
,
**
kw
):
def
audio_data_filter
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_data_filter
(
*
args
,
**
kw
))
return
self
.
compose
(
filters
.
audio_data_filter
(
*
args
,
**
kw
))
def
audio_tokenize
(
self
,
*
args
,
**
kw
):
def
audio_tokenize
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_tokenize
(
*
args
,
**
kw
))
return
self
.
compose
(
filters
.
audio_tokenize
(
*
args
,
**
kw
))
def
resample
(
self
,
*
args
,
**
kw
):
def
resample
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
resample
(
*
args
,
**
kw
))
return
self
.
compose
(
filters
.
resample
(
*
args
,
**
kw
))
def
audio_compute_fbank
(
self
,
*
args
,
**
kw
):
def
audio_compute_fbank
(
self
,
*
args
,
**
kw
):
return
self
.
compose
(
filters
.
audio_compute_fbank
(
*
args
,
**
kw
))
return
self
.
compose
(
filters
.
audio_compute_fbank
(
*
args
,
**
kw
))
...
@@ -102,27 +113,28 @@ class FluidInterface:
...
@@ -102,27 +113,28 @@ class FluidInterface:
def
audio_cmvn
(
self
,
cmvn_file
):
def
audio_cmvn
(
self
,
cmvn_file
):
return
self
.
compose
(
filters
.
audio_cmvn
(
cmvn_file
))
return
self
.
compose
(
filters
.
audio_cmvn
(
cmvn_file
))
class
WebDataset
(
DataPipeline
,
FluidInterface
):
class
WebDataset
(
DataPipeline
,
FluidInterface
):
"""Small fluid-interface wrapper for DataPipeline."""
"""Small fluid-interface wrapper for DataPipeline."""
def
__init__
(
def
__init__
(
self
,
self
,
urls
,
urls
,
handler
=
reraise_exception
,
handler
=
reraise_exception
,
resampled
=
False
,
resampled
=
False
,
repeat
=
False
,
repeat
=
False
,
shardshuffle
=
None
,
shardshuffle
=
None
,
cache_size
=
0
,
cache_size
=
0
,
cache_dir
=
None
,
cache_dir
=
None
,
detshuffle
=
False
,
detshuffle
=
False
,
nodesplitter
=
shardlists
.
single_node_only
,
nodesplitter
=
shardlists
.
single_node_only
,
verbose
=
False
,
verbose
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
if
isinstance
(
urls
,
IterableDataset
):
if
isinstance
(
urls
,
IterableDataset
):
assert
not
resampled
assert
not
resampled
self
.
append
(
urls
)
self
.
append
(
urls
)
elif
isinstance
(
urls
,
str
)
and
(
urls
.
endswith
(
".yaml"
)
or
urls
.
endswith
(
".yml"
)):
elif
isinstance
(
urls
,
str
)
and
(
urls
.
endswith
(
".yaml"
)
or
urls
.
endswith
(
".yml"
)):
with
(
open
(
urls
))
as
stream
:
with
(
open
(
urls
))
as
stream
:
spec
=
yaml
.
safe_load
(
stream
)
spec
=
yaml
.
safe_load
(
stream
)
assert
"datasets"
in
spec
assert
"datasets"
in
spec
...
@@ -152,9 +164,7 @@ class WebDataset(DataPipeline, FluidInterface):
...
@@ -152,9 +164,7 @@ class WebDataset(DataPipeline, FluidInterface):
handler
=
handler
,
handler
=
handler
,
verbose
=
verbose
,
verbose
=
verbose
,
cache_size
=
cache_size
,
cache_size
=
cache_size
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
))
)
)
class
FluidWrapper
(
DataPipeline
,
FluidInterface
):
class
FluidWrapper
(
DataPipeline
,
FluidInterface
):
...
...
paddlespeech/audio/streamdata/extradatasets.py
浏览文件 @
795eb7bd
...
@@ -5,20 +5,10 @@
...
@@ -5,20 +5,10 @@
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
#
#
"""Train PyTorch models directly from POSIX tar archive.
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
Code works locally or over HTTP connections.
"""
"""
import
itertools
as
itt
import
os
import
random
import
sys
import
braceexpand
from
.
import
utils
from
.
import
utils
from
.paddle_utils
import
IterableDataset
from
.paddle_utils
import
IterableDataset
from
.utils
import
PipelineStage
from
.utils
import
PipelineStage
...
@@ -63,8 +53,7 @@ class repeatedly(IterableDataset, PipelineStage):
...
@@ -63,8 +53,7 @@ class repeatedly(IterableDataset, PipelineStage):
return
utils
.
repeatedly
(
return
utils
.
repeatedly
(
source
,
source
,
nepochs
=
self
.
nepochs
,
nepochs
=
self
.
nepochs
,
nbatches
=
self
.
nbatches
,
nbatches
=
self
.
nbatches
,
)
)
class
with_epoch
(
IterableDataset
):
class
with_epoch
(
IterableDataset
):
...
...
paddlespeech/audio/streamdata/filters.py
浏览文件 @
795eb7bd
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
# This file is part of the WebDataset library.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
#
#
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""A collection of iterators for data transformations.
"""A collection of iterators for data transformations.
...
@@ -12,28 +11,29 @@ These functions are plain iterator functions. You can find curried versions
...
@@ -12,28 +11,29 @@ These functions are plain iterator functions. You can find curried versions
in webdataset.filters, and you can find IterableDataset wrappers in
in webdataset.filters, and you can find IterableDataset wrappers in
webdataset.processing.
webdataset.processing.
"""
"""
import
io
import
io
from
fnmatch
import
fnmatch
import
itertools
import
os
import
random
import
re
import
re
import
itertools
,
os
,
random
,
sys
,
time
import
sys
from
functools
import
reduce
,
wraps
import
time
from
fnmatch
import
fnmatch
from
functools
import
reduce
import
numpy
as
np
import
paddle
from
.
import
autodecode
from
.
import
autodecode
from
.
import
utils
from
.
import
utils
from
.paddle_utils
import
PaddleTensor
from
.utils
import
PipelineStage
from
..
import
backends
from
..
import
backends
from
..compliance
import
kaldi
from
..compliance
import
kaldi
import
paddle
from
..transform.cmvn
import
GlobalCMVN
from
..transform.cmvn
import
GlobalCMVN
from
..utils.tensor_utils
import
pad_sequence
from
..transform.spec_augment
import
time_warp
from
..transform.spec_augment
import
time_mask
from
..transform.spec_augment
import
freq_mask
from
..transform.spec_augment
import
freq_mask
from
..transform.spec_augment
import
time_mask
from
..transform.spec_augment
import
time_warp
from
..utils.tensor_utils
import
pad_sequence
from
.utils
import
PipelineStage
class
FilterFunction
(
object
):
class
FilterFunction
(
object
):
"""Helper class for currying pipeline stages.
"""Helper class for currying pipeline stages.
...
@@ -159,10 +159,12 @@ def transform_with(sample, transformers):
...
@@ -159,10 +159,12 @@ def transform_with(sample, transformers):
result
[
i
]
=
f
(
sample
[
i
])
result
[
i
]
=
f
(
sample
[
i
])
return
result
return
result
###
###
# Iterators
# Iterators
###
###
def
_info
(
data
,
fmt
=
None
,
n
=
3
,
every
=-
1
,
width
=
50
,
stream
=
sys
.
stderr
,
name
=
""
):
def
_info
(
data
,
fmt
=
None
,
n
=
3
,
every
=-
1
,
width
=
50
,
stream
=
sys
.
stderr
,
name
=
""
):
"""Print information about the samples that are passing through.
"""Print information about the samples that are passing through.
...
@@ -278,10 +280,16 @@ def _log_keys(data, logfile=None):
...
@@ -278,10 +280,16 @@ def _log_keys(data, logfile=None):
log_keys
=
pipelinefilter
(
_log_keys
)
log_keys
=
pipelinefilter
(
_log_keys
)
def
_minedecode
(
x
):
if
isinstance
(
x
,
str
):
return
autodecode
.
imagehandler
(
x
)
else
:
return
x
def
_decode
(
data
,
*
args
,
handler
=
reraise_exception
,
**
kw
):
def
_decode
(
data
,
*
args
,
handler
=
reraise_exception
,
**
kw
):
"""Decode data based on the decoding functions given as arguments."""
"""Decode data based on the decoding functions given as arguments."""
decoder
=
_minedecode
decoder
=
lambda
x
:
autodecode
.
imagehandler
(
x
)
if
isinstance
(
x
,
str
)
else
x
handlers
=
[
decoder
(
x
)
for
x
in
args
]
handlers
=
[
decoder
(
x
)
for
x
in
args
]
f
=
autodecode
.
Decoder
(
handlers
,
**
kw
)
f
=
autodecode
.
Decoder
(
handlers
,
**
kw
)
...
@@ -325,15 +333,24 @@ def _rename(data, handler=reraise_exception, keep=True, **kw):
...
@@ -325,15 +333,24 @@ def _rename(data, handler=reraise_exception, keep=True, **kw):
for
sample
in
data
:
for
sample
in
data
:
try
:
try
:
if
not
keep
:
if
not
keep
:
yield
{
k
:
getfirst
(
sample
,
v
,
missing_is_error
=
True
)
for
k
,
v
in
kw
.
items
()}
yield
{
k
:
getfirst
(
sample
,
v
,
missing_is_error
=
True
)
for
k
,
v
in
kw
.
items
()
}
else
:
else
:
def
listify
(
v
):
def
listify
(
v
):
return
v
.
split
(
";"
)
if
isinstance
(
v
,
str
)
else
v
return
v
.
split
(
";"
)
if
isinstance
(
v
,
str
)
else
v
to_be_replaced
=
{
x
for
v
in
kw
.
values
()
for
x
in
listify
(
v
)}
to_be_replaced
=
{
x
for
v
in
kw
.
values
()
for
x
in
listify
(
v
)}
result
=
{
k
:
v
for
k
,
v
in
sample
.
items
()
if
k
not
in
to_be_replaced
}
result
=
{
result
.
update
({
k
:
getfirst
(
sample
,
v
,
missing_is_error
=
True
)
for
k
,
v
in
kw
.
items
()})
k
:
v
for
k
,
v
in
sample
.
items
()
if
k
not
in
to_be_replaced
}
result
.
update
({
k
:
getfirst
(
sample
,
v
,
missing_is_error
=
True
)
for
k
,
v
in
kw
.
items
()
})
yield
result
yield
result
except
Exception
as
exn
:
except
Exception
as
exn
:
if
handler
(
exn
):
if
handler
(
exn
):
...
@@ -381,7 +398,11 @@ def _map_dict(data, handler=reraise_exception, **kw):
...
@@ -381,7 +398,11 @@ def _map_dict(data, handler=reraise_exception, **kw):
map_dict
=
pipelinefilter
(
_map_dict
)
map_dict
=
pipelinefilter
(
_map_dict
)
def
_to_tuple
(
data
,
*
args
,
handler
=
reraise_exception
,
missing_is_error
=
True
,
none_is_error
=
None
):
def
_to_tuple
(
data
,
*
args
,
handler
=
reraise_exception
,
missing_is_error
=
True
,
none_is_error
=
None
):
"""Convert dict samples to tuples."""
"""Convert dict samples to tuples."""
if
none_is_error
is
None
:
if
none_is_error
is
None
:
none_is_error
=
missing_is_error
none_is_error
=
missing_is_error
...
@@ -390,7 +411,10 @@ def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, non
...
@@ -390,7 +411,10 @@ def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, non
for
sample
in
data
:
for
sample
in
data
:
try
:
try
:
result
=
tuple
([
getfirst
(
sample
,
f
,
missing_is_error
=
missing_is_error
)
for
f
in
args
])
result
=
tuple
([
getfirst
(
sample
,
f
,
missing_is_error
=
missing_is_error
)
for
f
in
args
])
if
none_is_error
and
any
(
x
is
None
for
x
in
result
):
if
none_is_error
and
any
(
x
is
None
for
x
in
result
):
raise
ValueError
(
f
"to_tuple
{
args
}
got
{
sample
.
keys
()
}
"
)
raise
ValueError
(
f
"to_tuple
{
args
}
got
{
sample
.
keys
()
}
"
)
yield
result
yield
result
...
@@ -463,19 +487,28 @@ rsample = pipelinefilter(_rsample)
...
@@ -463,19 +487,28 @@ rsample = pipelinefilter(_rsample)
slice
=
pipelinefilter
(
itertools
.
islice
)
slice
=
pipelinefilter
(
itertools
.
islice
)
def
_extract_keys
(
source
,
*
patterns
,
duplicate_is_error
=
True
,
ignore_missing
=
False
):
def
_extract_keys
(
source
,
*
patterns
,
duplicate_is_error
=
True
,
ignore_missing
=
False
):
for
sample
in
source
:
for
sample
in
source
:
result
=
[]
result
=
[]
for
pattern
in
patterns
:
for
pattern
in
patterns
:
pattern
=
pattern
.
split
(
";"
)
if
isinstance
(
pattern
,
str
)
else
pattern
pattern
=
pattern
.
split
(
";"
)
if
isinstance
(
pattern
,
matches
=
[
x
for
x
in
sample
.
keys
()
if
any
(
fnmatch
(
"."
+
x
,
p
)
for
p
in
pattern
)]
str
)
else
pattern
matches
=
[
x
for
x
in
sample
.
keys
()
if
any
(
fnmatch
(
"."
+
x
,
p
)
for
p
in
pattern
)
]
if
len
(
matches
)
==
0
:
if
len
(
matches
)
==
0
:
if
ignore_missing
:
if
ignore_missing
:
continue
continue
else
:
else
:
raise
ValueError
(
f
"Cannot find
{
pattern
}
in sample keys
{
sample
.
keys
()
}
."
)
raise
ValueError
(
f
"Cannot find
{
pattern
}
in sample keys
{
sample
.
keys
()
}
."
)
if
len
(
matches
)
>
1
and
duplicate_is_error
:
if
len
(
matches
)
>
1
and
duplicate_is_error
:
raise
ValueError
(
f
"Multiple sample keys
{
sample
.
keys
()
}
match
{
pattern
}
."
)
raise
ValueError
(
f
"Multiple sample keys
{
sample
.
keys
()
}
match
{
pattern
}
."
)
value
=
sample
[
matches
[
0
]]
value
=
sample
[
matches
[
0
]]
result
.
append
(
value
)
result
.
append
(
value
)
yield
tuple
(
result
)
yield
tuple
(
result
)
...
@@ -484,7 +517,12 @@ def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=Fal
...
@@ -484,7 +517,12 @@ def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=Fal
extract_keys
=
pipelinefilter
(
_extract_keys
)
extract_keys
=
pipelinefilter
(
_extract_keys
)
def
_rename_keys
(
source
,
*
args
,
keep_unselected
=
False
,
must_match
=
True
,
duplicate_is_error
=
True
,
**
kw
):
def
_rename_keys
(
source
,
*
args
,
keep_unselected
=
False
,
must_match
=
True
,
duplicate_is_error
=
True
,
**
kw
):
renamings
=
[(
pattern
,
output
)
for
output
,
pattern
in
args
]
renamings
=
[(
pattern
,
output
)
for
output
,
pattern
in
args
]
renamings
+=
[(
pattern
,
output
)
for
output
,
pattern
in
kw
.
items
()]
renamings
+=
[(
pattern
,
output
)
for
output
,
pattern
in
kw
.
items
()]
for
sample
in
source
:
for
sample
in
source
:
...
@@ -504,11 +542,15 @@ def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicat
...
@@ -504,11 +542,15 @@ def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicat
continue
continue
if
new_name
in
new_sample
:
if
new_name
in
new_sample
:
if
duplicate_is_error
:
if
duplicate_is_error
:
raise
ValueError
(
f
"Duplicate value in sample
{
sample
.
keys
()
}
after rename."
)
raise
ValueError
(
f
"Duplicate value in sample
{
sample
.
keys
()
}
after rename."
)
continue
continue
new_sample
[
new_name
]
=
value
new_sample
[
new_name
]
=
value
if
must_match
and
not
all
(
matched
.
values
()):
if
must_match
and
not
all
(
matched
.
values
()):
raise
ValueError
(
f
"Not all patterns (
{
matched
}
) matched sample keys (
{
sample
.
keys
()
}
)."
)
raise
ValueError
(
f
"Not all patterns (
{
matched
}
) matched sample keys (
{
sample
.
keys
()
}
)."
)
yield
new_sample
yield
new_sample
...
@@ -541,18 +583,18 @@ def find_decoder(decoders, path):
...
@@ -541,18 +583,18 @@ def find_decoder(decoders, path):
if
fname
.
startswith
(
"__"
):
if
fname
.
startswith
(
"__"
):
return
lambda
x
:
x
return
lambda
x
:
x
for
pattern
,
fun
in
decoders
[::
-
1
]:
for
pattern
,
fun
in
decoders
[::
-
1
]:
if
fnmatch
(
fname
.
lower
(),
pattern
)
or
fnmatch
(
"."
+
fname
.
lower
(),
pattern
):
if
fnmatch
(
fname
.
lower
(),
pattern
)
or
fnmatch
(
"."
+
fname
.
lower
(),
pattern
):
return
fun
return
fun
return
None
return
None
def
_xdecode
(
def
_xdecode
(
source
,
source
,
*
args
,
*
args
,
must_decode
=
True
,
must_decode
=
True
,
defaults
=
default_decoders
,
defaults
=
default_decoders
,
**
kw
,
**
kw
,
):
):
decoders
=
list
(
defaults
)
+
list
(
args
)
decoders
=
list
(
defaults
)
+
list
(
args
)
decoders
+=
[(
"*."
+
k
,
v
)
for
k
,
v
in
kw
.
items
()]
decoders
+=
[(
"*."
+
k
,
v
)
for
k
,
v
in
kw
.
items
()]
for
sample
in
source
:
for
sample
in
source
:
...
@@ -575,18 +617,18 @@ def _xdecode(
...
@@ -575,18 +617,18 @@ def _xdecode(
new_sample
[
path
]
=
value
new_sample
[
path
]
=
value
yield
new_sample
yield
new_sample
xdecode
=
pipelinefilter
(
_xdecode
)
xdecode
=
pipelinefilter
(
_xdecode
)
def
_audio_data_filter
(
source
,
def
_audio_data_filter
(
source
,
frame_shift
=
10
,
frame_shift
=
10
,
max_length
=
10240
,
max_length
=
10240
,
min_length
=
10
,
min_length
=
10
,
token_max_length
=
200
,
token_max_length
=
200
,
token_min_length
=
1
,
token_min_length
=
1
,
min_output_input_ratio
=
0.0005
,
min_output_input_ratio
=
0.0005
,
max_output_input_ratio
=
1
):
max_output_input_ratio
=
1
):
""" Filter sample according to feature and label length
""" Filter sample according to feature and label length
Inplace operation.
Inplace operation.
...
@@ -613,7 +655,8 @@ def _audio_data_filter(source,
...
@@ -613,7 +655,8 @@ def _audio_data_filter(source,
assert
'wav'
in
sample
assert
'wav'
in
sample
assert
'label'
in
sample
assert
'label'
in
sample
# sample['wav'] is paddle.Tensor, we have 100 frames every second (default)
# sample['wav'] is paddle.Tensor, we have 100 frames every second (default)
num_frames
=
sample
[
'wav'
].
shape
[
1
]
/
sample
[
'sample_rate'
]
*
(
1000
/
frame_shift
)
num_frames
=
sample
[
'wav'
].
shape
[
1
]
/
sample
[
'sample_rate'
]
*
(
1000
/
frame_shift
)
if
num_frames
<
min_length
:
if
num_frames
<
min_length
:
continue
continue
if
num_frames
>
max_length
:
if
num_frames
>
max_length
:
...
@@ -629,13 +672,15 @@ def _audio_data_filter(source,
...
@@ -629,13 +672,15 @@ def _audio_data_filter(source,
continue
continue
yield
sample
yield
sample
audio_data_filter
=
pipelinefilter
(
_audio_data_filter
)
audio_data_filter
=
pipelinefilter
(
_audio_data_filter
)
def
_audio_tokenize
(
source
,
def
_audio_tokenize
(
source
,
symbol_table
,
symbol_table
,
bpe_model
=
None
,
bpe_model
=
None
,
non_lang_syms
=
None
,
non_lang_syms
=
None
,
split_with_space
=
False
):
split_with_space
=
False
):
""" Decode text to chars or BPE
""" Decode text to chars or BPE
Inplace operation
Inplace operation
...
@@ -693,8 +738,10 @@ def _audio_tokenize(source,
...
@@ -693,8 +738,10 @@ def _audio_tokenize(source,
sample
[
'label'
]
=
label
sample
[
'label'
]
=
label
yield
sample
yield
sample
audio_tokenize
=
pipelinefilter
(
_audio_tokenize
)
audio_tokenize
=
pipelinefilter
(
_audio_tokenize
)
def
_audio_resample
(
source
,
resample_rate
=
16000
):
def
_audio_resample
(
source
,
resample_rate
=
16000
):
""" Resample data.
""" Resample data.
Inplace operation.
Inplace operation.
...
@@ -713,18 +760,22 @@ def _audio_resample(source, resample_rate=16000):
...
@@ -713,18 +760,22 @@ def _audio_resample(source, resample_rate=16000):
waveform
=
sample
[
'wav'
]
waveform
=
sample
[
'wav'
]
if
sample_rate
!=
resample_rate
:
if
sample_rate
!=
resample_rate
:
sample
[
'sample_rate'
]
=
resample_rate
sample
[
'sample_rate'
]
=
resample_rate
sample
[
'wav'
]
=
paddle
.
to_tensor
(
backends
.
soundfile_backend
.
resample
(
sample
[
'wav'
]
=
paddle
.
to_tensor
(
waveform
.
numpy
(),
src_sr
=
sample_rate
,
target_sr
=
resample_rate
backends
.
soundfile_backend
.
resample
(
))
waveform
.
numpy
(),
src_sr
=
sample_rate
,
target_sr
=
resample_rate
))
yield
sample
yield
sample
audio_resample
=
pipelinefilter
(
_audio_resample
)
audio_resample
=
pipelinefilter
(
_audio_resample
)
def
_audio_compute_fbank
(
source
,
def
_audio_compute_fbank
(
source
,
num_mel_bins
=
80
,
num_mel_bins
=
80
,
frame_length
=
25
,
frame_length
=
25
,
frame_shift
=
10
,
frame_shift
=
10
,
dither
=
0.0
):
dither
=
0.0
):
""" Extract fbank
""" Extract fbank
Args:
Args:
...
@@ -746,30 +797,33 @@ def _audio_compute_fbank(source,
...
@@ -746,30 +797,33 @@ def _audio_compute_fbank(source,
waveform
=
sample
[
'wav'
]
waveform
=
sample
[
'wav'
]
waveform
=
waveform
*
(
1
<<
15
)
waveform
=
waveform
*
(
1
<<
15
)
# Only keep fname, feat, label
# Only keep fname, feat, label
mat
=
kaldi
.
fbank
(
waveform
,
mat
=
kaldi
.
fbank
(
n_mels
=
num_mel_bins
,
waveform
,
frame_length
=
frame_length
,
n_mels
=
num_mel_bins
,
frame_shift
=
frame_shift
,
frame_length
=
frame_length
,
dither
=
dither
,
frame_shift
=
frame_shift
,
energy_floor
=
0.0
,
dither
=
dither
,
sr
=
sample_rate
)
energy_floor
=
0.0
,
sr
=
sample_rate
)
yield
dict
(
fname
=
sample
[
'fname'
],
label
=
sample
[
'label'
],
feat
=
mat
)
yield
dict
(
fname
=
sample
[
'fname'
],
label
=
sample
[
'label'
],
feat
=
mat
)
audio_compute_fbank
=
pipelinefilter
(
_audio_compute_fbank
)
audio_compute_fbank
=
pipelinefilter
(
_audio_compute_fbank
)
def
_audio_spec_aug
(
source
,
max_w
=
5
,
def
_audio_spec_aug
(
w_inplace
=
True
,
source
,
w_mode
=
"PIL"
,
max_w
=
5
,
max_f
=
30
,
w_inplace
=
True
,
num_f_mask
=
2
,
w_mode
=
"PIL"
,
f_inplace
=
True
,
max_f
=
30
,
f_replace_with_zero
=
False
,
num_f_mask
=
2
,
max_t
=
40
,
f_inplace
=
True
,
num_t_mask
=
2
,
f_replace_with_zero
=
False
,
t_inplace
=
True
,
max_t
=
40
,
t_replace_with_zero
=
False
,):
num_t_mask
=
2
,
t_inplace
=
True
,
t_replace_with_zero
=
False
,
):
""" Do spec augmentation
""" Do spec augmentation
Inplace operation
Inplace operation
...
@@ -793,12 +847,23 @@ def _audio_spec_aug(source,
...
@@ -793,12 +847,23 @@ def _audio_spec_aug(source,
for
sample
in
source
:
for
sample
in
source
:
x
=
sample
[
'feat'
]
x
=
sample
[
'feat'
]
x
=
x
.
numpy
()
x
=
x
.
numpy
()
x
=
time_warp
(
x
,
max_time_warp
=
max_w
,
inplace
=
w_inplace
,
mode
=
w_mode
)
x
=
time_warp
(
x
,
max_time_warp
=
max_w
,
inplace
=
w_inplace
,
mode
=
w_mode
)
x
=
freq_mask
(
x
,
F
=
max_f
,
n_mask
=
num_f_mask
,
inplace
=
f_inplace
,
replace_with_zero
=
f_replace_with_zero
)
x
=
freq_mask
(
x
=
time_mask
(
x
,
T
=
max_t
,
n_mask
=
num_t_mask
,
inplace
=
t_inplace
,
replace_with_zero
=
t_replace_with_zero
)
x
,
F
=
max_f
,
n_mask
=
num_f_mask
,
inplace
=
f_inplace
,
replace_with_zero
=
f_replace_with_zero
)
x
=
time_mask
(
x
,
T
=
max_t
,
n_mask
=
num_t_mask
,
inplace
=
t_inplace
,
replace_with_zero
=
t_replace_with_zero
)
sample
[
'feat'
]
=
paddle
.
to_tensor
(
x
,
dtype
=
paddle
.
float32
)
sample
[
'feat'
]
=
paddle
.
to_tensor
(
x
,
dtype
=
paddle
.
float32
)
yield
sample
yield
sample
audio_spec_aug
=
pipelinefilter
(
_audio_spec_aug
)
audio_spec_aug
=
pipelinefilter
(
_audio_spec_aug
)
...
@@ -829,8 +894,10 @@ def _sort(source, sort_size=500):
...
@@ -829,8 +894,10 @@ def _sort(source, sort_size=500):
for
x
in
buf
:
for
x
in
buf
:
yield
x
yield
x
sort
=
pipelinefilter
(
_sort
)
sort
=
pipelinefilter
(
_sort
)
def
_batched
(
source
,
batch_size
=
16
):
def
_batched
(
source
,
batch_size
=
16
):
""" Static batch the data by `batch_size`
""" Static batch the data by `batch_size`
...
@@ -850,8 +917,10 @@ def _batched(source, batch_size=16):
...
@@ -850,8 +917,10 @@ def _batched(source, batch_size=16):
if
len
(
buf
)
>
0
:
if
len
(
buf
)
>
0
:
yield
buf
yield
buf
batched
=
pipelinefilter
(
_batched
)
batched
=
pipelinefilter
(
_batched
)
def
dynamic_batched
(
source
,
max_frames_in_batch
=
12000
):
def
dynamic_batched
(
source
,
max_frames_in_batch
=
12000
):
""" Dynamic batch the data until the total frames in batch
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
reach `max_frames_in_batch`
...
@@ -892,8 +961,8 @@ def _audio_padding(source):
...
@@ -892,8 +961,8 @@ def _audio_padding(source):
"""
"""
for
sample
in
source
:
for
sample
in
source
:
assert
isinstance
(
sample
,
list
)
assert
isinstance
(
sample
,
list
)
feats_length
=
paddle
.
to_tensor
(
[
x
[
'feat'
].
shape
[
0
]
for
x
in
sample
],
feats_length
=
paddle
.
to_tensor
(
dtype
=
"int64"
)
[
x
[
'feat'
].
shape
[
0
]
for
x
in
sample
],
dtype
=
"int64"
)
order
=
paddle
.
argsort
(
feats_length
,
descending
=
True
)
order
=
paddle
.
argsort
(
feats_length
,
descending
=
True
)
feats_lengths
=
paddle
.
to_tensor
(
feats_lengths
=
paddle
.
to_tensor
(
[
sample
[
i
][
'feat'
].
shape
[
0
]
for
i
in
order
],
dtype
=
"int64"
)
[
sample
[
i
][
'feat'
].
shape
[
0
]
for
i
in
order
],
dtype
=
"int64"
)
...
@@ -902,20 +971,20 @@ def _audio_padding(source):
...
@@ -902,20 +971,20 @@ def _audio_padding(source):
sorted_labels
=
[
sorted_labels
=
[
paddle
.
to_tensor
(
sample
[
i
][
'label'
],
dtype
=
"int32"
)
for
i
in
order
paddle
.
to_tensor
(
sample
[
i
][
'label'
],
dtype
=
"int32"
)
for
i
in
order
]
]
label_lengths
=
paddle
.
to_tensor
([
x
.
shape
[
0
]
for
x
in
sorted_labels
],
label_lengths
=
paddle
.
to_tensor
(
dtype
=
"int64"
)
[
x
.
shape
[
0
]
for
x
in
sorted_labels
],
dtype
=
"int64"
)
padded_feats
=
pad_sequence
(
sorted_feats
,
padded_feats
=
pad_sequence
(
batch_first
=
True
,
sorted_feats
,
batch_first
=
True
,
padding_value
=
0
)
padding_value
=
0
)
padding_labels
=
pad_sequence
(
padding_labels
=
pad_sequence
(
sorted_labels
,
sorted_labels
,
batch_first
=
True
,
padding_value
=-
1
)
batch_first
=
True
,
padding_value
=-
1
)
yield
(
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
yield
(
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
label_lengths
)
label_lengths
)
audio_padding
=
pipelinefilter
(
_audio_padding
)
audio_padding
=
pipelinefilter
(
_audio_padding
)
def
_audio_cmvn
(
source
,
cmvn_file
):
def
_audio_cmvn
(
source
,
cmvn_file
):
global_cmvn
=
GlobalCMVN
(
cmvn_file
)
global_cmvn
=
GlobalCMVN
(
cmvn_file
)
for
batch
in
source
:
for
batch
in
source
:
...
@@ -923,13 +992,16 @@ def _audio_cmvn(source, cmvn_file):
...
@@ -923,13 +992,16 @@ def _audio_cmvn(source, cmvn_file):
padded_feats
=
padded_feats
.
numpy
()
padded_feats
=
padded_feats
.
numpy
()
padded_feats
=
global_cmvn
(
padded_feats
)
padded_feats
=
global_cmvn
(
padded_feats
)
padded_feats
=
paddle
.
to_tensor
(
padded_feats
,
dtype
=
paddle
.
float32
)
padded_feats
=
paddle
.
to_tensor
(
padded_feats
,
dtype
=
paddle
.
float32
)
yield
(
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
yield
(
sorted_keys
,
padded_feats
,
feats_lengths
,
padding_labels
,
label_lengths
)
label_lengths
)
audio_cmvn
=
pipelinefilter
(
_audio_cmvn
)
audio_cmvn
=
pipelinefilter
(
_audio_cmvn
)
def
_placeholder
(
source
):
def
_placeholder
(
source
):
for
data
in
source
:
for
data
in
source
:
yield
data
yield
data
placeholder
=
pipelinefilter
(
_placeholder
)
placeholder
=
pipelinefilter
(
_placeholder
)
paddlespeech/audio/streamdata/gopen.py
浏览文件 @
795eb7bd
...
@@ -3,12 +3,12 @@
...
@@ -3,12 +3,12 @@
# This file is part of the WebDataset library.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
#
#
"""Open URLs by calling subcommands."""
"""Open URLs by calling subcommands."""
import
os
import
os
,
sys
,
re
import
re
from
subprocess
import
PIPE
,
Popen
import
sys
from
subprocess
import
PIPE
from
subprocess
import
Popen
from
urllib.parse
import
urlparse
from
urllib.parse
import
urlparse
# global used for printing additional node information during verbose output
# global used for printing additional node information during verbose output
...
@@ -31,14 +31,13 @@ class Pipe:
...
@@ -31,14 +31,13 @@ class Pipe:
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
*
args
,
*
args
,
mode
=
None
,
mode
=
None
,
timeout
=
7200.0
,
timeout
=
7200.0
,
ignore_errors
=
False
,
ignore_errors
=
False
,
ignore_status
=
[],
ignore_status
=
[],
**
kw
,
**
kw
,
):
):
"""Create an IO Pipe."""
"""Create an IO Pipe."""
self
.
ignore_errors
=
ignore_errors
self
.
ignore_errors
=
ignore_errors
self
.
ignore_status
=
[
0
]
+
ignore_status
self
.
ignore_status
=
[
0
]
+
ignore_status
...
@@ -75,8 +74,7 @@ class Pipe:
...
@@ -75,8 +74,7 @@ class Pipe:
if
verbose
:
if
verbose
:
print
(
print
(
f
"pipe exit [
{
self
.
status
}
{
os
.
getpid
()
}
:
{
self
.
proc
.
pid
}
]
{
self
.
args
}
{
info
}
"
,
f
"pipe exit [
{
self
.
status
}
{
os
.
getpid
()
}
:
{
self
.
proc
.
pid
}
]
{
self
.
args
}
{
info
}
"
,
file
=
sys
.
stderr
,
file
=
sys
.
stderr
,
)
)
if
self
.
status
not
in
self
.
ignore_status
and
not
self
.
ignore_errors
:
if
self
.
status
not
in
self
.
ignore_status
and
not
self
.
ignore_errors
:
raise
Exception
(
f
"
{
self
.
args
}
: exit
{
self
.
status
}
(read)
{
info
}
"
)
raise
Exception
(
f
"
{
self
.
args
}
: exit
{
self
.
status
}
(read)
{
info
}
"
)
...
@@ -114,9 +112,11 @@ class Pipe:
...
@@ -114,9 +112,11 @@ class Pipe:
self
.
close
()
self
.
close
()
def
set_options
(
def
set_options
(
obj
,
obj
,
timeout
=
None
,
ignore_errors
=
None
,
ignore_status
=
None
,
handler
=
None
timeout
=
None
,
):
ignore_errors
=
None
,
ignore_status
=
None
,
handler
=
None
):
"""Set options for Pipes.
"""Set options for Pipes.
This function can be called on any stream. It will set pipe options only
This function can be called on any stream. It will set pipe options only
...
@@ -168,16 +168,14 @@ def gopen_pipe(url, mode="rb", bufsize=8192):
...
@@ -168,16 +168,14 @@ def gopen_pipe(url, mode="rb", bufsize=8192):
mode
=
mode
,
mode
=
mode
,
shell
=
True
,
shell
=
True
,
bufsize
=
bufsize
,
bufsize
=
bufsize
,
ignore_status
=
[
141
],
ignore_status
=
[
141
],
)
# skipcq: BAN-B604
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
elif
mode
[
0
]
==
"w"
:
return
Pipe
(
return
Pipe
(
cmd
,
cmd
,
mode
=
mode
,
mode
=
mode
,
shell
=
True
,
shell
=
True
,
bufsize
=
bufsize
,
bufsize
=
bufsize
,
ignore_status
=
[
141
],
ignore_status
=
[
141
],
)
# skipcq: BAN-B604
)
# skipcq: BAN-B604
else
:
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
...
@@ -196,8 +194,7 @@ def gopen_curl(url, mode="rb", bufsize=8192):
...
@@ -196,8 +194,7 @@ def gopen_curl(url, mode="rb", bufsize=8192):
mode
=
mode
,
mode
=
mode
,
shell
=
True
,
shell
=
True
,
bufsize
=
bufsize
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
23
],
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
elif
mode
[
0
]
==
"w"
:
cmd
=
f
"curl -s -L -T - '
{
url
}
'"
cmd
=
f
"curl -s -L -T - '
{
url
}
'"
return
Pipe
(
return
Pipe
(
...
@@ -205,8 +202,7 @@ def gopen_curl(url, mode="rb", bufsize=8192):
...
@@ -205,8 +202,7 @@ def gopen_curl(url, mode="rb", bufsize=8192):
mode
=
mode
,
mode
=
mode
,
shell
=
True
,
shell
=
True
,
bufsize
=
bufsize
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
26
],
ignore_status
=
[
141
,
26
],
)
# skipcq: BAN-B604
)
# skipcq: BAN-B604
else
:
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
...
@@ -226,15 +222,13 @@ def gopen_htgs(url, mode="rb", bufsize=8192):
...
@@ -226,15 +222,13 @@ def gopen_htgs(url, mode="rb", bufsize=8192):
mode
=
mode
,
mode
=
mode
,
shell
=
True
,
shell
=
True
,
bufsize
=
bufsize
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
23
],
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
elif
mode
[
0
]
==
"w"
:
raise
ValueError
(
f
"
{
mode
}
: cannot write"
)
raise
ValueError
(
f
"
{
mode
}
: cannot write"
)
else
:
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
def
gopen_gsutil
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
def
gopen_gsutil
(
url
,
mode
=
"rb"
,
bufsize
=
8192
):
"""Open a URL with `curl`.
"""Open a URL with `curl`.
...
@@ -249,8 +243,7 @@ def gopen_gsutil(url, mode="rb", bufsize=8192):
...
@@ -249,8 +243,7 @@ def gopen_gsutil(url, mode="rb", bufsize=8192):
mode
=
mode
,
mode
=
mode
,
shell
=
True
,
shell
=
True
,
bufsize
=
bufsize
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
23
],
ignore_status
=
[
141
,
23
],
)
# skipcq: BAN-B604
)
# skipcq: BAN-B604
elif
mode
[
0
]
==
"w"
:
elif
mode
[
0
]
==
"w"
:
cmd
=
f
"gsutil cp - '
{
url
}
'"
cmd
=
f
"gsutil cp - '
{
url
}
'"
return
Pipe
(
return
Pipe
(
...
@@ -258,13 +251,11 @@ def gopen_gsutil(url, mode="rb", bufsize=8192):
...
@@ -258,13 +251,11 @@ def gopen_gsutil(url, mode="rb", bufsize=8192):
mode
=
mode
,
mode
=
mode
,
shell
=
True
,
shell
=
True
,
bufsize
=
bufsize
,
bufsize
=
bufsize
,
ignore_status
=
[
141
,
26
],
ignore_status
=
[
141
,
26
],
)
# skipcq: BAN-B604
)
# skipcq: BAN-B604
else
:
else
:
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
raise
ValueError
(
f
"
{
mode
}
: unknown mode"
)
def
gopen_error
(
url
,
*
args
,
**
kw
):
def
gopen_error
(
url
,
*
args
,
**
kw
):
"""Raise a value error.
"""Raise a value error.
...
@@ -285,8 +276,7 @@ gopen_schemes = dict(
...
@@ -285,8 +276,7 @@ gopen_schemes = dict(
ftps
=
gopen_curl
,
ftps
=
gopen_curl
,
scp
=
gopen_curl
,
scp
=
gopen_curl
,
gs
=
gopen_gsutil
,
gs
=
gopen_gsutil
,
htgs
=
gopen_htgs
,
htgs
=
gopen_htgs
,
)
)
def
gopen
(
url
,
mode
=
"rb"
,
bufsize
=
8192
,
**
kw
):
def
gopen
(
url
,
mode
=
"rb"
,
bufsize
=
8192
,
**
kw
):
...
...
paddlespeech/audio/streamdata/handlers.py
浏览文件 @
795eb7bd
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
# This file is part of the WebDataset library.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
#
#
"""Pluggable exception handlers.
"""Pluggable exception handlers.
These are functions that take an exception as an argument and then return...
These are functions that take an exception as an argument and then return...
...
@@ -14,8 +13,8 @@ These are functions that take an exception as an argument and then return...
...
@@ -14,8 +13,8 @@ These are functions that take an exception as an argument and then return...
They are used as handler= arguments in much of the library.
They are used as handler= arguments in much of the library.
"""
"""
import
time
import
time
,
warnings
import
warnings
def
reraise_exception
(
exn
):
def
reraise_exception
(
exn
):
...
...
paddlespeech/audio/streamdata/mix.py
浏览文件 @
795eb7bd
...
@@ -5,17 +5,12 @@
...
@@ -5,17 +5,12 @@
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
#
#
"""Classes for mixing samples from multiple sources."""
"""Classes for mixing samples from multiple sources."""
import
random
import
itertools
,
os
,
random
,
time
,
sys
from
functools
import
reduce
,
wraps
import
numpy
as
np
import
numpy
as
np
from
.
import
autodecode
,
utils
from
.paddle_utils
import
IterableDataset
from
.paddle_utils
import
PaddleTensor
,
IterableDataset
from
.utils
import
PipelineStage
def
round_robin_shortest
(
*
sources
):
def
round_robin_shortest
(
*
sources
):
...
...
paddlespeech/audio/streamdata/paddle_utils.py
浏览文件 @
795eb7bd
...
@@ -5,12 +5,11 @@
...
@@ -5,12 +5,11 @@
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
#
#
"""Mock implementations of paddle interfaces when paddle is not available."""
"""Mock implementations of paddle interfaces when paddle is not available."""
try
:
try
:
from
paddle.io
import
DataLoader
,
IterableDataset
from
paddle.io
import
DataLoader
from
paddle.io
import
IterableDataset
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
class
IterableDataset
:
class
IterableDataset
:
...
@@ -22,12 +21,3 @@ except ModuleNotFoundError:
...
@@ -22,12 +21,3 @@ except ModuleNotFoundError:
"""Empty implementation of DataLoader when paddle is not available."""
"""Empty implementation of DataLoader when paddle is not available."""
pass
pass
try
:
from
paddle
import
Tensor
as
PaddleTensor
except
ModuleNotFoundError
:
class
TorchTensor
:
"""Empty implementation of PaddleTensor when paddle is not available."""
pass
paddlespeech/audio/streamdata/pipeline.py
浏览文件 @
795eb7bd
...
@@ -3,15 +3,12 @@
...
@@ -3,15 +3,12 @@
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
#%%
#%%
import
copy
,
os
,
random
,
sys
,
time
import
copy
from
dataclasses
import
dataclas
s
import
sy
s
from
itertools
import
islice
from
itertools
import
islice
from
typing
import
List
import
braceexpand
,
yaml
from
.paddle_utils
import
DataLoader
from
.paddle_utils
import
IterableDataset
from
.handlers
import
reraise_exception
from
.paddle_utils
import
DataLoader
,
IterableDataset
from
.utils
import
PipelineStage
from
.utils
import
PipelineStage
...
@@ -22,8 +19,7 @@ def add_length_method(obj):
...
@@ -22,8 +19,7 @@ def add_length_method(obj):
Combined
=
type
(
Combined
=
type
(
obj
.
__class__
.
__name__
+
"_Length"
,
obj
.
__class__
.
__name__
+
"_Length"
,
(
obj
.
__class__
,
IterableDataset
),
(
obj
.
__class__
,
IterableDataset
),
{
"__len__"
:
length
},
{
"__len__"
:
length
},
)
)
obj
.
__class__
=
Combined
obj
.
__class__
=
Combined
return
obj
return
obj
...
...
paddlespeech/audio/streamdata/shardlists.py
浏览文件 @
795eb7bd
...
@@ -4,28 +4,30 @@
...
@@ -4,28 +4,30 @@
# This file is part of the WebDataset library.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
#
#
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
"""Train PyTorch models directly from POSIX tar archive.
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
Code works locally or over HTTP connections.
"""
"""
import
os
import
os
,
random
,
sys
,
time
import
random
from
dataclasses
import
dataclass
,
field
import
sys
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
field
from
itertools
import
islice
from
itertools
import
islice
from
typing
import
List
from
typing
import
List
import
braceexpand
,
yaml
import
braceexpand
import
yaml
from
.
import
utils
from
.
import
utils
from
..utils.log
import
Logger
from
.filters
import
pipelinefilter
from
.filters
import
pipelinefilter
from
.paddle_utils
import
IterableDataset
from
.paddle_utils
import
IterableDataset
logger
=
Logger
(
__name__
)
from
..utils.log
import
Logger
logger
=
Logger
(
__name__
)
def
expand_urls
(
urls
):
def
expand_urls
(
urls
):
if
isinstance
(
urls
,
str
):
if
isinstance
(
urls
,
str
):
urllist
=
urls
.
split
(
"::"
)
urllist
=
urls
.
split
(
"::"
)
...
@@ -64,7 +66,8 @@ class SimpleShardList(IterableDataset):
...
@@ -64,7 +66,8 @@ class SimpleShardList(IterableDataset):
def
split_by_node
(
src
,
group
=
None
):
def
split_by_node
(
src
,
group
=
None
):
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
(
group
=
group
)
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
(
group
=
group
)
logger
.
info
(
f
"world_size:
{
world_size
}
, rank:
{
rank
}
"
)
logger
.
info
(
f
"world_size:
{
world_size
}
, rank:
{
rank
}
"
)
if
world_size
>
1
:
if
world_size
>
1
:
for
s
in
islice
(
src
,
rank
,
None
,
world_size
):
for
s
in
islice
(
src
,
rank
,
None
,
world_size
):
...
@@ -75,9 +78,11 @@ def split_by_node(src, group=None):
...
@@ -75,9 +78,11 @@ def split_by_node(src, group=None):
def
single_node_only
(
src
,
group
=
None
):
def
single_node_only
(
src
,
group
=
None
):
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
(
group
=
group
)
rank
,
world_size
,
worker
,
num_workers
=
utils
.
paddle_worker_info
(
group
=
group
)
if
world_size
>
1
:
if
world_size
>
1
:
raise
ValueError
(
"input pipeline needs to be reconfigured for multinode training"
)
raise
ValueError
(
"input pipeline needs to be reconfigured for multinode training"
)
for
s
in
src
:
for
s
in
src
:
yield
s
yield
s
...
@@ -104,7 +109,8 @@ def resampled_(src, n=sys.maxsize):
...
@@ -104,7 +109,8 @@ def resampled_(src, n=sys.maxsize):
rng
=
random
.
Random
(
seed
)
rng
=
random
.
Random
(
seed
)
print
(
"# resampled loading"
,
file
=
sys
.
stderr
)
print
(
"# resampled loading"
,
file
=
sys
.
stderr
)
items
=
list
(
src
)
items
=
list
(
src
)
print
(
f
"# resampled got
{
len
(
items
)
}
samples, yielding
{
n
}
"
,
file
=
sys
.
stderr
)
print
(
f
"# resampled got
{
len
(
items
)
}
samples, yielding
{
n
}
"
,
file
=
sys
.
stderr
)
for
i
in
range
(
n
):
for
i
in
range
(
n
):
yield
rng
.
choice
(
items
)
yield
rng
.
choice
(
items
)
...
@@ -118,7 +124,9 @@ def non_empty(src):
...
@@ -118,7 +124,9 @@ def non_empty(src):
yield
s
yield
s
count
+=
1
count
+=
1
if
count
==
0
:
if
count
==
0
:
raise
ValueError
(
"pipeline stage received no data at all and this was declared as an error"
)
raise
ValueError
(
"pipeline stage received no data at all and this was declared as an error"
)
@
dataclass
@
dataclass
...
@@ -138,10 +146,6 @@ def expand(s):
...
@@ -138,10 +146,6 @@ def expand(s):
return
os
.
path
.
expanduser
(
os
.
path
.
expandvars
(
s
))
return
os
.
path
.
expanduser
(
os
.
path
.
expandvars
(
s
))
class
MultiShardSample
(
IterableDataset
):
def
__init__
(
self
,
fname
):
"""Construct a shardlist from multiple sources using a YAML spec."""
self
.
epoch
=
-
1
class
MultiShardSample
(
IterableDataset
):
class
MultiShardSample
(
IterableDataset
):
def
__init__
(
self
,
fname
):
def
__init__
(
self
,
fname
):
"""Construct a shardlist from multiple sources using a YAML spec."""
"""Construct a shardlist from multiple sources using a YAML spec."""
...
@@ -156,20 +160,23 @@ class MultiShardSample(IterableDataset):
...
@@ -156,20 +160,23 @@ class MultiShardSample(IterableDataset):
else
:
else
:
with
open
(
fname
)
as
stream
:
with
open
(
fname
)
as
stream
:
spec
=
yaml
.
safe_load
(
stream
)
spec
=
yaml
.
safe_load
(
stream
)
assert
set
(
spec
.
keys
()).
issubset
(
set
(
"prefix datasets buckets"
.
split
())),
list
(
spec
.
keys
())
assert
set
(
spec
.
keys
()).
issubset
(
set
(
"prefix datasets buckets"
.
split
())),
list
(
spec
.
keys
())
prefix
=
expand
(
spec
.
get
(
"prefix"
,
""
))
prefix
=
expand
(
spec
.
get
(
"prefix"
,
""
))
self
.
sources
=
[]
self
.
sources
=
[]
for
ds
in
spec
[
"datasets"
]:
for
ds
in
spec
[
"datasets"
]:
assert
set
(
ds
.
keys
()).
issubset
(
set
(
"buckets name shards resample choose"
.
split
())),
list
(
assert
set
(
ds
.
keys
()).
issubset
(
ds
.
keys
()
set
(
"buckets name shards resample choose"
.
split
())),
list
(
)
ds
.
keys
()
)
buckets
=
ds
.
get
(
"buckets"
,
spec
.
get
(
"buckets"
,
[]))
buckets
=
ds
.
get
(
"buckets"
,
spec
.
get
(
"buckets"
,
[]))
if
isinstance
(
buckets
,
str
):
if
isinstance
(
buckets
,
str
):
buckets
=
[
buckets
]
buckets
=
[
buckets
]
buckets
=
[
expand
(
s
)
for
s
in
buckets
]
buckets
=
[
expand
(
s
)
for
s
in
buckets
]
if
buckets
==
[]:
if
buckets
==
[]:
buckets
=
[
""
]
buckets
=
[
""
]
assert
len
(
buckets
)
==
1
,
f
"
{
buckets
}
: FIXME support for multiple buckets unimplemented"
assert
len
(
buckets
)
==
1
,
f
"
{
buckets
}
: FIXME support for multiple buckets unimplemented"
bucket
=
buckets
[
0
]
bucket
=
buckets
[
0
]
name
=
ds
.
get
(
"name"
,
"@"
+
bucket
)
name
=
ds
.
get
(
"name"
,
"@"
+
bucket
)
urls
=
ds
[
"shards"
]
urls
=
ds
[
"shards"
]
...
@@ -177,15 +184,19 @@ class MultiShardSample(IterableDataset):
...
@@ -177,15 +184,19 @@ class MultiShardSample(IterableDataset):
urls
=
[
urls
]
urls
=
[
urls
]
# urls = [u for url in urls for u in braceexpand.braceexpand(url)]
# urls = [u for url in urls for u in braceexpand.braceexpand(url)]
urls
=
[
urls
=
[
prefix
+
os
.
path
.
join
(
bucket
,
u
)
for
url
in
urls
for
u
in
braceexpand
.
braceexpand
(
expand
(
url
))
prefix
+
os
.
path
.
join
(
bucket
,
u
)
for
url
in
urls
for
u
in
braceexpand
.
braceexpand
(
expand
(
url
))
]
]
resample
=
ds
.
get
(
"resample"
,
-
1
)
resample
=
ds
.
get
(
"resample"
,
-
1
)
nsample
=
ds
.
get
(
"choose"
,
-
1
)
nsample
=
ds
.
get
(
"choose"
,
-
1
)
if
nsample
>
len
(
urls
):
if
nsample
>
len
(
urls
):
raise
ValueError
(
f
"perepoch
{
nsample
}
must be no greater than the number of shards"
)
raise
ValueError
(
f
"perepoch
{
nsample
}
must be no greater than the number of shards"
)
if
(
nsample
>
0
)
and
(
resample
>
0
):
if
(
nsample
>
0
)
and
(
resample
>
0
):
raise
ValueError
(
"specify only one of perepoch or choose"
)
raise
ValueError
(
"specify only one of perepoch or choose"
)
entry
=
MSSource
(
name
=
name
,
urls
=
urls
,
perepoch
=
nsample
,
resample
=
resample
)
entry
=
MSSource
(
name
=
name
,
urls
=
urls
,
perepoch
=
nsample
,
resample
=
resample
)
self
.
sources
.
append
(
entry
)
self
.
sources
.
append
(
entry
)
print
(
f
"#
{
name
}
{
len
(
urls
)
}
{
nsample
}
"
,
file
=
sys
.
stderr
)
print
(
f
"#
{
name
}
{
len
(
urls
)
}
{
nsample
}
"
,
file
=
sys
.
stderr
)
...
@@ -203,7 +214,7 @@ class MultiShardSample(IterableDataset):
...
@@ -203,7 +214,7 @@ class MultiShardSample(IterableDataset):
# sample without replacement
# sample without replacement
l
=
list
(
source
.
urls
)
l
=
list
(
source
.
urls
)
self
.
rng
.
shuffle
(
l
)
self
.
rng
.
shuffle
(
l
)
l
=
l
[:
source
.
perepoch
]
l
=
l
[:
source
.
perepoch
]
else
:
else
:
l
=
list
(
source
.
urls
)
l
=
list
(
source
.
urls
)
result
+=
l
result
+=
l
...
@@ -227,12 +238,11 @@ class ResampledShards(IterableDataset):
...
@@ -227,12 +238,11 @@ class ResampledShards(IterableDataset):
"""An iterable dataset yielding a list of urls."""
"""An iterable dataset yielding a list of urls."""
def
__init__
(
def
__init__
(
self
,
self
,
urls
,
urls
,
nshards
=
sys
.
maxsize
,
nshards
=
sys
.
maxsize
,
worker_seed
=
None
,
worker_seed
=
None
,
deterministic
=
False
,
deterministic
=
False
,
):
):
"""Sample shards from the shard list with replacement.
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
:param urls: a list of URLs as a Python list or brace notation string
...
@@ -252,7 +262,8 @@ class ResampledShards(IterableDataset):
...
@@ -252,7 +262,8 @@ class ResampledShards(IterableDataset):
if
self
.
deterministic
:
if
self
.
deterministic
:
seed
=
utils
.
make_seed
(
self
.
worker_seed
(),
self
.
epoch
)
seed
=
utils
.
make_seed
(
self
.
worker_seed
(),
self
.
epoch
)
else
:
else
:
seed
=
utils
.
make_seed
(
self
.
worker_seed
(),
self
.
epoch
,
os
.
getpid
(),
time
.
time_ns
(),
os
.
urandom
(
4
))
seed
=
utils
.
make_seed
(
self
.
worker_seed
(),
self
.
epoch
,
os
.
getpid
(),
time
.
time_ns
(),
os
.
urandom
(
4
))
if
os
.
environ
.
get
(
"WDS_SHOW_SEED"
,
"0"
)
==
"1"
:
if
os
.
environ
.
get
(
"WDS_SHOW_SEED"
,
"0"
)
==
"1"
:
print
(
f
"# ResampledShards seed
{
seed
}
"
)
print
(
f
"# ResampledShards seed
{
seed
}
"
)
self
.
rng
=
random
.
Random
(
seed
)
self
.
rng
=
random
.
Random
(
seed
)
...
...
paddlespeech/audio/streamdata/tariterators.py
浏览文件 @
795eb7bd
...
@@ -3,13 +3,12 @@
...
@@ -3,13 +3,12 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Low level iteration functions for tar archives."""
"""Low level iteration functions for tar archives."""
import
random
import
random
,
re
,
tarfile
import
re
import
tarfile
import
braceexpand
import
braceexpand
...
@@ -27,6 +26,7 @@ import numpy as np
...
@@ -27,6 +26,7 @@ import numpy as np
AUDIO_FORMAT_SETS
=
set
([
'flac'
,
'mp3'
,
'm4a'
,
'ogg'
,
'opus'
,
'wav'
,
'wma'
])
AUDIO_FORMAT_SETS
=
set
([
'flac'
,
'mp3'
,
'm4a'
,
'ogg'
,
'opus'
,
'wav'
,
'wma'
])
def
base_plus_ext
(
path
):
def
base_plus_ext
(
path
):
"""Split off all file extensions.
"""Split off all file extensions.
...
@@ -47,12 +47,8 @@ def valid_sample(sample):
...
@@ -47,12 +47,8 @@ def valid_sample(sample):
:param sample: sample to be checked
:param sample: sample to be checked
"""
"""
return
(
return
(
sample
is
not
None
and
isinstance
(
sample
,
dict
)
and
sample
is
not
None
len
(
list
(
sample
.
keys
()))
>
0
and
not
sample
.
get
(
"__bad__"
,
False
))
and
isinstance
(
sample
,
dict
)
and
len
(
list
(
sample
.
keys
()))
>
0
and
not
sample
.
get
(
"__bad__"
,
False
)
)
# FIXME: UNUSED
# FIXME: UNUSED
...
@@ -79,16 +75,16 @@ def url_opener(data, handler=reraise_exception, **kw):
...
@@ -79,16 +75,16 @@ def url_opener(data, handler=reraise_exception, **kw):
sample
.
update
(
stream
=
stream
)
sample
.
update
(
stream
=
stream
)
yield
sample
yield
sample
except
Exception
as
exn
:
except
Exception
as
exn
:
exn
.
args
=
exn
.
args
+
(
url
,)
exn
.
args
=
exn
.
args
+
(
url
,
)
if
handler
(
exn
):
if
handler
(
exn
):
continue
continue
else
:
else
:
break
break
def
tar_file_iterator
(
def
tar_file_iterator
(
fileobj
,
fileobj
,
skip_meta
=
r
"__[^/]*__($|/)"
,
handler
=
reraise_exception
skip_meta
=
r
"__[^/]*__($|/)"
,
):
handler
=
reraise_exception
):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
:param fileobj: byte stream suitable for tarfile
:param fileobj: byte stream suitable for tarfile
...
@@ -103,11 +99,8 @@ def tar_file_iterator(
...
@@ -103,11 +99,8 @@ def tar_file_iterator(
continue
continue
if
fname
is
None
:
if
fname
is
None
:
continue
continue
if
(
if
(
"/"
not
in
fname
and
fname
.
startswith
(
meta_prefix
)
and
"/"
not
in
fname
fname
.
endswith
(
meta_suffix
)):
and
fname
.
startswith
(
meta_prefix
)
and
fname
.
endswith
(
meta_suffix
)
):
# skipping metadata for now
# skipping metadata for now
continue
continue
if
skip_meta
is
not
None
and
re
.
match
(
skip_meta
,
fname
):
if
skip_meta
is
not
None
and
re
.
match
(
skip_meta
,
fname
):
...
@@ -118,8 +111,10 @@ def tar_file_iterator(
...
@@ -118,8 +111,10 @@ def tar_file_iterator(
assert
pos
>
0
assert
pos
>
0
prefix
,
postfix
=
name
[:
pos
],
name
[
pos
+
1
:]
prefix
,
postfix
=
name
[:
pos
],
name
[
pos
+
1
:]
if
postfix
==
'wav'
:
if
postfix
==
'wav'
:
waveform
,
sample_rate
=
paddlespeech
.
audio
.
load
(
stream
.
extractfile
(
tarinfo
),
normal
=
False
)
waveform
,
sample_rate
=
paddlespeech
.
audio
.
load
(
result
=
dict
(
fname
=
prefix
,
wav
=
waveform
,
sample_rate
=
sample_rate
)
stream
.
extractfile
(
tarinfo
),
normal
=
False
)
result
=
dict
(
fname
=
prefix
,
wav
=
waveform
,
sample_rate
=
sample_rate
)
else
:
else
:
txt
=
stream
.
extractfile
(
tarinfo
).
read
().
decode
(
'utf8'
).
strip
()
txt
=
stream
.
extractfile
(
tarinfo
).
read
().
decode
(
'utf8'
).
strip
()
result
=
dict
(
fname
=
prefix
,
txt
=
txt
)
result
=
dict
(
fname
=
prefix
,
txt
=
txt
)
...
@@ -128,16 +123,17 @@ def tar_file_iterator(
...
@@ -128,16 +123,17 @@ def tar_file_iterator(
stream
.
members
=
[]
stream
.
members
=
[]
except
Exception
as
exn
:
except
Exception
as
exn
:
if
hasattr
(
exn
,
"args"
)
and
len
(
exn
.
args
)
>
0
:
if
hasattr
(
exn
,
"args"
)
and
len
(
exn
.
args
)
>
0
:
exn
.
args
=
(
exn
.
args
[
0
]
+
" @ "
+
str
(
fileobj
),)
+
exn
.
args
[
1
:]
exn
.
args
=
(
exn
.
args
[
0
]
+
" @ "
+
str
(
fileobj
),
)
+
exn
.
args
[
1
:]
if
handler
(
exn
):
if
handler
(
exn
):
continue
continue
else
:
else
:
break
break
del
stream
del
stream
def
tar_file_and_group_iterator
(
fileobj
,
skip_meta
=
r
"__[^/]*__($|/)"
,
handler
=
reraise_exception
def
tar_file_and_group_iterator
(
fileobj
,
):
skip_meta
=
r
"__[^/]*__($|/)"
,
handler
=
reraise_exception
):
""" Expand a stream of open tar files into a stream of tar file contents.
""" Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
And groups the file with same prefix
...
@@ -167,8 +163,11 @@ def tar_file_and_group_iterator(
...
@@ -167,8 +163,11 @@ def tar_file_and_group_iterator(
if
postfix
==
'txt'
:
if
postfix
==
'txt'
:
example
[
'txt'
]
=
file_obj
.
read
().
decode
(
'utf8'
).
strip
()
example
[
'txt'
]
=
file_obj
.
read
().
decode
(
'utf8'
).
strip
()
elif
postfix
in
AUDIO_FORMAT_SETS
:
elif
postfix
in
AUDIO_FORMAT_SETS
:
waveform
,
sample_rate
=
paddlespeech
.
audio
.
load
(
file_obj
,
normal
=
False
)
waveform
,
sample_rate
=
paddlespeech
.
audio
.
load
(
waveform
=
paddle
.
to_tensor
(
np
.
expand_dims
(
np
.
array
(
waveform
),
0
),
dtype
=
paddle
.
float32
)
file_obj
,
normal
=
False
)
waveform
=
paddle
.
to_tensor
(
np
.
expand_dims
(
np
.
array
(
waveform
),
0
),
dtype
=
paddle
.
float32
)
example
[
'wav'
]
=
waveform
example
[
'wav'
]
=
waveform
example
[
'sample_rate'
]
=
sample_rate
example
[
'sample_rate'
]
=
sample_rate
...
@@ -176,19 +175,21 @@ def tar_file_and_group_iterator(
...
@@ -176,19 +175,21 @@ def tar_file_and_group_iterator(
example
[
postfix
]
=
file_obj
.
read
()
example
[
postfix
]
=
file_obj
.
read
()
except
Exception
as
exn
:
except
Exception
as
exn
:
if
hasattr
(
exn
,
"args"
)
and
len
(
exn
.
args
)
>
0
:
if
hasattr
(
exn
,
"args"
)
and
len
(
exn
.
args
)
>
0
:
exn
.
args
=
(
exn
.
args
[
0
]
+
" @ "
+
str
(
fileobj
),)
+
exn
.
args
[
1
:]
exn
.
args
=
(
exn
.
args
[
0
]
+
" @ "
+
str
(
fileobj
),
)
+
exn
.
args
[
1
:]
if
handler
(
exn
):
if
handler
(
exn
):
continue
continue
else
:
else
:
break
break
valid
=
False
valid
=
False
# logging.warning('error to parse {}'.format(name))
# logging.warning('error to parse {}'.format(name))
prev_prefix
=
prefix
prev_prefix
=
prefix
if
prev_prefix
is
not
None
:
if
prev_prefix
is
not
None
:
example
[
'fname'
]
=
prev_prefix
example
[
'fname'
]
=
prev_prefix
yield
example
yield
example
stream
.
close
()
stream
.
close
()
def
tar_file_expander
(
data
,
handler
=
reraise_exception
):
def
tar_file_expander
(
data
,
handler
=
reraise_exception
):
"""Expand a stream of open tar files into a stream of tar file contents.
"""Expand a stream of open tar files into a stream of tar file contents.
...
@@ -200,9 +201,8 @@ def tar_file_expander(data, handler=reraise_exception):
...
@@ -200,9 +201,8 @@ def tar_file_expander(data, handler=reraise_exception):
assert
isinstance
(
source
,
dict
)
assert
isinstance
(
source
,
dict
)
assert
"stream"
in
source
assert
"stream"
in
source
for
sample
in
tar_file_iterator
(
source
[
"stream"
]):
for
sample
in
tar_file_iterator
(
source
[
"stream"
]):
assert
(
assert
(
isinstance
(
sample
,
dict
)
and
"data"
in
sample
and
isinstance
(
sample
,
dict
)
and
"data"
in
sample
and
"fname"
in
sample
"fname"
in
sample
)
)
sample
[
"__url__"
]
=
url
sample
[
"__url__"
]
=
url
yield
sample
yield
sample
except
Exception
as
exn
:
except
Exception
as
exn
:
...
@@ -213,8 +213,6 @@ def tar_file_expander(data, handler=reraise_exception):
...
@@ -213,8 +213,6 @@ def tar_file_expander(data, handler=reraise_exception):
break
break
def
tar_file_and_group_expander
(
data
,
handler
=
reraise_exception
):
def
tar_file_and_group_expander
(
data
,
handler
=
reraise_exception
):
"""Expand a stream of open tar files into a stream of tar file contents.
"""Expand a stream of open tar files into a stream of tar file contents.
...
@@ -226,9 +224,8 @@ def tar_file_and_group_expander(data, handler=reraise_exception):
...
@@ -226,9 +224,8 @@ def tar_file_and_group_expander(data, handler=reraise_exception):
assert
isinstance
(
source
,
dict
)
assert
isinstance
(
source
,
dict
)
assert
"stream"
in
source
assert
"stream"
in
source
for
sample
in
tar_file_and_group_iterator
(
source
[
"stream"
]):
for
sample
in
tar_file_and_group_iterator
(
source
[
"stream"
]):
assert
(
assert
(
isinstance
(
sample
,
dict
)
and
"wav"
in
sample
and
isinstance
(
sample
,
dict
)
and
"wav"
in
sample
and
"txt"
in
sample
and
"fname"
in
sample
"txt"
in
sample
and
"fname"
in
sample
)
)
sample
[
"__url__"
]
=
url
sample
[
"__url__"
]
=
url
yield
sample
yield
sample
except
Exception
as
exn
:
except
Exception
as
exn
:
...
@@ -239,7 +236,11 @@ def tar_file_and_group_expander(data, handler=reraise_exception):
...
@@ -239,7 +236,11 @@ def tar_file_and_group_expander(data, handler=reraise_exception):
break
break
def
group_by_keys
(
data
,
keys
=
base_plus_ext
,
lcase
=
True
,
suffixes
=
None
,
handler
=
None
):
def
group_by_keys
(
data
,
keys
=
base_plus_ext
,
lcase
=
True
,
suffixes
=
None
,
handler
=
None
):
"""Return function over iterator that groups key, value pairs into samples.
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param keys: function that splits the key into key and extension (base_plus_ext)
...
@@ -254,8 +255,8 @@ def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=N
...
@@ -254,8 +255,8 @@ def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=N
print
(
print
(
prefix
,
prefix
,
suffix
,
suffix
,
current_sample
.
keys
()
if
isinstance
(
current_sample
,
dict
)
else
None
,
current_sample
.
keys
()
)
if
isinstance
(
current_sample
,
dict
)
else
None
,
)
if
prefix
is
None
:
if
prefix
is
None
:
continue
continue
if
lcase
:
if
lcase
:
...
...
paddlespeech/audio/streamdata/utils.py
浏览文件 @
795eb7bd
...
@@ -4,22 +4,23 @@
...
@@ -4,22 +4,23 @@
# This file is part of the WebDataset library.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
#
#
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
"""Miscellaneous utility functions."""
"""Miscellaneous utility functions."""
import
importlib
import
importlib
import
itertools
as
itt
import
itertools
as
itt
import
os
import
os
import
re
import
re
import
sys
import
sys
from
typing
import
Any
,
Callable
,
Iterator
,
Optional
,
Union
from
typing
import
Any
from
typing
import
Callable
from
typing
import
Iterator
from
typing
import
Union
from
..utils.log
import
Logger
from
..utils.log
import
Logger
logger
=
Logger
(
__name__
)
logger
=
Logger
(
__name__
)
def
make_seed
(
*
args
):
def
make_seed
(
*
args
):
seed
=
0
seed
=
0
for
arg
in
args
:
for
arg
in
args
:
...
@@ -37,7 +38,7 @@ def identity(x: Any) -> Any:
...
@@ -37,7 +38,7 @@ def identity(x: Any) -> Any:
return
x
return
x
def
safe_eval
(
s
:
str
,
expr
:
str
=
"{}"
):
def
safe_eval
(
s
:
str
,
expr
:
str
=
"{}"
):
"""Evaluate the given expression more safely."""
"""Evaluate the given expression more safely."""
if
re
.
sub
(
"[^A-Za-z0-9_]"
,
""
,
s
)
!=
s
:
if
re
.
sub
(
"[^A-Za-z0-9_]"
,
""
,
s
)
!=
s
:
raise
ValueError
(
f
"safe_eval: illegal characters in: '
{
s
}
'"
)
raise
ValueError
(
f
"safe_eval: illegal characters in: '
{
s
}
'"
)
...
@@ -54,9 +55,9 @@ def lookup_sym(sym: str, modules: list):
...
@@ -54,9 +55,9 @@ def lookup_sym(sym: str, modules: list):
return
None
return
None
def
repeatedly0
(
def
repeatedly0
(
loader
:
Iterator
,
loader
:
Iterator
,
nepochs
:
int
=
sys
.
maxsize
,
nbatches
:
int
=
sys
.
maxsize
nepochs
:
int
=
sys
.
maxsize
,
):
nbatches
:
int
=
sys
.
maxsize
):
"""Repeatedly returns batches from a DataLoader."""
"""Repeatedly returns batches from a DataLoader."""
for
epoch
in
range
(
nepochs
):
for
epoch
in
range
(
nepochs
):
for
sample
in
itt
.
islice
(
loader
,
nbatches
):
for
sample
in
itt
.
islice
(
loader
,
nbatches
):
...
@@ -69,12 +70,11 @@ def guess_batchsize(batch: Union[tuple, list]):
...
@@ -69,12 +70,11 @@ def guess_batchsize(batch: Union[tuple, list]):
def
repeatedly
(
def
repeatedly
(
source
:
Iterator
,
source
:
Iterator
,
nepochs
:
int
=
None
,
nepochs
:
int
=
None
,
nbatches
:
int
=
None
,
nbatches
:
int
=
None
,
nsamples
:
int
=
None
,
nsamples
:
int
=
None
,
batchsize
:
Callable
[...,
int
]
=
guess_batchsize
,
batchsize
:
Callable
[...,
int
]
=
guess_batchsize
,
):
):
"""Repeatedly yield samples from an iterator."""
"""Repeatedly yield samples from an iterator."""
epoch
=
0
epoch
=
0
batch
=
0
batch
=
0
...
@@ -93,6 +93,7 @@ def repeatedly(
...
@@ -93,6 +93,7 @@ def repeatedly(
if
nepochs
is
not
None
and
epoch
>=
nepochs
:
if
nepochs
is
not
None
and
epoch
>=
nepochs
:
return
return
def
paddle_worker_info
(
group
=
None
):
def
paddle_worker_info
(
group
=
None
):
"""Return node and worker info for PyTorch and some distributed environments."""
"""Return node and worker info for PyTorch and some distributed environments."""
rank
=
0
rank
=
0
...
@@ -116,7 +117,7 @@ def paddle_worker_info(group=None):
...
@@ -116,7 +117,7 @@ def paddle_worker_info(group=None):
else
:
else
:
try
:
try
:
from
paddle.io
import
get_worker_info
from
paddle.io
import
get_worker_info
worker_info
=
paddle
.
io
.
get_worker_info
()
worker_info
=
get_worker_info
()
if
worker_info
is
not
None
:
if
worker_info
is
not
None
:
worker
=
worker_info
.
id
worker
=
worker_info
.
id
num_workers
=
worker_info
.
num_workers
num_workers
=
worker_info
.
num_workers
...
@@ -126,6 +127,7 @@ def paddle_worker_info(group=None):
...
@@ -126,6 +127,7 @@ def paddle_worker_info(group=None):
return
rank
,
world_size
,
worker
,
num_workers
return
rank
,
world_size
,
worker
,
num_workers
def
paddle_worker_seed
(
group
=
None
):
def
paddle_worker_seed
(
group
=
None
):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank
,
world_size
,
worker
,
num_workers
=
paddle_worker_info
(
group
=
group
)
rank
,
world_size
,
worker
,
num_workers
=
paddle_worker_info
(
group
=
group
)
...
...
paddlespeech/audio/streamdata/writer.py
浏览文件 @
795eb7bd
...
@@ -5,18 +5,24 @@
...
@@ -5,18 +5,24 @@
# See the LICENSE file for licensing terms (BSD-style).
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
# Modified from https://github.com/webdataset/webdataset
#
#
"""Classes and functions for writing tar files and WebDataset files."""
"""Classes and functions for writing tar files and WebDataset files."""
import
io
import
io
,
json
,
pickle
,
re
,
tarfile
,
time
import
json
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
pickle
import
re
import
tarfile
import
time
from
typing
import
Any
from
typing
import
Callable
from
typing
import
Optional
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
from
.
import
gopen
from
.
import
gopen
def
imageencoder
(
image
:
Any
,
format
:
str
=
"PNG"
):
# skipcq: PYL-W0622
def
imageencoder
(
image
:
Any
,
format
:
str
=
"PNG"
):
# skipcq: PYL-W0622
"""Compress an image using PIL and return it as a string.
"""Compress an image using PIL and return it as a string.
Can handle float or uint8 images.
Can handle float or uint8 images.
...
@@ -67,6 +73,7 @@ def bytestr(data: Any):
...
@@ -67,6 +73,7 @@ def bytestr(data: Any):
return
data
.
encode
(
"ascii"
)
return
data
.
encode
(
"ascii"
)
return
str
(
data
).
encode
(
"ascii"
)
return
str
(
data
).
encode
(
"ascii"
)
def
paddle_dumps
(
data
:
Any
):
def
paddle_dumps
(
data
:
Any
):
"""Dump data into a bytestring using paddle.dumps.
"""Dump data into a bytestring using paddle.dumps.
...
@@ -82,6 +89,7 @@ def paddle_dumps(data: Any):
...
@@ -82,6 +89,7 @@ def paddle_dumps(data: Any):
paddle
.
save
(
data
,
stream
)
paddle
.
save
(
data
,
stream
)
return
stream
.
getvalue
()
return
stream
.
getvalue
()
def
numpy_dumps
(
data
:
np
.
ndarray
):
def
numpy_dumps
(
data
:
np
.
ndarray
):
"""Dump data into a bytestring using numpy npy format.
"""Dump data into a bytestring using numpy npy format.
...
@@ -139,9 +147,8 @@ def add_handlers(d, keys, value):
...
@@ -139,9 +147,8 @@ def add_handlers(d, keys, value):
def
make_handlers
():
def
make_handlers
():
"""Create a list of handlers for encoding data."""
"""Create a list of handlers for encoding data."""
handlers
=
{}
handlers
=
{}
add_handlers
(
add_handlers
(
handlers
,
"cls cls2 class count index inx id"
,
handlers
,
"cls cls2 class count index inx id"
,
lambda
x
:
str
(
x
).
encode
(
"ascii"
)
lambda
x
:
str
(
x
).
encode
(
"ascii"
))
)
add_handlers
(
handlers
,
"txt text transcript"
,
lambda
x
:
x
.
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"txt text transcript"
,
lambda
x
:
x
.
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"html htm"
,
lambda
x
:
x
.
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"html htm"
,
lambda
x
:
x
.
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"pyd pickle"
,
pickle
.
dumps
)
add_handlers
(
handlers
,
"pyd pickle"
,
pickle
.
dumps
)
...
@@ -152,7 +159,8 @@ def make_handlers():
...
@@ -152,7 +159,8 @@ def make_handlers():
add_handlers
(
handlers
,
"json jsn"
,
lambda
x
:
json
.
dumps
(
x
).
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"json jsn"
,
lambda
x
:
json
.
dumps
(
x
).
encode
(
"utf-8"
))
add_handlers
(
handlers
,
"mp msgpack msg"
,
mp_dumps
)
add_handlers
(
handlers
,
"mp msgpack msg"
,
mp_dumps
)
add_handlers
(
handlers
,
"cbor"
,
cbor_dumps
)
add_handlers
(
handlers
,
"cbor"
,
cbor_dumps
)
add_handlers
(
handlers
,
"jpg jpeg img image"
,
lambda
data
:
imageencoder
(
data
,
"jpg"
))
add_handlers
(
handlers
,
"jpg jpeg img image"
,
lambda
data
:
imageencoder
(
data
,
"jpg"
))
add_handlers
(
handlers
,
"png"
,
lambda
data
:
imageencoder
(
data
,
"png"
))
add_handlers
(
handlers
,
"png"
,
lambda
data
:
imageencoder
(
data
,
"png"
))
add_handlers
(
handlers
,
"pbm"
,
lambda
data
:
imageencoder
(
data
,
"pbm"
))
add_handlers
(
handlers
,
"pbm"
,
lambda
data
:
imageencoder
(
data
,
"pbm"
))
add_handlers
(
handlers
,
"pgm"
,
lambda
data
:
imageencoder
(
data
,
"pgm"
))
add_handlers
(
handlers
,
"pgm"
,
lambda
data
:
imageencoder
(
data
,
"pgm"
))
...
@@ -192,7 +200,8 @@ def encode_based_on_extension(sample: dict, handlers: dict):
...
@@ -192,7 +200,8 @@ def encode_based_on_extension(sample: dict, handlers: dict):
:param handlers: handlers for encoding
:param handlers: handlers for encoding
"""
"""
return
{
return
{
k
:
encode_based_on_extension1
(
v
,
k
,
handlers
)
for
k
,
v
in
list
(
sample
.
items
())
k
:
encode_based_on_extension1
(
v
,
k
,
handlers
)
for
k
,
v
in
list
(
sample
.
items
())
}
}
...
@@ -258,15 +267,14 @@ class TarWriter:
...
@@ -258,15 +267,14 @@ class TarWriter:
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
fileobj
,
fileobj
,
user
:
str
=
"bigdata"
,
user
:
str
=
"bigdata"
,
group
:
str
=
"bigdata"
,
group
:
str
=
"bigdata"
,
mode
:
int
=
0o0444
,
mode
:
int
=
0o0444
,
compress
:
Optional
[
bool
]
=
None
,
compress
:
Optional
[
bool
]
=
None
,
encoder
:
Union
[
None
,
bool
,
Callable
]
=
True
,
encoder
:
Union
[
None
,
bool
,
Callable
]
=
True
,
keep_meta
:
bool
=
False
,
keep_meta
:
bool
=
False
,
):
):
"""Create a tar writer.
"""Create a tar writer.
:param fileobj: stream to write data to
:param fileobj: stream to write data to
...
@@ -330,8 +338,7 @@ class TarWriter:
...
@@ -330,8 +338,7 @@ class TarWriter:
continue
continue
if
not
isinstance
(
v
,
(
bytes
,
bytearray
,
memoryview
)):
if
not
isinstance
(
v
,
(
bytes
,
bytearray
,
memoryview
)):
raise
ValueError
(
raise
ValueError
(
f
"
{
k
}
doesn't map to a bytes after encoding (
{
type
(
v
)
}
)"
f
"
{
k
}
doesn't map to a bytes after encoding (
{
type
(
v
)
}
)"
)
)
key
=
obj
[
"__key__"
]
key
=
obj
[
"__key__"
]
for
k
in
sorted
(
obj
.
keys
()):
for
k
in
sorted
(
obj
.
keys
()):
if
k
==
"__key__"
:
if
k
==
"__key__"
:
...
@@ -349,7 +356,8 @@ class TarWriter:
...
@@ -349,7 +356,8 @@ class TarWriter:
ti
.
uname
=
self
.
user
ti
.
uname
=
self
.
user
ti
.
gname
=
self
.
group
ti
.
gname
=
self
.
group
if
not
isinstance
(
v
,
(
bytes
,
bytearray
,
memoryview
)):
if
not
isinstance
(
v
,
(
bytes
,
bytearray
,
memoryview
)):
raise
ValueError
(
f
"converter didn't yield bytes:
{
k
}
,
{
type
(
v
)
}
"
)
raise
ValueError
(
f
"converter didn't yield bytes:
{
k
}
,
{
type
(
v
)
}
"
)
stream
=
io
.
BytesIO
(
v
)
stream
=
io
.
BytesIO
(
v
)
self
.
tarstream
.
addfile
(
ti
,
stream
)
self
.
tarstream
.
addfile
(
ti
,
stream
)
total
+=
ti
.
size
total
+=
ti
.
size
...
@@ -360,14 +368,13 @@ class ShardWriter:
...
@@ -360,14 +368,13 @@ class ShardWriter:
"""Like TarWriter but splits into multiple shards."""
"""Like TarWriter but splits into multiple shards."""
def
__init__
(
def
__init__
(
self
,
self
,
pattern
:
str
,
pattern
:
str
,
maxcount
:
int
=
100000
,
maxcount
:
int
=
100000
,
maxsize
:
float
=
3e9
,
maxsize
:
float
=
3e9
,
post
:
Optional
[
Callable
]
=
None
,
post
:
Optional
[
Callable
]
=
None
,
start_shard
:
int
=
0
,
start_shard
:
int
=
0
,
**
kw
,
**
kw
,
):
):
"""Create a ShardWriter.
"""Create a ShardWriter.
:param pattern: output file pattern
:param pattern: output file pattern
...
@@ -400,8 +407,7 @@ class ShardWriter:
...
@@ -400,8 +407,7 @@ class ShardWriter:
self
.
fname
,
self
.
fname
,
self
.
count
,
self
.
count
,
"%.1f GB"
%
(
self
.
size
/
1e9
),
"%.1f GB"
%
(
self
.
size
/
1e9
),
self
.
total
,
self
.
total
,
)
)
self
.
shard
+=
1
self
.
shard
+=
1
stream
=
open
(
self
.
fname
,
"wb"
)
stream
=
open
(
self
.
fname
,
"wb"
)
self
.
tarstream
=
TarWriter
(
stream
,
**
self
.
kw
)
self
.
tarstream
=
TarWriter
(
stream
,
**
self
.
kw
)
...
@@ -413,11 +419,8 @@ class ShardWriter:
...
@@ -413,11 +419,8 @@ class ShardWriter:
:param obj: sample to be written
:param obj: sample to be written
"""
"""
if
(
if
(
self
.
tarstream
is
None
or
self
.
count
>=
self
.
maxcount
or
self
.
tarstream
is
None
self
.
size
>=
self
.
maxsize
):
or
self
.
count
>=
self
.
maxcount
or
self
.
size
>=
self
.
maxsize
):
self
.
next_stream
()
self
.
next_stream
()
size
=
self
.
tarstream
.
write
(
obj
)
size
=
self
.
tarstream
.
write
(
obj
)
self
.
count
+=
1
self
.
count
+=
1
...
...
paddlespeech/audio/text/text_featurizer.py
浏览文件 @
795eb7bd
...
@@ -17,6 +17,7 @@ from typing import Union
...
@@ -17,6 +17,7 @@ from typing import Union
import
sentencepiece
as
spm
import
sentencepiece
as
spm
from
..utils.log
import
Logger
from
.utility
import
BLANK
from
.utility
import
BLANK
from
.utility
import
EOS
from
.utility
import
EOS
from
.utility
import
load_dict
from
.utility
import
load_dict
...
@@ -24,7 +25,6 @@ from .utility import MASKCTC
...
@@ -24,7 +25,6 @@ from .utility import MASKCTC
from
.utility
import
SOS
from
.utility
import
SOS
from
.utility
import
SPACE
from
.utility
import
SPACE
from
.utility
import
UNK
from
.utility
import
UNK
from
..utils.log
import
Logger
logger
=
Logger
(
__name__
)
logger
=
Logger
(
__name__
)
...
...
paddlespeech/audio/transform/perturb.py
浏览文件 @
795eb7bd
...
@@ -12,15 +12,16 @@
...
@@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
# Modified from espnet(https://github.com/espnet/espnet)
import
io
import
os
import
h5py
import
librosa
import
librosa
import
numpy
import
numpy
import
numpy
as
np
import
scipy
import
scipy
import
soundfile
import
soundfile
import
io
import
os
import
h5py
import
numpy
as
np
class
SoundHDF5File
():
class
SoundHDF5File
():
"""Collecting sound files to a HDF5 file
"""Collecting sound files to a HDF5 file
...
@@ -109,6 +110,7 @@ class SoundHDF5File():
...
@@ -109,6 +110,7 @@ class SoundHDF5File():
def
close
(
self
):
def
close
(
self
):
self
.
file
.
close
()
self
.
file
.
close
()
class
SpeedPerturbation
():
class
SpeedPerturbation
():
"""SpeedPerturbation
"""SpeedPerturbation
...
@@ -558,4 +560,3 @@ class RIRConvolve():
...
@@ -558,4 +560,3 @@ class RIRConvolve():
[
scipy
.
convolve
(
x
,
r
,
mode
=
"same"
)
for
r
in
rir
],
axis
=-
1
)
[
scipy
.
convolve
(
x
,
r
,
mode
=
"same"
)
for
r
in
rir
],
axis
=-
1
)
else
:
else
:
return
scipy
.
convolve
(
x
,
rir
,
mode
=
"same"
)
return
scipy
.
convolve
(
x
,
rir
,
mode
=
"same"
)
paddlespeech/audio/transform/spec_augment.py
浏览文件 @
795eb7bd
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet)
# Modified from espnet(https://github.com/espnet/espnet)
"""Spec Augment module for preprocessing i.e., data augmentation"""
"""Spec Augment module for preprocessing i.e., data augmentation"""
import
random
import
random
import
numpy
import
numpy
from
PIL
import
Image
from
PIL
import
Image
...
...
paddlespeech/cli/executor.py
浏览文件 @
795eb7bd
...
@@ -191,7 +191,7 @@ class BaseExecutor(ABC):
...
@@ -191,7 +191,7 @@ class BaseExecutor(ABC):
line
=
line
.
strip
()
line
=
line
.
strip
()
if
not
line
:
if
not
line
:
continue
continue
k
,
v
=
line
.
split
()
# space or \t
k
,
v
=
line
.
split
()
# space or \t
job_contents
[
k
]
=
v
job_contents
[
k
]
=
v
return
job_contents
return
job_contents
...
...
paddlespeech/s2t/__init__.py
浏览文件 @
795eb7bd
...
@@ -114,6 +114,7 @@ if not hasattr(paddle.Tensor, 'new_full'):
...
@@ -114,6 +114,7 @@ if not hasattr(paddle.Tensor, 'new_full'):
paddle
.
Tensor
.
new_full
=
new_full
paddle
.
Tensor
.
new_full
=
new_full
paddle
.
static
.
Variable
.
new_full
=
new_full
paddle
.
static
.
Variable
.
new_full
=
new_full
def
contiguous
(
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
contiguous
(
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
return
xs
return
xs
...
...
paddlespeech/s2t/exps/u2/model.py
浏览文件 @
795eb7bd
...
@@ -25,8 +25,6 @@ import paddle
...
@@ -25,8 +25,6 @@ import paddle
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.dataloader
import
BatchDataLoader
from
paddlespeech.s2t.io.dataloader
import
StreamDataLoader
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.io.dataloader
import
DataLoaderFactory
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
from
paddlespeech.s2t.training.optimizer
import
OptimizerFactory
...
@@ -109,7 +107,8 @@ class U2Trainer(Trainer):
...
@@ -109,7 +107,8 @@ class U2Trainer(Trainer):
def
valid
(
self
):
def
valid
(
self
):
self
.
model
.
eval
()
self
.
model
.
eval
()
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
num_seen_utts
=
1
total_loss
=
0.0
total_loss
=
0.0
...
@@ -136,7 +135,8 @@ class U2Trainer(Trainer):
...
@@ -136,7 +135,8 @@ class U2Trainer(Trainer):
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
logger
.
info
(
msg
)
...
@@ -157,7 +157,8 @@ class U2Trainer(Trainer):
...
@@ -157,7 +157,8 @@ class U2Trainer(Trainer):
self
.
before_train
()
self
.
before_train
()
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
self
.
model
.
train
()
...
@@ -225,14 +226,18 @@ class U2Trainer(Trainer):
...
@@ -225,14 +226,18 @@ class U2Trainer(Trainer):
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
if
self
.
train
:
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
'train'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
else
:
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
decode_batch_size
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
'decode_batch_size'
,
1
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
self
.
args
)
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
logger
.
info
(
"Setup test/align Dataloader!"
)
logger
.
info
(
"Setup test/align Dataloader!"
)
def
setup_model
(
self
):
def
setup_model
(
self
):
...
...
paddlespeech/s2t/exps/u2_kaldi/model.py
浏览文件 @
795eb7bd
...
@@ -105,7 +105,8 @@ class U2Trainer(Trainer):
...
@@ -105,7 +105,8 @@ class U2Trainer(Trainer):
def
valid
(
self
):
def
valid
(
self
):
self
.
model
.
eval
()
self
.
model
.
eval
()
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
num_seen_utts
=
1
total_loss
=
0.0
total_loss
=
0.0
...
@@ -133,7 +134,8 @@ class U2Trainer(Trainer):
...
@@ -133,7 +134,8 @@ class U2Trainer(Trainer):
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
logger
.
info
(
msg
)
...
@@ -153,7 +155,8 @@ class U2Trainer(Trainer):
...
@@ -153,7 +155,8 @@ class U2Trainer(Trainer):
self
.
before_train
()
self
.
before_train
()
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
self
.
model
.
train
()
...
@@ -165,8 +168,8 @@ class U2Trainer(Trainer):
...
@@ -165,8 +168,8 @@ class U2Trainer(Trainer):
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
msg
+=
"batch : {}/{}, "
.
format
(
len
(
self
.
train_loader
))
batch_index
+
1
,
len
(
self
.
train_loader
))
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
...
@@ -204,21 +207,24 @@ class U2Trainer(Trainer):
...
@@ -204,21 +207,24 @@ class U2Trainer(Trainer):
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
if
self
.
train
:
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
config
[
'preprocess_config'
]
=
None
config
[
'preprocess_config'
]
=
None
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
else
:
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
config
[
'preprocess_config'
]
=
None
config
[
'preprocess_config'
]
=
None
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
config
[
'preprocess_config'
]
=
None
config
[
'preprocess_config'
]
=
None
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
self
.
align_loader
=
DataLoaderFactory
.
get_dataloader
(
'align'
,
config
,
self
.
args
)
logger
.
info
(
"Setup test/align Dataloader!"
)
logger
.
info
(
"Setup test/align Dataloader!"
)
def
setup_model
(
self
):
def
setup_model
(
self
):
config
=
self
.
config
config
=
self
.
config
...
...
paddlespeech/s2t/exps/u2_st/model.py
浏览文件 @
795eb7bd
...
@@ -121,7 +121,8 @@ class U2STTrainer(Trainer):
...
@@ -121,7 +121,8 @@ class U2STTrainer(Trainer):
def
valid
(
self
):
def
valid
(
self
):
self
.
model
.
eval
()
self
.
model
.
eval
()
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
num_seen_utts
=
1
total_loss
=
0.0
total_loss
=
0.0
...
@@ -155,7 +156,8 @@ class U2STTrainer(Trainer):
...
@@ -155,7 +156,8 @@ class U2STTrainer(Trainer):
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
"batch: {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_dump
.
items
())
for
k
,
v
in
valid_dump
.
items
())
logger
.
info
(
msg
)
logger
.
info
(
msg
)
...
@@ -175,7 +177,8 @@ class U2STTrainer(Trainer):
...
@@ -175,7 +177,8 @@ class U2STTrainer(Trainer):
self
.
before_train
()
self
.
before_train
()
if
not
self
.
use_streamdata
:
if
not
self
.
use_streamdata
:
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
n_epoch
:
with
Timer
(
"Epoch-Train Time Cost: {}"
):
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
self
.
model
.
train
()
...
@@ -248,14 +251,16 @@ class U2STTrainer(Trainer):
...
@@ -248,14 +251,16 @@ class U2STTrainer(Trainer):
config
[
'load_transcript'
]
=
load_transcript
config
[
'load_transcript'
]
=
load_transcript
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
self
.
use_streamdata
=
config
.
get
(
"use_stream_data"
,
False
)
if
self
.
train
:
if
self
.
train
:
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
'train'
,
config
,
self
.
args
)
self
.
train_loader
=
DataLoaderFactory
.
get_dataloader
(
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
'train'
,
config
,
self
.
args
)
self
.
valid_loader
=
DataLoaderFactory
.
get_dataloader
(
'valid'
,
config
,
self
.
args
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
else
:
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
self
.
test_loader
=
DataLoaderFactory
.
get_dataloader
(
'test'
,
config
,
self
.
args
)
logger
.
info
(
"Setup test Dataloader!"
)
logger
.
info
(
"Setup test Dataloader!"
)
def
setup_model
(
self
):
def
setup_model
(
self
):
config
=
self
.
config
config
=
self
.
config
model_conf
=
config
model_conf
=
config
...
...
paddlespeech/s2t/io/dataloader.py
浏览文件 @
795eb7bd
...
@@ -22,17 +22,16 @@ import paddle
...
@@ -22,17 +22,16 @@ import paddle
from
paddle.io
import
BatchSampler
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddle.io
import
DistributedBatchSampler
from
yacs.config
import
CfgNode
import
paddlespeech.audio.streamdata
as
streamdata
from
paddlespeech.audio.text.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.batchfy
import
make_batchset
from
paddlespeech.s2t.io.batchfy
import
make_batchset
from
paddlespeech.s2t.io.converter
import
CustomConverter
from
paddlespeech.s2t.io.converter
import
CustomConverter
from
paddlespeech.s2t.io.dataset
import
TransformDataset
from
paddlespeech.s2t.io.dataset
import
TransformDataset
from
paddlespeech.s2t.io.reader
import
LoadInputsAndTargets
from
paddlespeech.s2t.io.reader
import
LoadInputsAndTargets
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
import
paddlespeech.audio.streamdata
as
streamdata
from
paddlespeech.audio.text.text_featurizer
import
TextFeaturizer
from
yacs.config
import
CfgNode
__all__
=
[
"BatchDataLoader"
,
"StreamDataLoader"
]
__all__
=
[
"BatchDataLoader"
,
"StreamDataLoader"
]
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -61,6 +60,7 @@ def batch_collate(x):
...
@@ -61,6 +60,7 @@ def batch_collate(x):
"""
"""
return
x
[
0
]
return
x
[
0
]
def
read_preprocess_cfg
(
preprocess_conf_file
):
def
read_preprocess_cfg
(
preprocess_conf_file
):
augment_conf
=
dict
()
augment_conf
=
dict
()
preprocess_cfg
=
CfgNode
(
new_allowed
=
True
)
preprocess_cfg
=
CfgNode
(
new_allowed
=
True
)
...
@@ -82,7 +82,8 @@ def read_preprocess_cfg(preprocess_conf_file):
...
@@ -82,7 +82,8 @@ def read_preprocess_cfg(preprocess_conf_file):
augment_conf
[
'num_t_mask'
]
=
process
[
'n_mask'
]
augment_conf
[
'num_t_mask'
]
=
process
[
'n_mask'
]
augment_conf
[
't_inplace'
]
=
process
[
'inplace'
]
augment_conf
[
't_inplace'
]
=
process
[
'inplace'
]
augment_conf
[
't_replace_with_zero'
]
=
process
[
'replace_with_zero'
]
augment_conf
[
't_replace_with_zero'
]
=
process
[
'replace_with_zero'
]
return
augment_conf
return
augment_conf
class
StreamDataLoader
():
class
StreamDataLoader
():
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -95,12 +96,12 @@ class StreamDataLoader():
...
@@ -95,12 +96,12 @@ class StreamDataLoader():
frame_length
=
25
,
frame_length
=
25
,
frame_shift
=
10
,
frame_shift
=
10
,
dither
=
0.0
,
dither
=
0.0
,
minlen_in
:
float
=
0.0
,
minlen_in
:
float
=
0.0
,
maxlen_in
:
float
=
float
(
'inf'
),
maxlen_in
:
float
=
float
(
'inf'
),
minlen_out
:
float
=
0.0
,
minlen_out
:
float
=
0.0
,
maxlen_out
:
float
=
float
(
'inf'
),
maxlen_out
:
float
=
float
(
'inf'
),
resample_rate
:
int
=
16000
,
resample_rate
:
int
=
16000
,
shuffle_size
:
int
=
10000
,
shuffle_size
:
int
=
10000
,
sort_size
:
int
=
1000
,
sort_size
:
int
=
1000
,
n_iter_processes
:
int
=
1
,
n_iter_processes
:
int
=
1
,
prefetch_factor
:
int
=
2
,
prefetch_factor
:
int
=
2
,
...
@@ -116,11 +117,11 @@ class StreamDataLoader():
...
@@ -116,11 +117,11 @@ class StreamDataLoader():
text_featurizer
=
TextFeaturizer
(
unit_type
,
vocab_filepath
)
text_featurizer
=
TextFeaturizer
(
unit_type
,
vocab_filepath
)
symbol_table
=
text_featurizer
.
vocab_dict
symbol_table
=
text_featurizer
.
vocab_dict
self
.
feat_dim
=
num_mel_bins
self
.
feat_dim
=
num_mel_bins
self
.
vocab_size
=
text_featurizer
.
vocab_size
self
.
vocab_size
=
text_featurizer
.
vocab_size
augment_conf
=
read_preprocess_cfg
(
preprocess_conf
)
augment_conf
=
read_preprocess_cfg
(
preprocess_conf
)
# The list of shard
# The list of shard
shardlist
=
[]
shardlist
=
[]
with
open
(
manifest_file
,
"r"
)
as
f
:
with
open
(
manifest_file
,
"r"
)
as
f
:
...
@@ -128,58 +129,68 @@ class StreamDataLoader():
...
@@ -128,58 +129,68 @@ class StreamDataLoader():
shardlist
.
append
(
line
.
strip
())
shardlist
.
append
(
line
.
strip
())
world_size
=
1
world_size
=
1
try
:
try
:
world_size
=
paddle
.
distributed
.
get_world_size
()
world_size
=
paddle
.
distributed
.
get_world_size
()
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
warninig
(
e
)
logger
.
warninig
(
e
)
logger
.
warninig
(
"can not get world_size using paddle.distributed.get_world_size(), use world_size=1"
)
logger
.
warninig
(
assert
(
len
(
shardlist
)
>=
world_size
,
"the length of shard list should >= number of gpus/xpus/..."
)
"can not get world_size using paddle.distributed.get_world_size(), use world_size=1"
)
assert
len
(
shardlist
)
>=
world_size
,
\
"the length of shard list should >= number of gpus/xpus/..."
update_n_iter_processes
=
int
(
max
(
min
(
len
(
shardlist
)
/
world_size
-
1
,
self
.
n_iter_processes
),
0
))
update_n_iter_processes
=
int
(
max
(
min
(
len
(
shardlist
)
/
world_size
-
1
,
self
.
n_iter_processes
),
0
))
logger
.
info
(
f
"update_n_iter_processes
{
update_n_iter_processes
}
"
)
logger
.
info
(
f
"update_n_iter_processes
{
update_n_iter_processes
}
"
)
if
update_n_iter_processes
!=
self
.
n_iter_processes
:
if
update_n_iter_processes
!=
self
.
n_iter_processes
:
self
.
n_iter_processes
=
update_n_iter_processes
self
.
n_iter_processes
=
update_n_iter_processes
logger
.
info
(
f
"change nun_workers to
{
self
.
n_iter_processes
}
"
)
logger
.
info
(
f
"change nun_workers to
{
self
.
n_iter_processes
}
"
)
if
self
.
dist_sampler
:
if
self
.
dist_sampler
:
base_dataset
=
streamdata
.
DataPipeline
(
base_dataset
=
streamdata
.
DataPipeline
(
streamdata
.
SimpleShardList
(
shardlist
),
streamdata
.
SimpleShardList
(
shardlist
),
streamdata
.
split_by_node
streamdata
.
split_by_node
if
train_mode
else
streamdata
.
placeholder
(),
if
train_mode
else
streamdata
.
placeholder
(),
streamdata
.
split_by_worker
,
streamdata
.
split_by_worker
,
streamdata
.
tarfile_to_samples
(
streamdata
.
reraise_exception
)
streamdata
.
tarfile_to_samples
(
streamdata
.
reraise_exception
))
)
else
:
else
:
base_dataset
=
streamdata
.
DataPipeline
(
base_dataset
=
streamdata
.
DataPipeline
(
streamdata
.
SimpleShardList
(
shardlist
),
streamdata
.
SimpleShardList
(
shardlist
),
streamdata
.
split_by_worker
,
streamdata
.
split_by_worker
,
streamdata
.
tarfile_to_samples
(
streamdata
.
reraise_exception
)
streamdata
.
tarfile_to_samples
(
streamdata
.
reraise_exception
))
)
self
.
dataset
=
base_dataset
.
append_list
(
self
.
dataset
=
base_dataset
.
append_list
(
streamdata
.
audio_tokenize
(
symbol_table
),
streamdata
.
audio_tokenize
(
symbol_table
),
streamdata
.
audio_data_filter
(
frame_shift
=
frame_shift
,
max_length
=
maxlen_in
,
min_length
=
minlen_in
,
token_max_length
=
maxlen_out
,
token_min_length
=
minlen_out
),
streamdata
.
audio_data_filter
(
frame_shift
=
frame_shift
,
max_length
=
maxlen_in
,
min_length
=
minlen_in
,
token_max_length
=
maxlen_out
,
token_min_length
=
minlen_out
),
streamdata
.
audio_resample
(
resample_rate
=
resample_rate
),
streamdata
.
audio_resample
(
resample_rate
=
resample_rate
),
streamdata
.
audio_compute_fbank
(
num_mel_bins
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
),
streamdata
.
audio_compute_fbank
(
streamdata
.
audio_spec_aug
(
**
augment_conf
)
if
train_mode
else
streamdata
.
placeholder
(),
# num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
num_mel_bins
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
),
streamdata
.
audio_spec_aug
(
**
augment_conf
)
if
train_mode
else
streamdata
.
placeholder
(
),
# num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata
.
shuffle
(
shuffle_size
),
streamdata
.
shuffle
(
shuffle_size
),
streamdata
.
sort
(
sort_size
=
sort_size
),
streamdata
.
sort
(
sort_size
=
sort_size
),
streamdata
.
batched
(
batch_size
),
streamdata
.
batched
(
batch_size
),
streamdata
.
audio_padding
(),
streamdata
.
audio_padding
(),
streamdata
.
audio_cmvn
(
cmvn_file
)
streamdata
.
audio_cmvn
(
cmvn_file
))
)
if
paddle
.
__version__
>=
'2.3.2'
:
if
paddle
.
__version__
>=
'2.3.2'
:
self
.
loader
=
streamdata
.
WebLoader
(
self
.
loader
=
streamdata
.
WebLoader
(
self
.
dataset
,
self
.
dataset
,
num_workers
=
self
.
n_iter_processes
,
num_workers
=
self
.
n_iter_processes
,
prefetch_factor
=
self
.
prefetch_factor
,
prefetch_factor
=
self
.
prefetch_factor
,
batch_size
=
None
batch_size
=
None
)
)
else
:
else
:
self
.
loader
=
streamdata
.
WebLoader
(
self
.
loader
=
streamdata
.
WebLoader
(
self
.
dataset
,
self
.
dataset
,
num_workers
=
self
.
n_iter_processes
,
num_workers
=
self
.
n_iter_processes
,
batch_size
=
None
batch_size
=
None
)
)
def
__iter__
(
self
):
def
__iter__
(
self
):
return
self
.
loader
.
__iter__
()
return
self
.
loader
.
__iter__
()
...
@@ -188,7 +199,9 @@ class StreamDataLoader():
...
@@ -188,7 +199,9 @@ class StreamDataLoader():
return
self
.
__iter__
()
return
self
.
__iter__
()
def
__len__
(
self
):
def
__len__
(
self
):
logger
.
info
(
"Stream dataloader does not support calculate the length of the dataset"
)
logger
.
info
(
"Stream dataloader does not support calculate the length of the dataset"
)
return
-
1
return
-
1
...
@@ -347,7 +360,7 @@ class DataLoaderFactory():
...
@@ -347,7 +360,7 @@ class DataLoaderFactory():
config
[
'train_mode'
]
=
True
config
[
'train_mode'
]
=
True
elif
mode
==
'valid'
:
elif
mode
==
'valid'
:
config
[
'manifest'
]
=
config
.
dev_manifest
config
[
'manifest'
]
=
config
.
dev_manifest
config
[
'train_mode'
]
=
False
config
[
'train_mode'
]
=
False
elif
model
==
'test'
or
mode
==
'align'
:
elif
model
==
'test'
or
mode
==
'align'
:
config
[
'manifest'
]
=
config
.
test_manifest
config
[
'manifest'
]
=
config
.
test_manifest
config
[
'train_mode'
]
=
False
config
[
'train_mode'
]
=
False
...
@@ -358,30 +371,31 @@ class DataLoaderFactory():
...
@@ -358,30 +371,31 @@ class DataLoaderFactory():
config
[
'maxlen_out'
]
=
float
(
'inf'
)
config
[
'maxlen_out'
]
=
float
(
'inf'
)
config
[
'dist_sampler'
]
=
False
config
[
'dist_sampler'
]
=
False
else
:
else
:
raise
KeyError
(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
raise
KeyError
(
return
StreamDataLoader
(
"not valid mode type!!, please input one of 'train, valid, test, align'"
manifest_file
=
config
.
manifest
,
train_mode
=
config
.
train_mode
,
unit_type
=
config
.
unit_type
,
preprocess_conf
=
config
.
preprocess_config
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
config
.
dither
,
minlen_in
=
config
.
minlen_in
,
maxlen_in
=
config
.
maxlen_in
,
minlen_out
=
config
.
minlen_out
,
maxlen_out
=
config
.
maxlen_out
,
resample_rate
=
config
.
resample_rate
,
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
dist_sampler
,
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
)
return
StreamDataLoader
(
manifest_file
=
config
.
manifest
,
train_mode
=
config
.
train_mode
,
unit_type
=
config
.
unit_type
,
preprocess_conf
=
config
.
preprocess_config
,
batch_size
=
config
.
batch_size
,
num_mel_bins
=
config
.
feat_dim
,
frame_length
=
config
.
window_ms
,
frame_shift
=
config
.
stride_ms
,
dither
=
config
.
dither
,
minlen_in
=
config
.
minlen_in
,
maxlen_in
=
config
.
maxlen_in
,
minlen_out
=
config
.
minlen_out
,
maxlen_out
=
config
.
maxlen_out
,
resample_rate
=
config
.
resample_rate
,
shuffle_size
=
config
.
shuffle_size
,
sort_size
=
config
.
sort_size
,
n_iter_processes
=
config
.
num_workers
,
prefetch_factor
=
config
.
prefetch_factor
,
dist_sampler
=
config
.
dist_sampler
,
cmvn_file
=
config
.
cmvn_file
,
vocab_filepath
=
config
.
vocab_filepath
,
)
else
:
else
:
if
mode
==
'train'
:
if
mode
==
'train'
:
config
[
'manifest'
]
=
config
.
train_manifest
config
[
'manifest'
]
=
config
.
train_manifest
...
@@ -411,7 +425,7 @@ class DataLoaderFactory():
...
@@ -411,7 +425,7 @@ class DataLoaderFactory():
config
[
'train_mode'
]
=
False
config
[
'train_mode'
]
=
False
config
[
'sortagrad'
]
=
False
config
[
'sortagrad'
]
=
False
config
[
'batch_size'
]
=
config
.
get
(
'decode'
,
dict
()).
get
(
config
[
'batch_size'
]
=
config
.
get
(
'decode'
,
dict
()).
get
(
'decode_batch_size'
,
1
)
'decode_batch_size'
,
1
)
config
[
'maxlen_in'
]
=
float
(
'inf'
)
config
[
'maxlen_in'
]
=
float
(
'inf'
)
config
[
'maxlen_out'
]
=
float
(
'inf'
)
config
[
'maxlen_out'
]
=
float
(
'inf'
)
config
[
'minibatches'
]
=
0
config
[
'minibatches'
]
=
0
...
@@ -427,8 +441,10 @@ class DataLoaderFactory():
...
@@ -427,8 +441,10 @@ class DataLoaderFactory():
config
[
'dist_sampler'
]
=
False
config
[
'dist_sampler'
]
=
False
config
[
'shortest_first'
]
=
False
config
[
'shortest_first'
]
=
False
else
:
else
:
raise
KeyError
(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
raise
KeyError
(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return
BatchDataLoader
(
return
BatchDataLoader
(
json_file
=
config
.
manifest
,
json_file
=
config
.
manifest
,
train_mode
=
config
.
train_mode
,
train_mode
=
config
.
train_mode
,
...
@@ -450,4 +466,3 @@ class DataLoaderFactory():
...
@@ -450,4 +466,3 @@ class DataLoaderFactory():
num_encs
=
config
.
num_encs
,
num_encs
=
config
.
num_encs
,
dist_sampler
=
config
.
dist_sampler
,
dist_sampler
=
config
.
dist_sampler
,
shortest_first
=
config
.
shortest_first
)
shortest_first
=
config
.
shortest_first
)
paddlespeech/s2t/models/u2_st/u2_st.py
浏览文件 @
795eb7bd
...
@@ -18,7 +18,6 @@ Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recogni
...
@@ -18,7 +18,6 @@ Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recogni
"""
"""
import
time
import
time
from
typing
import
Dict
from
typing
import
Dict
from
typing
import
List
from
typing
import
Optional
from
typing
import
Optional
from
typing
import
Tuple
from
typing
import
Tuple
...
@@ -26,6 +25,8 @@ import paddle
...
@@ -26,6 +25,8 @@ import paddle
from
paddle
import
jit
from
paddle
import
jit
from
paddle
import
nn
from
paddle
import
nn
from
paddlespeech.audio.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.audio.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.frontend.utility
import
IGNORE_ID
from
paddlespeech.s2t.frontend.utility
import
IGNORE_ID
from
paddlespeech.s2t.frontend.utility
import
load_cmvn
from
paddlespeech.s2t.frontend.utility
import
load_cmvn
from
paddlespeech.s2t.modules.cmvn
import
GlobalCMVN
from
paddlespeech.s2t.modules.cmvn
import
GlobalCMVN
...
@@ -38,8 +39,6 @@ from paddlespeech.s2t.modules.mask import subsequent_mask
...
@@ -38,8 +39,6 @@ from paddlespeech.s2t.modules.mask import subsequent_mask
from
paddlespeech.s2t.utils
import
checkpoint
from
paddlespeech.s2t.utils
import
checkpoint
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.audio.utils.tensor_utils
import
add_sos_eos
from
paddlespeech.audio.utils.tensor_utils
import
th_accuracy
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
__all__
=
[
"U2STModel"
,
"U2STInferModel"
]
__all__
=
[
"U2STModel"
,
"U2STInferModel"
]
...
@@ -401,8 +400,8 @@ class U2STBaseModel(nn.Layer):
...
@@ -401,8 +400,8 @@ class U2STBaseModel(nn.Layer):
xs
:
paddle
.
Tensor
,
xs
:
paddle
.
Tensor
,
offset
:
int
,
offset
:
int
,
required_cache_size
:
int
,
required_cache_size
:
int
,
att_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
att_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
cnn_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
cnn_cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
]),
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
""" Export interface for c++ call, give input chunk xs, and return
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
output from time 0 to current chunk.
...
@@ -435,8 +434,8 @@ class U2STBaseModel(nn.Layer):
...
@@ -435,8 +434,8 @@ class U2STBaseModel(nn.Layer):
paddle.Tensor: new conformer cnn cache required for next chunk, with
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
same shape as the original cnn_cache.
"""
"""
return
self
.
encoder
.
forward_chunk
(
return
self
.
encoder
.
forward_chunk
(
xs
,
offset
,
required_cache_size
,
xs
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
)
att_cache
,
cnn_cache
)
# @jit.to_static
# @jit.to_static
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
ctc_activation
(
self
,
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
...
...
paddlespeech/s2t/modules/align.py
浏览文件 @
795eb7bd
...
@@ -11,9 +11,10 @@
...
@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
math
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
import
math
"""
"""
To align the initializer between paddle and torch,
To align the initializer between paddle and torch,
the API below are set defalut initializer with priority higger than global initializer.
the API below are set defalut initializer with priority higger than global initializer.
...
@@ -81,10 +82,18 @@ class Linear(nn.Linear):
...
@@ -81,10 +82,18 @@ class Linear(nn.Linear):
name
=
None
):
name
=
None
):
if
weight_attr
is
None
:
if
weight_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
if
global_init_type
==
"kaiming_uniform"
:
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
if
bias_attr
is
None
:
if
bias_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
if
global_init_type
==
"kaiming_uniform"
:
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
super
(
Linear
,
self
).
__init__
(
in_features
,
out_features
,
weight_attr
,
super
(
Linear
,
self
).
__init__
(
in_features
,
out_features
,
weight_attr
,
bias_attr
,
name
)
bias_attr
,
name
)
...
@@ -104,10 +113,18 @@ class Conv1D(nn.Conv1D):
...
@@ -104,10 +113,18 @@ class Conv1D(nn.Conv1D):
data_format
=
'NCL'
):
data_format
=
'NCL'
):
if
weight_attr
is
None
:
if
weight_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
if
global_init_type
==
"kaiming_uniform"
:
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
if
bias_attr
is
None
:
if
bias_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
if
global_init_type
==
"kaiming_uniform"
:
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
super
(
Conv1D
,
self
).
__init__
(
super
(
Conv1D
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
padding_mode
,
weight_attr
,
bias_attr
,
data_format
)
groups
,
padding_mode
,
weight_attr
,
bias_attr
,
data_format
)
...
@@ -128,10 +145,18 @@ class Conv2D(nn.Conv2D):
...
@@ -128,10 +145,18 @@ class Conv2D(nn.Conv2D):
data_format
=
'NCHW'
):
data_format
=
'NCHW'
):
if
weight_attr
is
None
:
if
weight_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
if
global_init_type
==
"kaiming_uniform"
:
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
if
bias_attr
is
None
:
if
bias_attr
is
None
:
if
global_init_type
==
"kaiming_uniform"
:
if
global_init_type
==
"kaiming_uniform"
:
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
(
fan_in
=
None
,
negative_slope
=
math
.
sqrt
(
5
),
nonlinearity
=
'leaky_relu'
))
super
(
Conv2D
,
self
).
__init__
(
super
(
Conv2D
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
padding_mode
,
weight_attr
,
bias_attr
,
data_format
)
groups
,
padding_mode
,
weight_attr
,
bias_attr
,
data_format
)
paddlespeech/s2t/modules/initializer.py
浏览文件 @
795eb7bd
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
numpy
as
np
class
DefaultInitializerContext
(
object
):
class
DefaultInitializerContext
(
object
):
"""
"""
...
...
paddlespeech/server/engine/asr/online/ctc_endpoint.py
浏览文件 @
795eb7bd
...
@@ -102,8 +102,10 @@ class OnlineCTCEndpoint:
...
@@ -102,8 +102,10 @@ class OnlineCTCEndpoint:
assert
self
.
num_frames_decoded
>=
self
.
trailing_silence_frames
assert
self
.
num_frames_decoded
>=
self
.
trailing_silence_frames
assert
self
.
frame_shift_in_ms
>
0
assert
self
.
frame_shift_in_ms
>
0
decoding_something
=
(
self
.
num_frames_decoded
>
self
.
trailing_silence_frames
)
and
decoding_something
decoding_something
=
(
self
.
num_frames_decoded
>
self
.
trailing_silence_frames
)
and
decoding_something
utterance_length
=
self
.
num_frames_decoded
*
self
.
frame_shift_in_ms
utterance_length
=
self
.
num_frames_decoded
*
self
.
frame_shift_in_ms
trailing_silence
=
self
.
trailing_silence_frames
*
self
.
frame_shift_in_ms
trailing_silence
=
self
.
trailing_silence_frames
*
self
.
frame_shift_in_ms
...
...
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
浏览文件 @
795eb7bd
...
@@ -21,12 +21,12 @@ import paddle
...
@@ -21,12 +21,12 @@ import paddle
from
numpy
import
float32
from
numpy
import
float32
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.log
import
logger
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils
import
onnx_infer
from
paddlespeech.server.utils
import
onnx_infer
...
...
paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
浏览文件 @
795eb7bd
...
@@ -21,10 +21,10 @@ import paddle
...
@@ -21,10 +21,10 @@ import paddle
from
numpy
import
float32
from
numpy
import
float32
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.log
import
logger
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.resource
import
CommonTaskResource
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
...
...
paddlespeech/server/engine/asr/python/asr_engine.py
浏览文件 @
795eb7bd
...
@@ -66,12 +66,14 @@ class ASREngine(BaseEngine):
...
@@ -66,12 +66,14 @@ class ASREngine(BaseEngine):
)
)
logger
.
error
(
e
)
logger
.
error
(
e
)
return
False
return
False
self
.
executor
.
_init_from_path
(
model_type
=
self
.
config
.
model
,
lang
=
self
.
config
.
lang
,
sample_rate
=
self
.
config
.
sample_rate
,
cfg_path
=
self
.
config
.
cfg_path
,
decode_method
=
self
.
config
.
decode_method
,
ckpt_path
=
self
.
config
.
ckpt_path
)
self
.
executor
.
_init_from_path
(
model_type
=
self
.
config
.
model
,
lang
=
self
.
config
.
lang
,
sample_rate
=
self
.
config
.
sample_rate
,
cfg_path
=
self
.
config
.
cfg_path
,
decode_method
=
self
.
config
.
decode_method
,
ckpt_path
=
self
.
config
.
ckpt_path
)
logger
.
info
(
"Initialize ASR server engine successfully on device: %s."
%
logger
.
info
(
"Initialize ASR server engine successfully on device: %s."
%
(
self
.
device
))
(
self
.
device
))
...
...
paddlespeech/t2s/datasets/sampler.py
浏览文件 @
795eb7bd
import
paddle
import
math
import
math
import
numpy
as
np
import
numpy
as
np
from
paddle.io
import
BatchSampler
from
paddle.io
import
BatchSampler
class
ErnieSATSampler
(
BatchSampler
):
class
ErnieSATSampler
(
BatchSampler
):
"""Sampler that restricts data loading to a subset of the dataset.
"""Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance
In such case, each process can pass a DistributedBatchSampler instance
...
@@ -110,8 +111,8 @@ class ErnieSATSampler(BatchSampler):
...
@@ -110,8 +111,8 @@ class ErnieSATSampler(BatchSampler):
subsampled_indices
.
extend
(
indices
[
i
:
i
+
self
.
batch_size
])
subsampled_indices
.
extend
(
indices
[
i
:
i
+
self
.
batch_size
])
indices
=
indices
[
len
(
indices
)
-
last_batch_size
:]
indices
=
indices
[
len
(
indices
)
-
last_batch_size
:]
subsampled_indices
.
extend
(
indices
[
subsampled_indices
.
extend
(
self
.
local_rank
*
last_local_batch_size
:(
indices
[
self
.
local_rank
*
last_local_batch_size
:(
self
.
local_rank
+
1
)
*
last_local_batch_size
])
self
.
local_rank
+
1
)
*
last_local_batch_size
])
return
subsampled_indices
return
subsampled_indices
...
...
paddlespeech/t2s/exps/ernie_sat/train.py
浏览文件 @
795eb7bd
...
@@ -25,7 +25,6 @@ from paddle import DataParallel
...
@@ -25,7 +25,6 @@ from paddle import DataParallel
from
paddle
import
distributed
as
dist
from
paddle
import
distributed
as
dist
from
paddle
import
nn
from
paddle
import
nn
from
paddle.io
import
DataLoader
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddle.optimizer
import
Adam
from
paddle.optimizer
import
Adam
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
...
...
paddlespeech/t2s/exps/ernie_sat/utils.py
浏览文件 @
795eb7bd
...
@@ -11,32 +11,35 @@
...
@@ -11,32 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
hashlib
import
os
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
from
typing
import
Dict
from
typing
import
List
from
typing
import
List
from
typing
import
Union
from
typing
import
Union
import
os
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
yaml
import
yaml
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
import
hashlib
from
paddlespeech.t2s.exps.syn_utils
import
get_am_inference
from
paddlespeech.t2s.exps.syn_utils
import
get_am_inference
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_inference
from
paddlespeech.t2s.exps.syn_utils
import
get_voc_inference
def
_get_user
():
def
_get_user
():
return
os
.
path
.
expanduser
(
'~'
).
split
(
'/'
)[
-
1
]
return
os
.
path
.
expanduser
(
'~'
).
split
(
'/'
)[
-
1
]
def
str2md5
(
string
):
def
str2md5
(
string
):
md5_val
=
hashlib
.
md5
(
string
.
encode
(
'utf8'
)).
hexdigest
()
md5_val
=
hashlib
.
md5
(
string
.
encode
(
'utf8'
)).
hexdigest
()
return
md5_val
return
md5_val
def
get_tmp_name
(
text
:
str
):
def
get_tmp_name
(
text
:
str
):
return
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
+
'_'
+
str2md5
(
text
)
return
_get_user
()
+
'_'
+
str
(
os
.
getpid
())
+
'_'
+
str2md5
(
text
)
def
get_dict
(
dictfile
:
str
):
def
get_dict
(
dictfile
:
str
):
word2phns_dict
=
{}
word2phns_dict
=
{}
with
open
(
dictfile
,
'r'
)
as
fid
:
with
open
(
dictfile
,
'r'
)
as
fid
:
...
...
paddlespeech/t2s/exps/syn_utils.py
浏览文件 @
795eb7bd
...
@@ -298,8 +298,8 @@ def am_to_static(am_inference,
...
@@ -298,8 +298,8 @@ def am_to_static(am_inference,
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
if
am_name
==
'fastspeech2'
:
if
am_name
==
'fastspeech2'
:
if
am_dataset
in
{
"aishell3"
,
"vctk"
,
"mix"
if
am_dataset
in
{
"aishell3"
,
"vctk"
,
}
and
speaker_dict
is
not
None
:
"mix"
}
and
speaker_dict
is
not
None
:
am_inference
=
jit
.
to_static
(
am_inference
=
jit
.
to_static
(
am_inference
,
am_inference
,
input_spec
=
[
input_spec
=
[
...
@@ -311,8 +311,8 @@ def am_to_static(am_inference,
...
@@ -311,8 +311,8 @@ def am_to_static(am_inference,
am_inference
,
input_spec
=
[
InputSpec
([
-
1
],
dtype
=
paddle
.
int64
)])
am_inference
,
input_spec
=
[
InputSpec
([
-
1
],
dtype
=
paddle
.
int64
)])
elif
am_name
==
'speedyspeech'
:
elif
am_name
==
'speedyspeech'
:
if
am_dataset
in
{
"aishell3"
,
"vctk"
,
"mix"
if
am_dataset
in
{
"aishell3"
,
"vctk"
,
}
and
speaker_dict
is
not
None
:
"mix"
}
and
speaker_dict
is
not
None
:
am_inference
=
jit
.
to_static
(
am_inference
=
jit
.
to_static
(
am_inference
,
am_inference
,
input_spec
=
[
input_spec
=
[
...
...
paddlespeech/t2s/frontend/g2pw/__init__.py
浏览文件 @
795eb7bd
from
paddlespeech.t2s.frontend.g2pw.onnx_api
import
G2PWOnnxConverter
from
paddlespeech.t2s.frontend.g2pw.onnx_api
import
G2PWOnnxConverter
paddlespeech/t2s/frontend/mix_frontend.py
浏览文件 @
795eb7bd
...
@@ -61,8 +61,11 @@ class MixFrontend():
...
@@ -61,8 +61,11 @@ class MixFrontend():
return
False
return
False
def
is_end
(
self
,
before_char
,
after_char
)
->
bool
:
def
is_end
(
self
,
before_char
,
after_char
)
->
bool
:
if
((
self
.
is_alphabet
(
before_char
)
or
before_char
==
" "
)
and
flag
=
0
(
self
.
is_alphabet
(
after_char
)
or
after_char
==
" "
)):
for
char
in
(
before_char
,
after_char
):
if
self
.
is_alphabet
(
char
)
or
char
==
" "
:
flag
+=
1
if
flag
==
2
:
return
True
return
True
else
:
else
:
return
False
return
False
...
...
paddlespeech/t2s/training/updaters/standard_updater.py
浏览文件 @
795eb7bd
...
@@ -24,10 +24,11 @@ from paddle.nn import Layer
...
@@ -24,10 +24,11 @@ from paddle.nn import Layer
from
paddle.optimizer
import
Optimizer
from
paddle.optimizer
import
Optimizer
from
timer
import
timer
from
timer
import
timer
from
paddlespeech.t2s.datasets.sampler
import
ErnieSATSampler
from
paddlespeech.t2s.training.reporter
import
report
from
paddlespeech.t2s.training.reporter
import
report
from
paddlespeech.t2s.training.updater
import
UpdaterBase
from
paddlespeech.t2s.training.updater
import
UpdaterBase
from
paddlespeech.t2s.training.updater
import
UpdaterState
from
paddlespeech.t2s.training.updater
import
UpdaterState
from
paddlespeech.t2s.datasets.sampler
import
ErnieSATSampler
class
StandardUpdater
(
UpdaterBase
):
class
StandardUpdater
(
UpdaterBase
):
"""An example of over-simplification. Things may not be that simple, but
"""An example of over-simplification. Things may not be that simple, but
...
...
setup.py
浏览文件 @
795eb7bd
...
@@ -77,12 +77,7 @@ base = [
...
@@ -77,12 +77,7 @@ base = [
"pybind11"
,
"pybind11"
,
]
]
server
=
[
server
=
[
"fastapi"
,
"uvicorn"
,
"pattern_singleton"
,
"websockets"
]
"fastapi"
,
"uvicorn"
,
"pattern_singleton"
,
"websockets"
]
requirements
=
{
requirements
=
{
"install"
:
"install"
:
...
@@ -330,4 +325,4 @@ setup_info = dict(
...
@@ -330,4 +325,4 @@ setup_info = dict(
})
})
with
version_info
():
with
version_info
():
setup
(
**
setup_info
,
include_package_data
=
True
)
setup
(
**
setup_info
,
include_package_data
=
True
)
speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py
浏览文件 @
795eb7bd
...
@@ -490,18 +490,10 @@ class SymbolicShapeInference:
...
@@ -490,18 +490,10 @@ class SymbolicShapeInference:
def
_onnx_infer_single_node
(
self
,
node
):
def
_onnx_infer_single_node
(
self
,
node
):
# skip onnx shape inference for some ops, as they are handled in _infer_*
# skip onnx shape inference for some ops, as they are handled in _infer_*
skip_infer
=
node
.
op_type
in
[
skip_infer
=
node
.
op_type
in
[
'If'
,
'Loop'
,
'Scan'
,
'SplitToSequence'
,
'ZipMap'
,
\
'If'
,
'Loop'
,
'Scan'
,
'SplitToSequence'
,
'ZipMap'
,
'Attention'
,
# contrib ops
'BiasGelu'
,
'EmbedLayerNormalization'
,
'FastGelu'
,
'Gelu'
,
'LayerNormalization'
,
'LongformerAttention'
,
'SkipLayerNormalization'
,
'PythonOp'
'Attention'
,
'BiasGelu'
,
\
'EmbedLayerNormalization'
,
\
'FastGelu'
,
'Gelu'
,
'LayerNormalization'
,
\
'LongformerAttention'
,
\
'SkipLayerNormalization'
,
\
'PythonOp'
]
]
if
not
skip_infer
:
if
not
skip_infer
:
...
@@ -514,8 +506,8 @@ class SymbolicShapeInference:
...
@@ -514,8 +506,8 @@ class SymbolicShapeInference:
if
(
get_opset
(
self
.
out_mp_
)
>=
9
)
and
node
.
op_type
in
[
'Unsqueeze'
]:
if
(
get_opset
(
self
.
out_mp_
)
>=
9
)
and
node
.
op_type
in
[
'Unsqueeze'
]:
initializers
=
[
initializers
=
[
self
.
initializers_
[
name
]
for
name
in
node
.
input
self
.
initializers_
[
name
]
for
name
in
node
.
input
if
(
name
in
self
.
initializers_
and
if
(
name
in
self
.
initializers_
and
name
not
in
name
not
in
self
.
graph_inputs_
)
self
.
graph_inputs_
)
]
]
# run single node inference with self.known_vi_ shapes
# run single node inference with self.known_vi_ shapes
...
@@ -601,8 +593,8 @@ class SymbolicShapeInference:
...
@@ -601,8 +593,8 @@ class SymbolicShapeInference:
for
o
in
symbolic_shape_inference
.
out_mp_
.
graph
.
output
for
o
in
symbolic_shape_inference
.
out_mp_
.
graph
.
output
]
]
subgraph_new_symbolic_dims
=
set
([
subgraph_new_symbolic_dims
=
set
([
d
for
s
in
subgraph_shapes
if
s
for
d
in
s
d
for
s
in
subgraph_shapes
if
type
(
d
)
==
str
and
not
d
in
self
.
symbolic_dims_
if
s
for
d
in
s
if
type
(
d
)
==
str
and
not
d
in
self
.
symbolic_dims_
])
])
new_dims
=
{}
new_dims
=
{}
for
d
in
subgraph_new_symbolic_dims
:
for
d
in
subgraph_new_symbolic_dims
:
...
@@ -729,8 +721,9 @@ class SymbolicShapeInference:
...
@@ -729,8 +721,9 @@ class SymbolicShapeInference:
for
d
,
s
in
zip
(
sympy_shape
[
-
rank
:],
strides
)
for
d
,
s
in
zip
(
sympy_shape
[
-
rank
:],
strides
)
]
]
total_pads
=
[
total_pads
=
[
max
(
0
,
(
k
-
s
)
if
r
==
0
else
(
k
-
r
))
for
k
,
s
,
r
in
max
(
0
,
(
k
-
s
)
if
r
==
0
else
(
k
-
r
))
zip
(
effective_kernel_shape
,
strides
,
residual
)
for
k
,
s
,
r
in
zip
(
effective_kernel_shape
,
strides
,
residual
)
]
]
except
TypeError
:
# sympy may throw TypeError: cannot determine truth value of Relational
except
TypeError
:
# sympy may throw TypeError: cannot determine truth value of Relational
total_pads
=
[
total_pads
=
[
...
@@ -1276,8 +1269,9 @@ class SymbolicShapeInference:
...
@@ -1276,8 +1269,9 @@ class SymbolicShapeInference:
if
pads
is
not
None
:
if
pads
is
not
None
:
assert
len
(
pads
)
==
2
*
rank
assert
len
(
pads
)
==
2
*
rank
new_sympy_shape
=
[
new_sympy_shape
=
[
d
+
pad_up
+
pad_down
for
d
,
pad_up
,
pad_down
in
d
+
pad_up
+
pad_down
zip
(
sympy_shape
,
pads
[:
rank
],
pads
[
rank
:])
for
d
,
pad_up
,
pad_down
in
zip
(
sympy_shape
,
pads
[:
rank
],
pads
[
rank
:])
]
]
self
.
_update_computed_dims
(
new_sympy_shape
)
self
.
_update_computed_dims
(
new_sympy_shape
)
else
:
else
:
...
@@ -1590,8 +1584,8 @@ class SymbolicShapeInference:
...
@@ -1590,8 +1584,8 @@ class SymbolicShapeInference:
scales
=
list
(
scales
)
scales
=
list
(
scales
)
new_sympy_shape
=
[
new_sympy_shape
=
[
sympy
.
simplify
(
sympy
.
floor
(
d
*
(
end
-
start
)
*
scale
))
sympy
.
simplify
(
sympy
.
floor
(
d
*
(
end
-
start
)
*
scale
))
for
d
,
start
,
end
,
scale
in
for
d
,
start
,
end
,
scale
in
zip
(
input_sympy_shape
,
zip
(
input_sympy_shape
,
roi_start
,
roi_end
,
scales
)
roi_start
,
roi_end
,
scales
)
]
]
self
.
_update_computed_dims
(
new_sympy_shape
)
self
.
_update_computed_dims
(
new_sympy_shape
)
else
:
else
:
...
@@ -2204,8 +2198,9 @@ class SymbolicShapeInference:
...
@@ -2204,8 +2198,9 @@ class SymbolicShapeInference:
# topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate
# topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate
sorted_nodes
=
[]
sorted_nodes
=
[]
sorted_known_vi
=
set
([
sorted_known_vi
=
set
([
i
.
name
for
i
in
list
(
self
.
out_mp_
.
graph
.
input
)
+
i
.
name
list
(
self
.
out_mp_
.
graph
.
initializer
)
for
i
in
list
(
self
.
out_mp_
.
graph
.
input
)
+
list
(
self
.
out_mp_
.
graph
.
initializer
)
])
])
if
any
([
o
.
name
in
sorted_known_vi
for
o
in
self
.
out_mp_
.
graph
.
output
]):
if
any
([
o
.
name
in
sorted_known_vi
for
o
in
self
.
out_mp_
.
graph
.
output
]):
# Loop/Scan will have some graph output in graph inputs, so don't do topological sort
# Loop/Scan will have some graph output in graph inputs, so don't do topological sort
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录