diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 54810f22f92df20ab95e751e143b146be111ff5a..64b6c8df6defa47276fafb3287c862f365058a4f 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -520,6 +520,7 @@ class U2Tester(U2Trainer): infer_model.ctc_activation, input_spec=input_spec) ######################### infer_model.forward_attention_decoder ######################## + reverse_weight = 0.3 input_spec = [ # hyps, (B, U) paddle.static.InputSpec(shape=[None, None], dtype='int64'), @@ -527,7 +528,8 @@ class U2Tester(U2Trainer): paddle.static.InputSpec(shape=[None], dtype='int64'), # encoder_out, (B,T,D) paddle.static.InputSpec( - shape=[batch_size, None, model_size], dtype='float32') + shape=[batch_size, None, model_size], dtype='float32'), + reverse_weight ] infer_model.forward_attention_decoder = paddle.jit.to_static( infer_model.forward_attention_decoder, input_spec=input_spec) diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index d699b684bf2782ce561811482c9fd3b7e416f613..1681bf1d967c5f02a69b5c2c00053cde5ea77f71 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -706,7 +706,7 @@ class U2BaseModel(ASRInterface, nn.Layer): hyps: paddle.Tensor, hyps_lens: paddle.Tensor, encoder_out: paddle.Tensor, - reverse_weight: float=0.0, ) -> paddle.Tensor: + reverse_weight: float=0.0) -> paddle.Tensor: """ Export interface for c++ call, forward decoder with multiple hypothesis from ctc prefix beam search and one encoder output Args: