提交 f29ae92a 编写于 作者: H huangyuxin

add unit test for deepspeech2online inference

上级 a9422260
...@@ -15,9 +15,12 @@ import unittest ...@@ -15,9 +15,12 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import pickle
import os
from paddle import inference
from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline
from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline
class TestDeepSpeech2ModelOnline(unittest.TestCase): class TestDeepSpeech2ModelOnline(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -182,5 +185,85 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): ...@@ -182,5 +185,85 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
paddle.allclose(final_state_c_box, final_state_c_box_chk), True) paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
def setUp(self):
export_prefix = "exp/deepspeech2_online/checkpoints/test_export"
os.makedirs( os.path.dirname(export_prefix), mode=0o755)
infer_model = DeepSpeech2InferModelOnline(
feat_size=161,
dict_size=4233,
num_conv_layers=2,
num_rnn_layers=5,
rnn_size=1024,
num_fc_layers=0,
fc_layers_size_list=[-1],
use_gru=False)
static_model = infer_model.export()
paddle.jit.save(static_model, export_prefix)
with open("test_data/static_ds2online_inputs.pickle", "rb") as f:
self.data_dict = pickle.load(f)
self.setup_model(export_prefix)
def setup_model(self, export_prefix):
deepspeech_config = inference.Config(
export_prefix + ".pdmodel",
export_prefix + ".pdiparams")
if ('CUDA_VISIBLE_DEVICES' in os.environ.keys() and os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
deepspeech_config.enable_use_gpu(100, 0)
deepspeech_config.enable_memory_optim()
deepspeech_predictor = inference.create_predictor(deepspeech_config)
self.predictor = deepspeech_predictor
def test_unit(self):
input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0])
audio_len_handle = self.predictor.get_input_handle(input_names[1])
h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3])
x_chunk = self.data_dict["audio_chunk"]
x_chunk_lens = self.data_dict["audio_chunk_lens"]
chunk_state_h_box = self.data_dict["chunk_state_h_box"]
chunk_state_c_box = self.data_dict["chunk_state_c_bos"]
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
audio_len_handle.reshape(x_chunk_lens.shape)
audio_len_handle.copy_from_cpu(x_chunk_lens)
h_box_handle.reshape(chunk_state_h_box.shape)
h_box_handle.copy_from_cpu(chunk_state_h_box)
c_box_handle.reshape(chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(chunk_state_c_box)
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(
output_names[0])
output_lens_handle = self.predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.predictor.get_output_handle(
output_names[3])
self.predictor.run()
output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu()
chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu()
return True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册