提交 0ea39f83 编写于 作者: X xiongxinlei

add asr time limt configuration, test=doc

上级 bd66c7a8
...@@ -187,6 +187,13 @@ class ASRExecutor(BaseExecutor): ...@@ -187,6 +187,13 @@ class ASRExecutor(BaseExecutor):
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
self.max_len = 5000
if self.config.encoder_conf.get("max_len", None):
self.max_len = self.config.encoder_conf.max_len
logger.info(f"max len: {self.max_len}")
# we assumen that the subsample rate is 4 and every frame step is 40ms
self.max_len = 40 * self.max_len / 1000
else: else:
raise Exception("wrong type") raise Exception("wrong type")
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
...@@ -352,9 +359,10 @@ class ASRExecutor(BaseExecutor): ...@@ -352,9 +359,10 @@ class ASRExecutor(BaseExecutor):
audio, audio_sample_rate = soundfile.read( audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True) audio_file, dtype="int16", always_2d=True)
audio_duration = audio.shape[0] / audio_sample_rate audio_duration = audio.shape[0] / audio_sample_rate
max_duration = 50.0 if audio_duration > self.max_len:
if audio_duration >= max_duration: logger.error(
logger.error("Please input audio file less then 50 seconds.\n") f"Please input audio file less then {self.max_len} seconds.\n"
)
return False return False
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
......
...@@ -62,21 +62,21 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ...@@ -62,21 +62,21 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
False: x -> x + att(x) False: x -> x + att(x)
""" """
def __init__( def __init__(self,
self, vocab_size: int,
vocab_size: int, encoder_output_size: int,
encoder_output_size: int, attention_heads: int=4,
attention_heads: int=4, linear_units: int=2048,
linear_units: int=2048, num_blocks: int=6,
num_blocks: int=6, dropout_rate: float=0.1,
dropout_rate: float=0.1, positional_dropout_rate: float=0.1,
positional_dropout_rate: float=0.1, self_attention_dropout_rate: float=0.0,
self_attention_dropout_rate: float=0.0, src_attention_dropout_rate: float=0.0,
src_attention_dropout_rate: float=0.0, input_layer: str="embed",
input_layer: str="embed", use_output_layer: bool=True,
use_output_layer: bool=True, normalize_before: bool=True,
normalize_before: bool=True, concat_after: bool=False,
concat_after: bool=False, ): max_len: int=5000):
assert check_argument_types() assert check_argument_types()
...@@ -87,7 +87,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ...@@ -87,7 +87,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
if input_layer == "embed": if input_layer == "embed":
self.embed = nn.Sequential( self.embed = nn.Sequential(
Embedding(vocab_size, attention_dim), Embedding(vocab_size, attention_dim),
PositionalEncoding(attention_dim, positional_dropout_rate), ) PositionalEncoding(
attention_dim, positional_dropout_rate, max_len=max_len), )
else: else:
raise ValueError(f"only 'embed' is supported: {input_layer}") raise ValueError(f"only 'embed' is supported: {input_layer}")
......
...@@ -112,7 +112,9 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface): ...@@ -112,7 +112,9 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...) paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
""" """
T = x.shape[1] T = x.shape[1]
assert offset + x.shape[1] < self.max_len assert offset + x.shape[
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len)
#TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor #TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + T] pos_emb = self.pe[:, offset:offset + T]
x = x * self.xscale + pos_emb x = x * self.xscale + pos_emb
...@@ -148,6 +150,7 @@ class RelPositionalEncoding(PositionalEncoding): ...@@ -148,6 +150,7 @@ class RelPositionalEncoding(PositionalEncoding):
max_len (int, optional): [Maximum input length.]. Defaults to 5000. max_len (int, optional): [Maximum input length.]. Defaults to 5000.
""" """
super().__init__(d_model, dropout_rate, max_len, reverse=True) super().__init__(d_model, dropout_rate, max_len, reverse=True)
logger.info(f"max len: {max_len}")
def forward(self, x: paddle.Tensor, def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
...@@ -158,7 +161,9 @@ class RelPositionalEncoding(PositionalEncoding): ...@@ -158,7 +161,9 @@ class RelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`). paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`). paddle.Tensor: Positional embedding tensor (1, time, `*`).
""" """
assert offset + x.shape[1] < self.max_len assert offset + x.shape[
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len)
x = x * self.xscale x = x * self.xscale
#TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor #TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + x.shape[1]] pos_emb = self.pe[:, offset:offset + x.shape[1]]
......
...@@ -47,24 +47,24 @@ __all__ = ["BaseEncoder", 'TransformerEncoder', "ConformerEncoder"] ...@@ -47,24 +47,24 @@ __all__ = ["BaseEncoder", 'TransformerEncoder', "ConformerEncoder"]
class BaseEncoder(nn.Layer): class BaseEncoder(nn.Layer):
def __init__( def __init__(self,
self, input_size: int,
input_size: int, output_size: int=256,
output_size: int=256, attention_heads: int=4,
attention_heads: int=4, linear_units: int=2048,
linear_units: int=2048, num_blocks: int=6,
num_blocks: int=6, dropout_rate: float=0.1,
dropout_rate: float=0.1, positional_dropout_rate: float=0.1,
positional_dropout_rate: float=0.1, attention_dropout_rate: float=0.0,
attention_dropout_rate: float=0.0, input_layer: str="conv2d",
input_layer: str="conv2d", pos_enc_layer_type: str="abs_pos",
pos_enc_layer_type: str="abs_pos", normalize_before: bool=True,
normalize_before: bool=True, concat_after: bool=False,
concat_after: bool=False, static_chunk_size: int=0,
static_chunk_size: int=0, use_dynamic_chunk: bool=False,
use_dynamic_chunk: bool=False, global_cmvn: paddle.nn.Layer=None,
global_cmvn: paddle.nn.Layer=None, use_dynamic_left_chunk: bool=False,
use_dynamic_left_chunk: bool=False, ): max_len: int=5000):
""" """
Args: Args:
input_size (int): input dim, d_feature input_size (int): input dim, d_feature
...@@ -127,7 +127,9 @@ class BaseEncoder(nn.Layer): ...@@ -127,7 +127,9 @@ class BaseEncoder(nn.Layer):
odim=output_size, odim=output_size,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
pos_enc_class=pos_enc_class( pos_enc_class=pos_enc_class(
d_model=output_size, dropout_rate=positional_dropout_rate), ) d_model=output_size,
dropout_rate=positional_dropout_rate,
max_len=max_len), )
self.normalize_before = normalize_before self.normalize_before = normalize_before
self.after_norm = LayerNorm(output_size, epsilon=1e-12) self.after_norm = LayerNorm(output_size, epsilon=1e-12)
...@@ -330,7 +332,7 @@ class BaseEncoder(nn.Layer): ...@@ -330,7 +332,7 @@ class BaseEncoder(nn.Layer):
# fake mask, just for jit script and compatibility with `forward` api # fake mask, just for jit script and compatibility with `forward` api
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1) masks = masks.unsqueeze(1)
return ys, masks return ys, masks, offset
class TransformerEncoder(BaseEncoder): class TransformerEncoder(BaseEncoder):
...@@ -415,32 +417,32 @@ class TransformerEncoder(BaseEncoder): ...@@ -415,32 +417,32 @@ class TransformerEncoder(BaseEncoder):
class ConformerEncoder(BaseEncoder): class ConformerEncoder(BaseEncoder):
"""Conformer encoder module.""" """Conformer encoder module."""
def __init__( def __init__(self,
self, input_size: int,
input_size: int, output_size: int=256,
output_size: int=256, attention_heads: int=4,
attention_heads: int=4, linear_units: int=2048,
linear_units: int=2048, num_blocks: int=6,
num_blocks: int=6, dropout_rate: float=0.1,
dropout_rate: float=0.1, positional_dropout_rate: float=0.1,
positional_dropout_rate: float=0.1, attention_dropout_rate: float=0.0,
attention_dropout_rate: float=0.0, input_layer: str="conv2d",
input_layer: str="conv2d", pos_enc_layer_type: str="rel_pos",
pos_enc_layer_type: str="rel_pos", normalize_before: bool=True,
normalize_before: bool=True, concat_after: bool=False,
concat_after: bool=False, static_chunk_size: int=0,
static_chunk_size: int=0, use_dynamic_chunk: bool=False,
use_dynamic_chunk: bool=False, global_cmvn: nn.Layer=None,
global_cmvn: nn.Layer=None, use_dynamic_left_chunk: bool=False,
use_dynamic_left_chunk: bool=False, positionwise_conv_kernel_size: int=1,
positionwise_conv_kernel_size: int=1, macaron_style: bool=True,
macaron_style: bool=True, selfattention_layer_type: str="rel_selfattn",
selfattention_layer_type: str="rel_selfattn", activation_type: str="swish",
activation_type: str="swish", use_cnn_module: bool=True,
use_cnn_module: bool=True, cnn_module_kernel: int=15,
cnn_module_kernel: int=15, causal: bool=False,
causal: bool=False, cnn_module_norm: str="batch_norm",
cnn_module_norm: str="batch_norm", ): max_len: int=5000):
"""Construct ConformerEncoder """Construct ConformerEncoder
Args: Args:
input_size to use_dynamic_chunk, see in BaseEncoder input_size to use_dynamic_chunk, see in BaseEncoder
...@@ -464,7 +466,7 @@ class ConformerEncoder(BaseEncoder): ...@@ -464,7 +466,7 @@ class ConformerEncoder(BaseEncoder):
attention_dropout_rate, input_layer, attention_dropout_rate, input_layer,
pos_enc_layer_type, normalize_before, concat_after, pos_enc_layer_type, normalize_before, concat_after,
static_chunk_size, use_dynamic_chunk, global_cmvn, static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk) use_dynamic_left_chunk, max_len)
activation = get_activation(activation_type) activation = get_activation(activation_type)
# self-attention module definition # self-attention module definition
......
...@@ -78,21 +78,26 @@ class ASREngine(BaseEngine): ...@@ -78,21 +78,26 @@ class ASREngine(BaseEngine):
Args: Args:
audio_data (bytes): base64.b64decode audio_data (bytes): base64.b64decode
""" """
if self.executor._check( try:
io.BytesIO(audio_data), self.config.sample_rate, if self.executor._check(
self.config.force_yes): io.BytesIO(audio_data), self.config.sample_rate,
logger.info("start run asr engine") self.config.force_yes):
self.executor.preprocess(self.config.model, io.BytesIO(audio_data)) logger.info("start run asr engine")
st = time.time() self.executor.preprocess(self.config.model,
self.executor.infer(self.config.model) io.BytesIO(audio_data))
infer_time = time.time() - st st = time.time()
self.output = self.executor.postprocess() # Retrieve result of asr. self.executor.infer(self.config.model)
else: infer_time = time.time() - st
logger.info("file check failed!") self.output = self.executor.postprocess(
self.output = None ) # Retrieve result of asr.
else:
logger.info("inference time: {}".format(infer_time)) logger.info("file check failed!")
logger.info("asr engine type: python") self.output = None
logger.info("inference time: {}".format(infer_time))
logger.info("asr engine type: python")
except Exception as e:
logger.info(e)
def postprocess(self): def postprocess(self):
"""postprocess """postprocess
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册