Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
8828210f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8828210f
编写于
12月 02, 2021
作者:
K
KP
提交者:
GitHub
12月 02, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3 from Jackwaterveg/cli_infer
revise the sample rate
上级
a19e51d7
a9d206c1
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
29 addition
and
34 deletion
+29
-34
paddlespeech/cli/asr/infer.py
paddlespeech/cli/asr/infer.py
+29
-34
未找到文件。
paddlespeech/cli/asr/infer.py
浏览文件 @
8828210f
...
@@ -22,6 +22,7 @@ import librosa
...
@@ -22,6 +22,7 @@ import librosa
import
paddle
import
paddle
import
soundfile
import
soundfile
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
import
numpy
as
np
from
..executor
import
BaseExecutor
from
..executor
import
BaseExecutor
from
..utils
import
cli_register
from
..utils
import
cli_register
...
@@ -78,9 +79,10 @@ class ASRExecutor(BaseExecutor):
...
@@ -78,9 +79,10 @@ class ASRExecutor(BaseExecutor):
default
=
'zh'
,
default
=
'zh'
,
help
=
'Choose model language. zh or en'
)
help
=
'Choose model language. zh or en'
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
"--
model_sample_rate
"
,
"--
sr
"
,
type
=
int
,
type
=
int
,
default
=
16000
,
default
=
16000
,
choices
=
[
8000
,
16000
],
help
=
'Choose the audio sample rate of the model. 8000 or 16000'
)
help
=
'Choose the audio sample rate of the model. 8000 or 16000'
)
self
.
parser
.
add_argument
(
self
.
parser
.
add_argument
(
'--config'
,
'--config'
,
...
@@ -117,26 +119,27 @@ class ASRExecutor(BaseExecutor):
...
@@ -117,26 +119,27 @@ class ASRExecutor(BaseExecutor):
def
_init_from_path
(
self
,
def
_init_from_path
(
self
,
model_type
:
str
=
'wenetspeech'
,
model_type
:
str
=
'wenetspeech'
,
lang
:
str
=
'zh'
,
lang
:
str
=
'zh'
,
model_
sample_rate
:
int
=
16000
,
sample_rate
:
int
=
16000
,
cfg_path
:
Optional
[
os
.
PathLike
]
=
None
,
cfg_path
:
Optional
[
os
.
PathLike
]
=
None
,
ckpt_path
:
Optional
[
os
.
PathLike
]
=
None
):
ckpt_path
:
Optional
[
os
.
PathLike
]
=
None
):
"""
"""
Init model and other resources from a specific path.
Init model and other resources from a specific path.
"""
"""
if
cfg_path
is
None
or
ckpt_path
is
None
:
if
cfg_path
is
None
or
ckpt_path
is
None
:
model_sample_rate_str
=
'16k'
if
model_
sample_rate
==
16000
else
'8k'
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'_'
+
lang
+
'_'
+
model_
sample_rate_str
tag
=
model_type
+
'_'
+
lang
+
'_'
+
sample_rate_str
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
self
.
cfg_path
=
os
.
path
.
join
(
res_path
,
self
.
cfg_path
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'cfg_path'
])
pretrained_models
[
tag
][
'cfg_path'
])
self
.
ckpt_path
=
os
.
path
.
join
(
res_path
,
self
.
ckpt_path
=
os
.
path
.
join
(
res_path
,
pretrained_models
[
tag
][
'ckpt_path'
])
pretrained_models
[
tag
][
'ckpt_path'
]
+
".pdparams"
)
logger
.
info
(
res_path
)
logger
.
info
(
res_path
)
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
ckpt_path
)
logger
.
info
(
self
.
ckpt_path
)
else
:
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
ckpt_path
=
os
.
path
.
abspath
(
ckpt_path
)
self
.
ckpt_path
=
os
.
path
.
abspath
(
ckpt_path
+
".pdparams"
)
res_path
=
os
.
path
.
dirname
(
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
...
@@ -182,8 +185,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -182,8 +185,7 @@ class ASRExecutor(BaseExecutor):
self
.
model
.
eval
()
self
.
model
.
eval
()
# load model
# load model
params_path
=
self
.
ckpt_path
+
".pdparams"
model_dict
=
paddle
.
load
(
self
.
ckpt_path
)
model_dict
=
paddle
.
load
(
params_path
)
self
.
model
.
set_state_dict
(
model_dict
)
self
.
model
.
set_state_dict
(
model_dict
)
def
preprocess
(
self
,
model_type
:
str
,
input
:
Union
[
str
,
os
.
PathLike
]):
def
preprocess
(
self
,
model_type
:
str
,
input
:
Union
[
str
,
os
.
PathLike
]):
...
@@ -195,8 +197,6 @@ class ASRExecutor(BaseExecutor):
...
@@ -195,8 +197,6 @@ class ASRExecutor(BaseExecutor):
audio_file
=
input
audio_file
=
input
logger
.
info
(
"Preprocess audio_file:"
+
audio_file
)
logger
.
info
(
"Preprocess audio_file:"
+
audio_file
)
config_target_sample_rate
=
self
.
config
.
collator
.
target_sample_rate
# Get the object for feature extraction
# Get the object for feature extraction
if
model_type
==
"ds2_online"
or
model_type
==
"ds2_offline"
:
if
model_type
==
"ds2_online"
or
model_type
==
"ds2_offline"
:
audio
,
_
=
self
.
collate_fn_test
.
process_utterance
(
audio
,
_
=
self
.
collate_fn_test
.
process_utterance
(
...
@@ -220,7 +220,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -220,7 +220,7 @@ class ASRExecutor(BaseExecutor):
preprocess_args
=
{
"train"
:
False
}
preprocess_args
=
{
"train"
:
False
}
preprocessing
=
Transformation
(
preprocess_conf
)
preprocessing
=
Transformation
(
preprocess_conf
)
logger
.
info
(
"read the audio file"
)
logger
.
info
(
"read the audio file"
)
audio
,
sample_rate
=
soundfile
.
read
(
audio
,
audio_
sample_rate
=
soundfile
.
read
(
audio_file
,
dtype
=
"int16"
,
always_2d
=
True
)
audio_file
,
dtype
=
"int16"
,
always_2d
=
True
)
if
self
.
change_format
:
if
self
.
change_format
:
...
@@ -229,17 +229,13 @@ class ASRExecutor(BaseExecutor):
...
@@ -229,17 +229,13 @@ class ASRExecutor(BaseExecutor):
else
:
else
:
audio
=
audio
[:,
0
]
audio
=
audio
[:,
0
]
audio
=
audio
.
astype
(
"float32"
)
audio
=
audio
.
astype
(
"float32"
)
audio
=
librosa
.
resample
(
audio
,
sample_rate
,
audio
=
librosa
.
resample
(
audio
,
audio_
sample_rate
,
self
.
target_
sample_rate
)
self
.
sample_rate
)
sample_rate
=
self
.
target_
sample_rate
audio_sample_rate
=
self
.
sample_rate
audio
=
audio
.
astype
(
"int16"
)
audio
=
np
.
round
(
audio
)
.
astype
(
"int16"
)
else
:
else
:
audio
=
audio
[:,
0
]
audio
=
audio
[:,
0
]
if
sample_rate
!=
config_target_sample_rate
:
logger
.
error
(
f
"sample rate error:
{
sample_rate
}
, need
{
self
.
sr
}
"
)
sys
.
exit
(
-
1
)
logger
.
info
(
f
"audio shape:
{
audio
.
shape
}
"
)
logger
.
info
(
f
"audio shape:
{
audio
.
shape
}
"
)
# fbank
# fbank
audio
=
preprocessing
(
audio
,
**
preprocess_args
)
audio
=
preprocessing
(
audio
,
**
preprocess_args
)
...
@@ -311,11 +307,11 @@ class ASRExecutor(BaseExecutor):
...
@@ -311,11 +307,11 @@ class ASRExecutor(BaseExecutor):
"""
"""
return
self
.
_outputs
[
"result"
]
return
self
.
_outputs
[
"result"
]
def
_check
(
self
,
audio_file
:
str
,
model_
sample_rate
:
int
):
def
_check
(
self
,
audio_file
:
str
,
sample_rate
:
int
):
self
.
target_sample_rate
=
model_
sample_rate
self
.
sample_rate
=
sample_rate
if
self
.
target_sample_rate
!=
16000
and
self
.
target_
sample_rate
!=
8000
:
if
self
.
sample_rate
!=
16000
and
self
.
sample_rate
!=
8000
:
logger
.
error
(
logger
.
error
(
"please input --
model_sample_rate 8000 or --model_sample_rate
16000"
"please input --
sr 8000 or --sr
16000"
)
)
raise
Exception
(
"invalid sample rate"
)
raise
Exception
(
"invalid sample rate"
)
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
...
@@ -326,7 +322,7 @@ class ASRExecutor(BaseExecutor):
...
@@ -326,7 +322,7 @@ class ASRExecutor(BaseExecutor):
logger
.
info
(
"checking the audio file format......"
)
logger
.
info
(
"checking the audio file format......"
)
try
:
try
:
sig
,
sample_rate
=
soundfile
.
read
(
audio
,
audio_
sample_rate
=
soundfile
.
read
(
audio_file
,
dtype
=
"int16"
,
always_2d
=
True
)
audio_file
,
dtype
=
"int16"
,
always_2d
=
True
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
str
(
e
))
logger
.
error
(
str
(
e
))
...
@@ -340,15 +336,15 @@ class ASRExecutor(BaseExecutor):
...
@@ -340,15 +336,15 @@ class ASRExecutor(BaseExecutor):
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav
\n
\
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav
\n
\
"
)
"
)
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
logger
.
info
(
"The sample rate is %d"
%
sample_rate
)
logger
.
info
(
"The sample rate is %d"
%
audio_
sample_rate
)
if
sample_rate
!=
self
.
target_
sample_rate
:
if
audio_sample_rate
!=
self
.
sample_rate
:
logger
.
warning
(
logger
.
warning
(
"The sample rate of the input file is not {}.
\n
\
"The sample rate of the input file is not {}.
\n
\
The program will resample the wav file to {}.
\n
\
The program will resample the wav file to {}.
\n
\
If the result does not meet your expectations,
\n
\
If the result does not meet your expectations,
\n
\
Please input the 16k 16bit 1 channel wav file.
\
Please input the 16k 16bit 1 channel wav file.
\
"
"
.
format
(
self
.
target_sample_rate
,
self
.
target_
sample_rate
))
.
format
(
self
.
sample_rate
,
self
.
sample_rate
))
while
(
True
):
while
(
True
):
logger
.
info
(
logger
.
info
(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
...
@@ -379,14 +375,14 @@ class ASRExecutor(BaseExecutor):
...
@@ -379,14 +375,14 @@ class ASRExecutor(BaseExecutor):
model
=
parser_args
.
model
model
=
parser_args
.
model
lang
=
parser_args
.
lang
lang
=
parser_args
.
lang
model_sample_rate
=
parser_args
.
model_sample_rate
sample_rate
=
parser_args
.
sr
config
=
parser_args
.
config
config
=
parser_args
.
config
ckpt_path
=
parser_args
.
ckpt_path
ckpt_path
=
parser_args
.
ckpt_path
audio_file
=
parser_args
.
input
audio_file
=
parser_args
.
input
device
=
parser_args
.
device
device
=
parser_args
.
device
try
:
try
:
res
=
self
(
model
,
lang
,
model_
sample_rate
,
config
,
ckpt_path
,
res
=
self
(
model
,
lang
,
sample_rate
,
config
,
ckpt_path
,
audio_file
,
device
)
audio_file
,
device
)
logger
.
info
(
'ASR Result: {}'
.
format
(
res
))
logger
.
info
(
'ASR Result: {}'
.
format
(
res
))
return
True
return
True
...
@@ -394,16 +390,15 @@ class ASRExecutor(BaseExecutor):
...
@@ -394,16 +390,15 @@ class ASRExecutor(BaseExecutor):
print
(
e
)
print
(
e
)
return
False
return
False
def
__call__
(
self
,
model
,
lang
,
model_
sample_rate
,
config
,
ckpt_path
,
def
__call__
(
self
,
model
,
lang
,
sample_rate
,
config
,
ckpt_path
,
audio_file
,
device
):
audio_file
,
device
):
"""
"""
Python API to call an executor.
Python API to call an executor.
"""
"""
audio_file
=
os
.
path
.
abspath
(
audio_file
)
audio_file
=
os
.
path
.
abspath
(
audio_file
)
self
.
_check
(
audio_file
,
model_sample_rate
)
self
.
_check
(
audio_file
,
sample_rate
)
paddle
.
set_device
(
device
)
paddle
.
set_device
(
device
)
self
.
_init_from_path
(
model
,
lang
,
model_
sample_rate
,
config
,
ckpt_path
)
self
.
_init_from_path
(
model
,
lang
,
sample_rate
,
config
,
ckpt_path
)
self
.
preprocess
(
model
,
audio_file
)
self
.
preprocess
(
model
,
audio_file
)
self
.
infer
(
model
)
self
.
infer
(
model
)
res
=
self
.
postprocess
()
# Retrieve result of asr.
res
=
self
.
postprocess
()
# Retrieve result of asr.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录