提交 42a5bdf8 编写于 作者: H Hui Zhang

encoder offline export

上级 1bc4acfd
......@@ -602,24 +602,61 @@ class U2Tester(U2Trainer):
infer_model.eval()
#static_model = paddle.jit.to_static(infer_model., input_spec=input_spec)
decoder_max_time = 100
encoder_max_time = None
encoder_model_size = 256
logger.info(f"subsampling_rate: {infer_model.subsampling_rate}")
logger.info(f"right_context: {infer_model.right_context}")
logger.info(f"sos_symbol: {infer_model.sos_symbol}")
logger.info(f"eos_symbol: {infer_model.eos_symbol}")
logger.info(f"model_size: {infer_model.encoder.output_size}")
encoder_model_size = infer_model.encoder.output_size
# export encoder
# speech (paddle.Tensor): [B, Tmax, D]
# speech_lengths (paddle.Tensor): [B]
# decoding_chunk_size (int, optional): chuck size. Defaults to -1.
# num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1.
# simulate_streaming (bool, optional): streaming or not. Defaults to False.
#encoder hiddens (B, Tmax, D),
#encoder hiddens mask (B, 1, Tmax).
encoder_input_dim = 80
static_model = paddle.jit.to_static(
infer_model.forward_attention_decoder,
infer_model.encoder.forward_export,
input_spec=[
paddle.static.InputSpec(
shape=[1, decoder_max_time], dtype='int32'), # tgt
paddle.static.InputSpec(
shape=[1, decoder_max_time], dtype='bool'), # tgt_mask
paddle.static.InputSpec(
shape=[1, encoder_max_time, encoder_model_size],
dtype='float32'), # encoder_out
shape=[1, None, encoder_input_dim],
dtype='float32'), # speech, [B, U]
paddle.static.InputSpec(shape=[1, None],
dtype='bool'), # speech_mask, [B, U],
-1,
-1
])
logger.info(f"Export code: {static_model.main_program}")
logger.debug(f"Export Code: {dir(static_model)}")
logger.debug(f"Export Porgram: {static_model.main_program}")
paddle.jit.save(static_model, self.args.export_path)
# # export decoder
# decoder_max_time = 100
# encoder_max_time = None
# static_model = paddle.jit.to_static(
# infer_model.forward_attention_decoder,
# input_spec=[
# paddle.static.InputSpec(
# shape=[1, decoder_max_time], dtype='int32'), # tgt, [B, U]
# paddle.static.InputSpec(
# shape=[1, decoder_max_time],
# dtype='bool'), # tgt_mask, [B, U]
# paddle.static.InputSpec(
# shape=[1, encoder_max_time, encoder_model_size],
# dtype='float32'), # encoder_out, [B, T, D]
# ])
# logger.debug(f"Export Code: {static_model.code}")
# logger.debug(f"Export Porgram: {static_model.main_program}")
# paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
......
......@@ -451,3 +451,46 @@ class ConformerEncoder(BaseEncoder):
normalize_before=normalize_before,
concat_after=concat_after) for _ in range(num_blocks)
])
def forward_export(
self,
xs: paddle.Tensor,
masks: paddle.Tensor,
decoding_chunk_size: int=0,
num_decoding_left_chunks: int=-1,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, L, D)
masks: input length (B, L)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor, lens and mask
"""
masks = masks.unsqueeze(1) # (B, 1, L)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad = masks.logical_not()
chunk_masks = masks
for layer in self.encoders:
xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册