From 0ea39f837b2a8adfdd68ff09df5e87caebcc583c Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 17 May 2022 01:00:46 +0800 Subject: [PATCH] add asr time limt configuration, test=doc --- paddlespeech/cli/asr/infer.py | 14 ++- paddlespeech/s2t/modules/decoder.py | 33 +++---- paddlespeech/s2t/modules/embedding.py | 9 +- paddlespeech/s2t/modules/encoder.py | 96 ++++++++++--------- .../server/engine/asr/python/asr_engine.py | 35 ++++--- 5 files changed, 104 insertions(+), 83 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 23029cfb..0569653a 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -187,6 +187,13 @@ class ASRExecutor(BaseExecutor): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.config.decode.decoding_method = decode_method + self.max_len = 5000 + if self.config.encoder_conf.get("max_len", None): + self.max_len = self.config.encoder_conf.max_len + + logger.info(f"max len: {self.max_len}") + # we assumen that the subsample rate is 4 and every frame step is 40ms + self.max_len = 40 * self.max_len / 1000 else: raise Exception("wrong type") model_name = model_type[:model_type.rindex( @@ -352,9 +359,10 @@ class ASRExecutor(BaseExecutor): audio, audio_sample_rate = soundfile.read( audio_file, dtype="int16", always_2d=True) audio_duration = audio.shape[0] / audio_sample_rate - max_duration = 50.0 - if audio_duration >= max_duration: - logger.error("Please input audio file less then 50 seconds.\n") + if audio_duration > self.max_len: + logger.error( + f"Please input audio file less then {self.max_len} seconds.\n" + ) return False except Exception as e: logger.exception(e) diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index 3a851ec6..42ac119b 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -62,21 +62,21 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): False: x -> x + att(x) """ - def __init__( - self, - vocab_size: int, - encoder_output_size: int, - attention_heads: int=4, - linear_units: int=2048, - num_blocks: int=6, - dropout_rate: float=0.1, - positional_dropout_rate: float=0.1, - self_attention_dropout_rate: float=0.0, - src_attention_dropout_rate: float=0.0, - input_layer: str="embed", - use_output_layer: bool=True, - normalize_before: bool=True, - concat_after: bool=False, ): + def __init__(self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int=4, + linear_units: int=2048, + num_blocks: int=6, + dropout_rate: float=0.1, + positional_dropout_rate: float=0.1, + self_attention_dropout_rate: float=0.0, + src_attention_dropout_rate: float=0.0, + input_layer: str="embed", + use_output_layer: bool=True, + normalize_before: bool=True, + concat_after: bool=False, + max_len: int=5000): assert check_argument_types() @@ -87,7 +87,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): if input_layer == "embed": self.embed = nn.Sequential( Embedding(vocab_size, attention_dim), - PositionalEncoding(attention_dim, positional_dropout_rate), ) + PositionalEncoding( + attention_dim, positional_dropout_rate, max_len=max_len), ) else: raise ValueError(f"only 'embed' is supported: {input_layer}") diff --git a/paddlespeech/s2t/modules/embedding.py b/paddlespeech/s2t/modules/embedding.py index 5d4e9175..596f61b7 100644 --- a/paddlespeech/s2t/modules/embedding.py +++ b/paddlespeech/s2t/modules/embedding.py @@ -112,7 +112,9 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface): paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...) """ T = x.shape[1] - assert offset + x.shape[1] < self.max_len + assert offset + x.shape[ + 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( + offset, x.shape[1], self.max_len) #TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + T] x = x * self.xscale + pos_emb @@ -148,6 +150,7 @@ class RelPositionalEncoding(PositionalEncoding): max_len (int, optional): [Maximum input length.]. Defaults to 5000. """ super().__init__(d_model, dropout_rate, max_len, reverse=True) + logger.info(f"max len: {max_len}") def forward(self, x: paddle.Tensor, offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: @@ -158,7 +161,9 @@ class RelPositionalEncoding(PositionalEncoding): paddle.Tensor: Encoded tensor (batch, time, `*`). paddle.Tensor: Positional embedding tensor (1, time, `*`). """ - assert offset + x.shape[1] < self.max_len + assert offset + x.shape[ + 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( + offset, x.shape[1], self.max_len) x = x * self.xscale #TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + x.shape[1]] diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index c843c0e2..8266a2bc 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -47,24 +47,24 @@ __all__ = ["BaseEncoder", 'TransformerEncoder', "ConformerEncoder"] class BaseEncoder(nn.Layer): - def __init__( - self, - input_size: int, - output_size: int=256, - attention_heads: int=4, - linear_units: int=2048, - num_blocks: int=6, - dropout_rate: float=0.1, - positional_dropout_rate: float=0.1, - attention_dropout_rate: float=0.0, - input_layer: str="conv2d", - pos_enc_layer_type: str="abs_pos", - normalize_before: bool=True, - concat_after: bool=False, - static_chunk_size: int=0, - use_dynamic_chunk: bool=False, - global_cmvn: paddle.nn.Layer=None, - use_dynamic_left_chunk: bool=False, ): + def __init__(self, + input_size: int, + output_size: int=256, + attention_heads: int=4, + linear_units: int=2048, + num_blocks: int=6, + dropout_rate: float=0.1, + positional_dropout_rate: float=0.1, + attention_dropout_rate: float=0.0, + input_layer: str="conv2d", + pos_enc_layer_type: str="abs_pos", + normalize_before: bool=True, + concat_after: bool=False, + static_chunk_size: int=0, + use_dynamic_chunk: bool=False, + global_cmvn: paddle.nn.Layer=None, + use_dynamic_left_chunk: bool=False, + max_len: int=5000): """ Args: input_size (int): input dim, d_feature @@ -127,7 +127,9 @@ class BaseEncoder(nn.Layer): odim=output_size, dropout_rate=dropout_rate, pos_enc_class=pos_enc_class( - d_model=output_size, dropout_rate=positional_dropout_rate), ) + d_model=output_size, + dropout_rate=positional_dropout_rate, + max_len=max_len), ) self.normalize_before = normalize_before self.after_norm = LayerNorm(output_size, epsilon=1e-12) @@ -330,7 +332,7 @@ class BaseEncoder(nn.Layer): # fake mask, just for jit script and compatibility with `forward` api masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) masks = masks.unsqueeze(1) - return ys, masks + return ys, masks, offset class TransformerEncoder(BaseEncoder): @@ -415,32 +417,32 @@ class TransformerEncoder(BaseEncoder): class ConformerEncoder(BaseEncoder): """Conformer encoder module.""" - def __init__( - self, - input_size: int, - output_size: int=256, - attention_heads: int=4, - linear_units: int=2048, - num_blocks: int=6, - dropout_rate: float=0.1, - positional_dropout_rate: float=0.1, - attention_dropout_rate: float=0.0, - input_layer: str="conv2d", - pos_enc_layer_type: str="rel_pos", - normalize_before: bool=True, - concat_after: bool=False, - static_chunk_size: int=0, - use_dynamic_chunk: bool=False, - global_cmvn: nn.Layer=None, - use_dynamic_left_chunk: bool=False, - positionwise_conv_kernel_size: int=1, - macaron_style: bool=True, - selfattention_layer_type: str="rel_selfattn", - activation_type: str="swish", - use_cnn_module: bool=True, - cnn_module_kernel: int=15, - causal: bool=False, - cnn_module_norm: str="batch_norm", ): + def __init__(self, + input_size: int, + output_size: int=256, + attention_heads: int=4, + linear_units: int=2048, + num_blocks: int=6, + dropout_rate: float=0.1, + positional_dropout_rate: float=0.1, + attention_dropout_rate: float=0.0, + input_layer: str="conv2d", + pos_enc_layer_type: str="rel_pos", + normalize_before: bool=True, + concat_after: bool=False, + static_chunk_size: int=0, + use_dynamic_chunk: bool=False, + global_cmvn: nn.Layer=None, + use_dynamic_left_chunk: bool=False, + positionwise_conv_kernel_size: int=1, + macaron_style: bool=True, + selfattention_layer_type: str="rel_selfattn", + activation_type: str="swish", + use_cnn_module: bool=True, + cnn_module_kernel: int=15, + causal: bool=False, + cnn_module_norm: str="batch_norm", + max_len: int=5000): """Construct ConformerEncoder Args: input_size to use_dynamic_chunk, see in BaseEncoder @@ -464,7 +466,7 @@ class ConformerEncoder(BaseEncoder): attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, concat_after, static_chunk_size, use_dynamic_chunk, global_cmvn, - use_dynamic_left_chunk) + use_dynamic_left_chunk, max_len) activation = get_activation(activation_type) # self-attention module definition diff --git a/paddlespeech/server/engine/asr/python/asr_engine.py b/paddlespeech/server/engine/asr/python/asr_engine.py index e76c49a7..d60a5fea 100644 --- a/paddlespeech/server/engine/asr/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/python/asr_engine.py @@ -78,21 +78,26 @@ class ASREngine(BaseEngine): Args: audio_data (bytes): base64.b64decode """ - if self.executor._check( - io.BytesIO(audio_data), self.config.sample_rate, - self.config.force_yes): - logger.info("start run asr engine") - self.executor.preprocess(self.config.model, io.BytesIO(audio_data)) - st = time.time() - self.executor.infer(self.config.model) - infer_time = time.time() - st - self.output = self.executor.postprocess() # Retrieve result of asr. - else: - logger.info("file check failed!") - self.output = None - - logger.info("inference time: {}".format(infer_time)) - logger.info("asr engine type: python") + try: + if self.executor._check( + io.BytesIO(audio_data), self.config.sample_rate, + self.config.force_yes): + logger.info("start run asr engine") + self.executor.preprocess(self.config.model, + io.BytesIO(audio_data)) + st = time.time() + self.executor.infer(self.config.model) + infer_time = time.time() - st + self.output = self.executor.postprocess( + ) # Retrieve result of asr. + else: + logger.info("file check failed!") + self.output = None + + logger.info("inference time: {}".format(infer_time)) + logger.info("asr engine type: python") + except Exception as e: + logger.info(e) def postprocess(self): """postprocess -- GitLab