提交 0ea39f83 编写于 作者: X xiongxinlei

add asr time limt configuration, test=doc

上级 bd66c7a8
......@@ -187,6 +187,13 @@ class ASRExecutor(BaseExecutor):
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.config.decode.decoding_method = decode_method
self.max_len = 5000
if self.config.encoder_conf.get("max_len", None):
self.max_len = self.config.encoder_conf.max_len
logger.info(f"max len: {self.max_len}")
# we assumen that the subsample rate is 4 and every frame step is 40ms
self.max_len = 40 * self.max_len / 1000
else:
raise Exception("wrong type")
model_name = model_type[:model_type.rindex(
......@@ -352,9 +359,10 @@ class ASRExecutor(BaseExecutor):
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
audio_duration = audio.shape[0] / audio_sample_rate
max_duration = 50.0
if audio_duration >= max_duration:
logger.error("Please input audio file less then 50 seconds.\n")
if audio_duration > self.max_len:
logger.error(
f"Please input audio file less then {self.max_len} seconds.\n"
)
return False
except Exception as e:
logger.exception(e)
......
......@@ -62,21 +62,21 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
False: x -> x + att(x)
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
self_attention_dropout_rate: float=0.0,
src_attention_dropout_rate: float=0.0,
input_layer: str="embed",
use_output_layer: bool=True,
normalize_before: bool=True,
concat_after: bool=False, ):
def __init__(self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
self_attention_dropout_rate: float=0.0,
src_attention_dropout_rate: float=0.0,
input_layer: str="embed",
use_output_layer: bool=True,
normalize_before: bool=True,
concat_after: bool=False,
max_len: int=5000):
assert check_argument_types()
......@@ -87,7 +87,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
if input_layer == "embed":
self.embed = nn.Sequential(
Embedding(vocab_size, attention_dim),
PositionalEncoding(attention_dim, positional_dropout_rate), )
PositionalEncoding(
attention_dim, positional_dropout_rate, max_len=max_len), )
else:
raise ValueError(f"only 'embed' is supported: {input_layer}")
......
......@@ -112,7 +112,9 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
"""
T = x.shape[1]
assert offset + x.shape[1] < self.max_len
assert offset + x.shape[
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len)
#TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + T]
x = x * self.xscale + pos_emb
......@@ -148,6 +150,7 @@ class RelPositionalEncoding(PositionalEncoding):
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
logger.info(f"max len: {max_len}")
def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
......@@ -158,7 +161,9 @@ class RelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
assert offset + x.shape[1] < self.max_len
assert offset + x.shape[
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len)
x = x * self.xscale
#TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + x.shape[1]]
......
......@@ -47,24 +47,24 @@ __all__ = ["BaseEncoder", 'TransformerEncoder', "ConformerEncoder"]
class BaseEncoder(nn.Layer):
def __init__(
self,
input_size: int,
output_size: int=256,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
attention_dropout_rate: float=0.0,
input_layer: str="conv2d",
pos_enc_layer_type: str="abs_pos",
normalize_before: bool=True,
concat_after: bool=False,
static_chunk_size: int=0,
use_dynamic_chunk: bool=False,
global_cmvn: paddle.nn.Layer=None,
use_dynamic_left_chunk: bool=False, ):
def __init__(self,
input_size: int,
output_size: int=256,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
attention_dropout_rate: float=0.0,
input_layer: str="conv2d",
pos_enc_layer_type: str="abs_pos",
normalize_before: bool=True,
concat_after: bool=False,
static_chunk_size: int=0,
use_dynamic_chunk: bool=False,
global_cmvn: paddle.nn.Layer=None,
use_dynamic_left_chunk: bool=False,
max_len: int=5000):
"""
Args:
input_size (int): input dim, d_feature
......@@ -127,7 +127,9 @@ class BaseEncoder(nn.Layer):
odim=output_size,
dropout_rate=dropout_rate,
pos_enc_class=pos_enc_class(
d_model=output_size, dropout_rate=positional_dropout_rate), )
d_model=output_size,
dropout_rate=positional_dropout_rate,
max_len=max_len), )
self.normalize_before = normalize_before
self.after_norm = LayerNorm(output_size, epsilon=1e-12)
......@@ -330,7 +332,7 @@ class BaseEncoder(nn.Layer):
# fake mask, just for jit script and compatibility with `forward` api
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
return ys, masks
return ys, masks, offset
class TransformerEncoder(BaseEncoder):
......@@ -415,32 +417,32 @@ class TransformerEncoder(BaseEncoder):
class ConformerEncoder(BaseEncoder):
"""Conformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int=256,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
attention_dropout_rate: float=0.0,
input_layer: str="conv2d",
pos_enc_layer_type: str="rel_pos",
normalize_before: bool=True,
concat_after: bool=False,
static_chunk_size: int=0,
use_dynamic_chunk: bool=False,
global_cmvn: nn.Layer=None,
use_dynamic_left_chunk: bool=False,
positionwise_conv_kernel_size: int=1,
macaron_style: bool=True,
selfattention_layer_type: str="rel_selfattn",
activation_type: str="swish",
use_cnn_module: bool=True,
cnn_module_kernel: int=15,
causal: bool=False,
cnn_module_norm: str="batch_norm", ):
def __init__(self,
input_size: int,
output_size: int=256,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
attention_dropout_rate: float=0.0,
input_layer: str="conv2d",
pos_enc_layer_type: str="rel_pos",
normalize_before: bool=True,
concat_after: bool=False,
static_chunk_size: int=0,
use_dynamic_chunk: bool=False,
global_cmvn: nn.Layer=None,
use_dynamic_left_chunk: bool=False,
positionwise_conv_kernel_size: int=1,
macaron_style: bool=True,
selfattention_layer_type: str="rel_selfattn",
activation_type: str="swish",
use_cnn_module: bool=True,
cnn_module_kernel: int=15,
causal: bool=False,
cnn_module_norm: str="batch_norm",
max_len: int=5000):
"""Construct ConformerEncoder
Args:
input_size to use_dynamic_chunk, see in BaseEncoder
......@@ -464,7 +466,7 @@ class ConformerEncoder(BaseEncoder):
attention_dropout_rate, input_layer,
pos_enc_layer_type, normalize_before, concat_after,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk)
use_dynamic_left_chunk, max_len)
activation = get_activation(activation_type)
# self-attention module definition
......
......@@ -78,21 +78,26 @@ class ASREngine(BaseEngine):
Args:
audio_data (bytes): base64.b64decode
"""
if self.executor._check(
io.BytesIO(audio_data), self.config.sample_rate,
self.config.force_yes):
logger.info("start run asr engine")
self.executor.preprocess(self.config.model, io.BytesIO(audio_data))
st = time.time()
self.executor.infer(self.config.model)
infer_time = time.time() - st
self.output = self.executor.postprocess() # Retrieve result of asr.
else:
logger.info("file check failed!")
self.output = None
logger.info("inference time: {}".format(infer_time))
logger.info("asr engine type: python")
try:
if self.executor._check(
io.BytesIO(audio_data), self.config.sample_rate,
self.config.force_yes):
logger.info("start run asr engine")
self.executor.preprocess(self.config.model,
io.BytesIO(audio_data))
st = time.time()
self.executor.infer(self.config.model)
infer_time = time.time() - st
self.output = self.executor.postprocess(
) # Retrieve result of asr.
else:
logger.info("file check failed!")
self.output = None
logger.info("inference time: {}".format(infer_time))
logger.info("asr engine type: python")
except Exception as e:
logger.info(e)
def postprocess(self):
"""postprocess
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册