diff --git a/tests/unit/asr/deepspeech2_online_model_test.py b/tests/unit/asr/deepspeech2_online_model_test.py index f623c5acd5066795cfa1cae43c622254a5ac88e0..d26e5b1532f5d66cb27f4520d56f742052f49306 100644 --- a/tests/unit/asr/deepspeech2_online_model_test.py +++ b/tests/unit/asr/deepspeech2_online_model_test.py @@ -15,9 +15,12 @@ import unittest import numpy as np import paddle +import pickle +import os +from paddle import inference from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline - +from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline class TestDeepSpeech2ModelOnline(unittest.TestCase): def setUp(self): @@ -182,5 +185,86 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): paddle.allclose(final_state_c_box, final_state_c_box_chk), True) + + +class TestDeepSpeech2StaticModelOnline(unittest.TestCase): + + def setUp(self): + export_prefix = "exp/deepspeech2_online/checkpoints/test_export" + if not os.path.exists(os.path.dirname(export_prefix)): + os.makedirs(os.path.dirname(export_prefix), mode=0o755) + infer_model = DeepSpeech2InferModelOnline( + feat_size=161, + dict_size=4233, + num_conv_layers=2, + num_rnn_layers=5, + rnn_size=1024, + num_fc_layers=0, + fc_layers_size_list=[-1], + use_gru=False) + static_model = infer_model.export() + paddle.jit.save(static_model, export_prefix) + + with open("test_data/static_ds2online_inputs.pickle", "rb") as f: + self.data_dict = pickle.load(f) + + self.setup_model(export_prefix) + + + def setup_model(self, export_prefix): + deepspeech_config = inference.Config( + export_prefix + ".pdmodel", + export_prefix + ".pdiparams") + if ('CUDA_VISIBLE_DEVICES' in os.environ.keys() and os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''): + deepspeech_config.enable_use_gpu(100, 0) + deepspeech_config.enable_memory_optim() + deepspeech_predictor = inference.create_predictor(deepspeech_config) + self.predictor = deepspeech_predictor + + def test_unit(self): + input_names = self.predictor.get_input_names() + audio_handle = self.predictor.get_input_handle(input_names[0]) + audio_len_handle = self.predictor.get_input_handle(input_names[1]) + h_box_handle = self.predictor.get_input_handle(input_names[2]) + c_box_handle = self.predictor.get_input_handle(input_names[3]) + + + x_chunk = self.data_dict["audio_chunk"] + x_chunk_lens = self.data_dict["audio_chunk_lens"] + chunk_state_h_box = self.data_dict["chunk_state_h_box"] + chunk_state_c_box = self.data_dict["chunk_state_c_bos"] + + audio_handle.reshape(x_chunk.shape) + audio_handle.copy_from_cpu(x_chunk) + + audio_len_handle.reshape(x_chunk_lens.shape) + audio_len_handle.copy_from_cpu(x_chunk_lens) + + h_box_handle.reshape(chunk_state_h_box.shape) + h_box_handle.copy_from_cpu(chunk_state_h_box) + + c_box_handle.reshape(chunk_state_c_box.shape) + c_box_handle.copy_from_cpu(chunk_state_c_box) + + + + output_names = self.predictor.get_output_names() + output_handle = self.predictor.get_output_handle( + output_names[0]) + output_lens_handle = self.predictor.get_output_handle( + output_names[1]) + output_state_h_handle = self.predictor.get_output_handle( + output_names[2]) + output_state_c_handle = self.predictor.get_output_handle( + output_names[3]) + self.predictor.run() + + output_chunk_probs = output_handle.copy_to_cpu() + output_chunk_lens = output_lens_handle.copy_to_cpu() + chunk_state_h_box = output_state_h_handle.copy_to_cpu() + chunk_state_c_box = output_state_c_handle.copy_to_cpu() + return True + + if __name__ == '__main__': unittest.main() diff --git a/tests/unit/asr/deepspeech2_online_model_test.sh b/tests/unit/asr/deepspeech2_online_model_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..629238fd04716a5844156898c8697e9b0e158c9f --- /dev/null +++ b/tests/unit/asr/deepspeech2_online_model_test.sh @@ -0,0 +1,3 @@ +mkdir -p ./test_data +wget -P ./test_data https://paddlespeech.bj.bcebos.com/datasets/unit_test/asr/static_ds2online_inputs.pickle +python deepspeech2_online_model_test.py