提交 b1ef4349 编写于 作者: X xiongxinlei

update the max len compute method, test=doc

上级 0ea39f83
......@@ -187,13 +187,7 @@ class ASRExecutor(BaseExecutor):
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.config.decode.decoding_method = decode_method
self.max_len = 5000
if self.config.encoder_conf.get("max_len", None):
self.max_len = self.config.encoder_conf.max_len
logger.info(f"max len: {self.max_len}")
# we assumen that the subsample rate is 4 and every frame step is 40ms
self.max_len = 40 * self.max_len / 1000
else:
raise Exception("wrong type")
model_name = model_type[:model_type.rindex(
......@@ -208,6 +202,21 @@ class ASRExecutor(BaseExecutor):
model_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(model_dict)
# compute the max len limit
if "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
# in transformer like model, we may use the subsample rate cnn network
subsample_rate = self.model.subsampling_rate()
frame_shift_ms = self.config.preprocess_config.process[0][
'n_shift'] / self.config.preprocess_config.process[0]['fs']
max_len = self.model.encoder.embed.pos_enc.max_len
if self.config.encoder_conf.get("max_len", None):
max_len = self.config.encoder_conf.max_len
self.max_len = frame_shift_ms * max_len * subsample_rate
logger.info(
f"The asr server limit max duration len: {self.max_len}")
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
"""
Input preprocess and return paddle.Tensor stored in self.input.
......
......@@ -332,7 +332,7 @@ class BaseEncoder(nn.Layer):
# fake mask, just for jit script and compatibility with `forward` api
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
return ys, masks, offset
return ys, masks
class TransformerEncoder(BaseEncoder):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册