Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
789471bf
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
789471bf
编写于
11月 25, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test wav for u2
上级
f598df0c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
193 addition
and
0 deletion
+193
-0
examples/wenetspeech/asr1/local/test_wav.sh
examples/wenetspeech/asr1/local/test_wav.sh
+45
-0
paddlespeech/s2t/exps/u2/bin/test_wav.py
paddlespeech/s2t/exps/u2/bin/test_wav.py
+148
-0
未找到文件。
examples/wenetspeech/asr1/local/test_wav.sh
0 → 100755
浏览文件 @
789471bf
#!/bin/bash
if
[
$#
!=
3
]
;
then
echo
"usage:
${
0
}
config_path ckpt_path_prefix audio_file"
exit
-1
fi
ngpu
=
$(
echo
$CUDA_VISIBLE_DEVICES
|
awk
-F
","
'{print NF}'
)
echo
"using
$ngpu
gpus..."
config_path
=
$1
ckpt_prefix
=
$2
audio_file
=
$3
chunk_mode
=
false
if
[[
${
config_path
}
=
~ ^.
*
chunk_.
*
yaml
$
]]
;
then
chunk_mode
=
true
fi
# download language model
#bash local/download_lm_ch.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
for
type
in
attention_rescoring
;
do
echo
"decoding
${
type
}
"
batch_size
=
1
output_dir
=
${
ckpt_prefix
}
mkdir
-p
${
output_dir
}
python3
-u
${
BIN_DIR
}
/test_wav.py
\
--nproc
${
ngpu
}
\
--config
${
config_path
}
\
--result_file
${
output_dir
}
/
${
type
}
.rsl
\
--checkpoint_path
${
ckpt_prefix
}
\
--opts
decoding.decoding_method
${
type
}
\
--opts
decoding.batch_size
${
batch_size
}
\
--audio_file
${
audio_file
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in evaluation!"
exit
1
fi
done
exit
0
paddlespeech/s2t/exps/u2/bin/test_
hub
.py
→
paddlespeech/s2t/exps/u2/bin/test_
wav
.py
浏览文件 @
789471bf
...
@@ -12,125 +12,107 @@
...
@@ -12,125 +12,107 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Evaluation for U2 model."""
"""Evaluation for U2 model."""
import
cProfile
import
os
import
os
import
sys
import
sys
from
pathlib
import
Path
import
paddle
import
paddle
import
soundfile
import
soundfile
from
paddlespeech.s2t.exps.u2.config
import
get_cfg_defaults
from
paddlespeech.s2t.exps.u2.config
import
get_cfg_defaults
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
paddlespeech.s2t.io.collator
import
SpeechCollator
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.models.u2
import
U2Model
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.training.trainer
import
Trainer
from
paddlespeech.s2t.transform.transformation
import
Transformation
from
paddlespeech.s2t.utils
import
layer_tools
from
paddlespeech.s2t.utils
import
mp_tools
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.utility
import
print_arguments
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
# TODO(hui zhang): dynamic load
# TODO(hui zhang): dynamic load
class
U2
Tester_Hub
(
Trainer
):
class
U2
Infer
(
):
def
__init__
(
self
,
config
,
args
):
def
__init__
(
self
,
config
,
args
):
# super().__init__(config, args)
self
.
args
=
args
self
.
args
=
args
self
.
config
=
config
self
.
config
=
config
self
.
audio_file
=
args
.
audio_file
self
.
audio_file
=
args
.
audio_file
self
.
collate_fn_test
=
SpeechCollator
.
from_config
(
config
)
self
.
sr
=
config
.
collator
.
target_sample_rate
self
.
_text_featurizer
=
TextFeaturizer
(
self
.
preprocess_conf
=
config
.
collator
.
augmentation_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
config
.
collator
.
unit_type
,
unit_type
=
config
.
collator
.
unit_type
,
vocab_filepath
=
None
,
vocab_filepath
=
config
.
collator
.
vocab_filepath
,
spm_model_prefix
=
config
.
collator
.
spm_model_prefix
)
spm_model_prefix
=
config
.
collator
.
spm_model_prefix
)
def
setup_model
(
self
):
paddle
.
set_device
(
'gpu'
if
self
.
args
.
nprocs
>
0
else
'cpu'
)
config
=
self
.
config
model_conf
=
config
.
model
# model
model_conf
=
config
.
model
with
UpdateConfig
(
model_conf
):
with
UpdateConfig
(
model_conf
):
model_conf
.
input_dim
=
self
.
collate_fn_test
.
feature_size
model_conf
.
input_dim
=
config
.
collator
.
feat_dim
model_conf
.
output_dim
=
self
.
collate_fn_test
.
vocab_size
model_conf
.
output_dim
=
self
.
text_feature
.
vocab_size
model
=
U2Model
.
from_config
(
model_conf
)
model
=
U2Model
.
from_config
(
model_conf
)
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
)
logger
.
info
(
f
"
{
model
}
"
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
self
.
model
=
model
self
.
model
=
model
logger
.
info
(
"Setup model"
)
@
mp_tools
.
rank_zero_only
@
paddle
.
no_grad
()
def
test
(
self
):
self
.
model
.
eval
()
self
.
model
.
eval
()
cfg
=
self
.
config
.
decoding
audio_file
=
self
.
audio_file
collate_fn_test
=
self
.
collate_fn_test
audio
,
_
=
collate_fn_test
.
process_utterance
(
audio_file
=
audio_file
,
transcript
=
"Hello"
)
audio_len
=
audio
.
shape
[
0
]
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
audio_len
=
paddle
.
to_tensor
(
audio_len
)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
vocab_list
=
collate_fn_test
.
vocab_list
text_feature
=
self
.
collate_fn_test
.
text_feature
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio_len
,
text_feature
=
text_feature
,
decoding_method
=
cfg
.
decoding_method
,
lang_model_path
=
cfg
.
lang_model_path
,
beam_alpha
=
cfg
.
alpha
,
beam_beta
=
cfg
.
beta
,
beam_size
=
cfg
.
beam_size
,
cutoff_prob
=
cfg
.
cutoff_prob
,
cutoff_top_n
=
cfg
.
cutoff_top_n
,
num_processes
=
cfg
.
num_proc_bsearch
,
ctc_weight
=
cfg
.
ctc_weight
,
decoding_chunk_size
=
cfg
.
decoding_chunk_size
,
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
,
simulate_streaming
=
cfg
.
simulate_streaming
)
logger
.
info
(
"The result_transcripts: "
+
result_transcripts
[
0
][
0
])
def
run_test
(
self
):
self
.
resume
()
try
:
self
.
test
()
except
KeyboardInterrupt
:
sys
.
exit
(
-
1
)
def
setup
(
self
):
"""Setup the experiment.
"""
paddle
.
set_device
(
'gpu'
if
self
.
args
.
nprocs
>
0
else
'cpu'
)
#self.setup_output_dir()
# load model
#self.setup_checkpointer()
#self.setup_dataloader()
self
.
setup_model
()
self
.
iteration
=
0
self
.
epoch
=
0
def
resume
(
self
):
"""Resume from the checkpoint at checkpoints in the output
directory or load a specified checkpoint.
"""
params_path
=
self
.
args
.
checkpoint_path
+
".pdparams"
params_path
=
self
.
args
.
checkpoint_path
+
".pdparams"
model_dict
=
paddle
.
load
(
params_path
)
model_dict
=
paddle
.
load
(
params_path
)
self
.
model
.
set_state_dict
(
model_dict
)
self
.
model
.
set_state_dict
(
model_dict
)
def
run
(
self
):
check
(
args
.
audio_file
)
with
paddle
.
no_grad
():
# read
audio
,
sample_rate
=
soundfile
.
read
(
self
.
audio_file
,
dtype
=
"int16"
,
always_2d
=
True
)
if
sample_rate
!=
self
.
sr
:
logger
.
error
(
f
"sample rate error:
{
sample_rate
}
, need
{
self
.
sr
}
"
)
sys
.
exit
(
-
1
)
audio
=
audio
[:,
0
]
logger
.
info
(
f
"audio shape:
{
audio
.
shape
}
"
)
# fbank
feat
=
self
.
preprocessing
(
audio
,
**
self
.
preprocess_args
)
logger
.
info
(
f
"feat shape:
{
feat
.
shape
}
"
)
ilen
=
paddle
.
to_tensor
(
feat
.
shape
[
0
])
xs
=
paddle
.
to_tensor
(
feat
,
dtype
=
'float32'
).
unsqueeze
(
axis
=
0
)
cfg
=
self
.
config
.
decoding
result_transcripts
=
self
.
model
.
decode
(
xs
,
ilen
,
text_feature
=
self
.
text_feature
,
decoding_method
=
cfg
.
decoding_method
,
lang_model_path
=
cfg
.
lang_model_path
,
beam_alpha
=
cfg
.
alpha
,
beam_beta
=
cfg
.
beta
,
beam_size
=
cfg
.
beam_size
,
cutoff_prob
=
cfg
.
cutoff_prob
,
cutoff_top_n
=
cfg
.
cutoff_top_n
,
num_processes
=
cfg
.
num_proc_bsearch
,
ctc_weight
=
cfg
.
ctc_weight
,
decoding_chunk_size
=
cfg
.
decoding_chunk_size
,
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
,
simulate_streaming
=
cfg
.
simulate_streaming
)
rsl
=
result_transcripts
[
0
][
0
]
utt
=
Path
(
self
.
audio_file
).
name
logger
.
info
(
f
"hyp:
{
utt
}
{
result_transcripts
[
0
][
0
]
}
"
)
return
rsl
def
check
(
audio_file
):
def
check
(
audio_file
):
if
not
os
.
path
.
isfile
(
audio_file
):
print
(
"Please input the right audio file path"
)
sys
.
exit
(
-
1
)
logger
.
info
(
"checking the audio file format......"
)
logger
.
info
(
"checking the audio file format......"
)
try
:
try
:
sig
,
sample_rate
=
soundfile
.
read
(
audio_file
)
sig
,
sample_rate
=
soundfile
.
read
(
audio_file
)
...
@@ -144,15 +126,8 @@ def check(audio_file):
...
@@ -144,15 +126,8 @@ def check(audio_file):
logger
.
info
(
"The audio file format is right"
)
logger
.
info
(
"The audio file format is right"
)
def
main_sp
(
config
,
args
):
exp
=
U2Tester_Hub
(
config
,
args
)
with
exp
.
eval
():
exp
.
setup
()
exp
.
run_test
()
def
main
(
config
,
args
):
def
main
(
config
,
args
):
main_sp
(
config
,
args
)
U2Infer
(
config
,
args
).
run
(
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -163,25 +138,11 @@ if __name__ == "__main__":
...
@@ -163,25 +138,11 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--audio_file"
,
type
=
str
,
help
=
"path of the input audio file"
)
"--audio_file"
,
type
=
str
,
help
=
"path of the input audio file"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print_arguments
(
args
,
globals
())
if
not
os
.
path
.
isfile
(
args
.
audio_file
):
print
(
"Please input the right audio file path"
)
sys
.
exit
(
-
1
)
check
(
args
.
audio_file
)
# https://yaml.org/type/float.html
config
=
get_cfg_defaults
()
config
=
get_cfg_defaults
()
if
args
.
config
:
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
merge_from_file
(
args
.
config
)
if
args
.
opts
:
if
args
.
opts
:
config
.
merge_from_list
(
args
.
opts
)
config
.
merge_from_list
(
args
.
opts
)
config
.
freeze
()
config
.
freeze
()
print
(
config
)
main
(
config
,
args
)
if
args
.
dump_config
:
with
open
(
args
.
dump_config
,
'w'
)
as
f
:
print
(
config
,
file
=
f
)
# Setting for profiling
pr
=
cProfile
.
Profile
()
pr
.
runcall
(
main
,
config
,
args
)
pr
.
dump_stats
(
'test.profile'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录