diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 0569653a251135bf0f7a20381b7429deeba4ca80..863a933f2a7446b920228fe2f5fa6e0294b50d5d 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -187,13 +187,7 @@ class ASRExecutor(BaseExecutor): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.config.decode.decoding_method = decode_method - self.max_len = 5000 - if self.config.encoder_conf.get("max_len", None): - self.max_len = self.config.encoder_conf.max_len - logger.info(f"max len: {self.max_len}") - # we assumen that the subsample rate is 4 and every frame step is 40ms - self.max_len = 40 * self.max_len / 1000 else: raise Exception("wrong type") model_name = model_type[:model_type.rindex( @@ -208,6 +202,21 @@ class ASRExecutor(BaseExecutor): model_dict = paddle.load(self.ckpt_path) self.model.set_state_dict(model_dict) + # compute the max len limit + if "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + # in transformer like model, we may use the subsample rate cnn network + subsample_rate = self.model.subsampling_rate() + frame_shift_ms = self.config.preprocess_config.process[0][ + 'n_shift'] / self.config.preprocess_config.process[0]['fs'] + max_len = self.model.encoder.embed.pos_enc.max_len + + if self.config.encoder_conf.get("max_len", None): + max_len = self.config.encoder_conf.max_len + + self.max_len = frame_shift_ms * max_len * subsample_rate + logger.info( + f"The asr server limit max duration len: {self.max_len}") + def preprocess(self, model_type: str, input: Union[str, os.PathLike]): """ Input preprocess and return paddle.Tensor stored in self.input. diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 8266a2bc5a7da016363cdc07f95639ddb37606a2..669a12d656947f0446eba3d228832964e8c1d7b0 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -332,7 +332,7 @@ class BaseEncoder(nn.Layer): # fake mask, just for jit script and compatibility with `forward` api masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) masks = masks.unsqueeze(1) - return ys, masks, offset + return ys, masks class TransformerEncoder(BaseEncoder):