From 1c9f238ba09e55b26b3b0c46033436ed27eb9613 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 9 Sep 2022 15:45:26 +0000 Subject: [PATCH] configurable export --- paddlespeech/s2t/exps/u2/model.py | 37 +++++++++++++++++++------------ 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 1d813761..45fbcb40 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -462,31 +462,37 @@ class U2Tester(U2Trainer): infer_model = U2InferModel.from_pretrained(self.test_loader, self.config.clone(), self.args.checkpoint_path) + + batch_size = 1 feat_dim = self.test_loader.feat_dim - input_spec = [ - paddle.static.InputSpec(shape=[1, None, feat_dim], - dtype='float32'), # audio, [B,T,D] - paddle.static.InputSpec(shape=[1], - dtype='int64'), # audio_length, [B] - ] - return infer_model, input_spec + model_size = 512 + num_left_chunks = -1 + + return infer_model, (batch_size, feat_dim, model_size, num_left_chunks) @paddle.no_grad() def export(self): infer_model, input_spec = self.load_inferspec() - assert isinstance(input_spec, list), type(input_spec) - del input_spec infer_model.eval() - ######################### infer_model.forward_encoder_chunk zero Tensor online ############ + assert isinstance(input_spec, list), type(input_spec) + batch_size, feat_dim, model_size, num_left_chunks = input_spec + + + ######################### infer_model.forward_encoder_chunk zero tensor online ############ # TODO: 80(feature dim) be configable input_spec = [ - paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'), + # xs, (B, T, D) + paddle.static.InputSpec(shape=[batch_size, None, feat_dim], dtype='float32'), + # offset, int, but need be tensor paddle.static.InputSpec(shape=[1], dtype='int32'), - -1, + # required_cache_size, int + num_left_chunks, + # att_cache paddle.static.InputSpec( shape=[None, None, None, None], dtype='float32'), + # cnn_cache paddle.static.InputSpec( shape=[None, None, None, None], dtype='float32') ] @@ -496,9 +502,12 @@ class U2Tester(U2Trainer): ######################### infer_model.forward_attention_decoder ######################## # TODO: 512(encoder_output) be configable. 1 for BatchSize input_spec = [ + # hyps, (B, U) paddle.static.InputSpec(shape=[None, None], dtype='int64'), + # hyps_lens, (B,) paddle.static.InputSpec(shape=[None], dtype='int64'), - paddle.static.InputSpec(shape=[1, None, 512], dtype='float32') + # encoder_out, (B,T,D) + paddle.static.InputSpec(shape=[batch_size, None, model_size], dtype='float32') ] infer_model.forward_attention_decoder = paddle.jit.to_static( infer_model.forward_attention_decoder, input_spec=input_spec) @@ -529,7 +538,7 @@ class U2Tester(U2Trainer): xs1 = paddle.rand(shape=[1, 67, 80], dtype='float32') offset = paddle.to_tensor([0], dtype='int32') - required_cache_size = -16 + required_cache_size = num_left_chunks att_cache = paddle.zeros([0, 0, 0, 0]) cnn_cache = paddle.zeros([0, 0, 0, 0]) -- GitLab