From b1ef434983e874130d5ff8be3ec4f37740b2ba18 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 17 May 2022 17:38:56 +0800 Subject: [PATCH] update the max len compute method, test=doc --- paddlespeech/cli/asr/infer.py | 21 +++++++++++++++------ paddlespeech/s2t/modules/encoder.py | 2 +- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 0569653a..863a933f 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -187,13 +187,7 @@ class ASRExecutor(BaseExecutor): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.config.decode.decoding_method = decode_method - self.max_len = 5000 - if self.config.encoder_conf.get("max_len", None): - self.max_len = self.config.encoder_conf.max_len - logger.info(f"max len: {self.max_len}") - # we assumen that the subsample rate is 4 and every frame step is 40ms - self.max_len = 40 * self.max_len / 1000 else: raise Exception("wrong type") model_name = model_type[:model_type.rindex( @@ -208,6 +202,21 @@ class ASRExecutor(BaseExecutor): model_dict = paddle.load(self.ckpt_path) self.model.set_state_dict(model_dict) + # compute the max len limit + if "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + # in transformer like model, we may use the subsample rate cnn network + subsample_rate = self.model.subsampling_rate() + frame_shift_ms = self.config.preprocess_config.process[0][ + 'n_shift'] / self.config.preprocess_config.process[0]['fs'] + max_len = self.model.encoder.embed.pos_enc.max_len + + if self.config.encoder_conf.get("max_len", None): + max_len = self.config.encoder_conf.max_len + + self.max_len = frame_shift_ms * max_len * subsample_rate + logger.info( + f"The asr server limit max duration len: {self.max_len}") + def preprocess(self, model_type: str, input: Union[str, os.PathLike]): """ Input preprocess and return paddle.Tensor stored in self.input. diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 8266a2bc..669a12d6 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -332,7 +332,7 @@ class BaseEncoder(nn.Layer): # fake mask, just for jit script and compatibility with `forward` api masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) masks = masks.unsqueeze(1) - return ys, masks, offset + return ys, masks class TransformerEncoder(BaseEncoder): -- GitLab