diff --git a/paddlespeech/s2t/io/sampler.py b/paddlespeech/s2t/io/sampler.py index 89752bb9fdb98faecc0ccc5b8f59ea1f09efc8b6..ac55af1236f11d175e9e7717220980cf95c7d79b 100644 --- a/paddlespeech/s2t/io/sampler.py +++ b/paddlespeech/s2t/io/sampler.py @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): """ rng = np.random.RandomState(epoch) shift_len = rng.randint(0, batch_size - 1) - batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) + batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] assert clipped is False diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index f7b05714ef6e9961a1bff79027015889815d5811..999723e5100309976c1b89cbf256ac106d8829e6 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -33,8 +33,6 @@ from paddlespeech.s2t.modules.decoder import TransformerDecoder from paddlespeech.s2t.modules.encoder import ConformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder from paddlespeech.s2t.modules.loss import LabelSmoothingLoss -from paddlespeech.s2t.modules.mask import mask_finished_preds -from paddlespeech.s2t.modules.mask import mask_finished_scores from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import layer_tools @@ -291,7 +289,7 @@ class U2STBaseModel(nn.Layer): device = speech.place # Let's assume B = batch_size and N = beam_size - # 1. Encoder and init hypothesis + # 1. Encoder and init hypothesis encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, diff --git a/paddlespeech/t2s/modules/transformer/repeat.py b/paddlespeech/t2s/modules/transformer/repeat.py index 2073a78b9330201dba15b42badf77cee0caceab1..1e946adf7e469fd6c05c2a8c8d9e6f16f638524e 100644 --- a/paddlespeech/t2s/modules/transformer/repeat.py +++ b/paddlespeech/t2s/modules/transformer/repeat.py @@ -36,4 +36,4 @@ def repeat(N, fn): Returns: MultiSequential: Repeated model instance. """ - return MultiSequential(*[fn(n) for n in range(N)]) + return MultiSequential(* [fn(n) for n in range(N)]) diff --git a/tests/unit/asr/deepspeech2_online_model_test.py b/tests/unit/asr/deepspeech2_online_model_test.py index d26e5b1532f5d66cb27f4520d56f742052f49306..f23c49263ec033280dc9b1ed0ad1b74b68d714c1 100644 --- a/tests/unit/asr/deepspeech2_online_model_test.py +++ b/tests/unit/asr/deepspeech2_online_model_test.py @@ -11,16 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import pickle import unittest import numpy as np import paddle -import pickle -import os from paddle import inference -from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline +from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline + class TestDeepSpeech2ModelOnline(unittest.TestCase): def setUp(self): @@ -185,15 +186,12 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): paddle.allclose(final_state_c_box, final_state_c_box_chk), True) - - class TestDeepSpeech2StaticModelOnline(unittest.TestCase): - def setUp(self): export_prefix = "exp/deepspeech2_online/checkpoints/test_export" if not os.path.exists(os.path.dirname(export_prefix)): os.makedirs(os.path.dirname(export_prefix), mode=0o755) - infer_model = DeepSpeech2InferModelOnline( + infer_model = DeepSpeech2InferModelOnline( feat_size=161, dict_size=4233, num_conv_layers=2, @@ -207,27 +205,25 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase): with open("test_data/static_ds2online_inputs.pickle", "rb") as f: self.data_dict = pickle.load(f) - + self.setup_model(export_prefix) - def setup_model(self, export_prefix): - deepspeech_config = inference.Config( - export_prefix + ".pdmodel", - export_prefix + ".pdiparams") - if ('CUDA_VISIBLE_DEVICES' in os.environ.keys() and os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''): + deepspeech_config = inference.Config(export_prefix + ".pdmodel", + export_prefix + ".pdiparams") + if ('CUDA_VISIBLE_DEVICES' in os.environ.keys() and + os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''): deepspeech_config.enable_use_gpu(100, 0) deepspeech_config.enable_memory_optim() deepspeech_predictor = inference.create_predictor(deepspeech_config) self.predictor = deepspeech_predictor - + def test_unit(self): input_names = self.predictor.get_input_names() audio_handle = self.predictor.get_input_handle(input_names[0]) audio_len_handle = self.predictor.get_input_handle(input_names[1]) h_box_handle = self.predictor.get_input_handle(input_names[2]) c_box_handle = self.predictor.get_input_handle(input_names[3]) - x_chunk = self.data_dict["audio_chunk"] x_chunk_lens = self.data_dict["audio_chunk_lens"] @@ -246,13 +242,9 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase): c_box_handle.reshape(chunk_state_c_box.shape) c_box_handle.copy_from_cpu(chunk_state_c_box) - - output_names = self.predictor.get_output_names() - output_handle = self.predictor.get_output_handle( - output_names[0]) - output_lens_handle = self.predictor.get_output_handle( - output_names[1]) + output_handle = self.predictor.get_output_handle(output_names[0]) + output_lens_handle = self.predictor.get_output_handle(output_names[1]) output_state_h_handle = self.predictor.get_output_handle( output_names[2]) output_state_c_handle = self.predictor.get_output_handle( @@ -264,7 +256,7 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase): chunk_state_h_box = output_state_h_handle.copy_to_cpu() chunk_state_c_box = output_state_c_handle.copy_to_cpu() return True - + if __name__ == '__main__': unittest.main()