Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
f29ae92a
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
f29ae92a
编写于
2月 25, 2022
作者:
H
huangyuxin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add unit test for deepspeech2online inference
上级
a9422260
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
84 addition
and
1 deletion
+84
-1
tests/unit/asr/deepspeech2_online_model_test.py
tests/unit/asr/deepspeech2_online_model_test.py
+84
-1
tests/unit/asr/test_data/static_ds2online_inputs.pickle
tests/unit/asr/test_data/static_ds2online_inputs.pickle
+0
-0
未找到文件。
tests/unit/asr/deepspeech2_online_model_test.py
浏览文件 @
f29ae92a
...
@@ -15,9 +15,12 @@ import unittest
...
@@ -15,9 +15,12 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
pickle
import
os
from
paddle
import
inference
from
paddlespeech.s2t.models.ds2_online
import
DeepSpeech2ModelOnline
from
paddlespeech.s2t.models.ds2_online
import
DeepSpeech2ModelOnline
from
paddlespeech.s2t.models.ds2_online
import
DeepSpeech2InferModelOnline
class
TestDeepSpeech2ModelOnline
(
unittest
.
TestCase
):
class
TestDeepSpeech2ModelOnline
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -182,5 +185,85 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
...
@@ -182,5 +185,85 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
paddle
.
allclose
(
final_state_c_box
,
final_state_c_box_chk
),
True
)
paddle
.
allclose
(
final_state_c_box
,
final_state_c_box_chk
),
True
)
class
TestDeepSpeech2StaticModelOnline
(
unittest
.
TestCase
):
def
setUp
(
self
):
export_prefix
=
"exp/deepspeech2_online/checkpoints/test_export"
os
.
makedirs
(
os
.
path
.
dirname
(
export_prefix
),
mode
=
0o755
)
infer_model
=
DeepSpeech2InferModelOnline
(
feat_size
=
161
,
dict_size
=
4233
,
num_conv_layers
=
2
,
num_rnn_layers
=
5
,
rnn_size
=
1024
,
num_fc_layers
=
0
,
fc_layers_size_list
=
[
-
1
],
use_gru
=
False
)
static_model
=
infer_model
.
export
()
paddle
.
jit
.
save
(
static_model
,
export_prefix
)
with
open
(
"test_data/static_ds2online_inputs.pickle"
,
"rb"
)
as
f
:
self
.
data_dict
=
pickle
.
load
(
f
)
self
.
setup_model
(
export_prefix
)
def
setup_model
(
self
,
export_prefix
):
deepspeech_config
=
inference
.
Config
(
export_prefix
+
".pdmodel"
,
export_prefix
+
".pdiparams"
)
if
(
'CUDA_VISIBLE_DEVICES'
in
os
.
environ
.
keys
()
and
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
].
strip
()
!=
''
):
deepspeech_config
.
enable_use_gpu
(
100
,
0
)
deepspeech_config
.
enable_memory_optim
()
deepspeech_predictor
=
inference
.
create_predictor
(
deepspeech_config
)
self
.
predictor
=
deepspeech_predictor
def
test_unit
(
self
):
input_names
=
self
.
predictor
.
get_input_names
()
audio_handle
=
self
.
predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
predictor
.
get_input_handle
(
input_names
[
1
])
h_box_handle
=
self
.
predictor
.
get_input_handle
(
input_names
[
2
])
c_box_handle
=
self
.
predictor
.
get_input_handle
(
input_names
[
3
])
x_chunk
=
self
.
data_dict
[
"audio_chunk"
]
x_chunk_lens
=
self
.
data_dict
[
"audio_chunk_lens"
]
chunk_state_h_box
=
self
.
data_dict
[
"chunk_state_h_box"
]
chunk_state_c_box
=
self
.
data_dict
[
"chunk_state_c_bos"
]
audio_handle
.
reshape
(
x_chunk
.
shape
)
audio_handle
.
copy_from_cpu
(
x_chunk
)
audio_len_handle
.
reshape
(
x_chunk_lens
.
shape
)
audio_len_handle
.
copy_from_cpu
(
x_chunk_lens
)
h_box_handle
.
reshape
(
chunk_state_h_box
.
shape
)
h_box_handle
.
copy_from_cpu
(
chunk_state_h_box
)
c_box_handle
.
reshape
(
chunk_state_c_box
.
shape
)
c_box_handle
.
copy_from_cpu
(
chunk_state_c_box
)
output_names
=
self
.
predictor
.
get_output_names
()
output_handle
=
self
.
predictor
.
get_output_handle
(
output_names
[
0
])
output_lens_handle
=
self
.
predictor
.
get_output_handle
(
output_names
[
1
])
output_state_h_handle
=
self
.
predictor
.
get_output_handle
(
output_names
[
2
])
output_state_c_handle
=
self
.
predictor
.
get_output_handle
(
output_names
[
3
])
self
.
predictor
.
run
()
output_chunk_probs
=
output_handle
.
copy_to_cpu
()
output_chunk_lens
=
output_lens_handle
.
copy_to_cpu
()
chunk_state_h_box
=
output_state_h_handle
.
copy_to_cpu
()
chunk_state_c_box
=
output_state_c_handle
.
copy_to_cpu
()
return
True
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
tests/unit/asr/test_data/static_ds2online_inputs.pickle
0 → 100644
浏览文件 @
f29ae92a
文件已添加
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录