提交 309c8d70 编写于 作者: H Hui Zhang

add reverse weight

上级 9b66680e
......@@ -520,6 +520,7 @@ class U2Tester(U2Trainer):
infer_model.ctc_activation, input_spec=input_spec)
######################### infer_model.forward_attention_decoder ########################
reverse_weight = 0.3
input_spec = [
# hyps, (B, U)
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
......@@ -527,7 +528,8 @@ class U2Tester(U2Trainer):
paddle.static.InputSpec(shape=[None], dtype='int64'),
# encoder_out, (B,T,D)
paddle.static.InputSpec(
shape=[batch_size, None, model_size], dtype='float32')
shape=[batch_size, None, model_size], dtype='float32'),
reverse_weight
]
infer_model.forward_attention_decoder = paddle.jit.to_static(
infer_model.forward_attention_decoder, input_spec=input_spec)
......
......@@ -706,7 +706,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
hyps: paddle.Tensor,
hyps_lens: paddle.Tensor,
encoder_out: paddle.Tensor,
reverse_weight: float=0.0, ) -> paddle.Tensor:
reverse_weight: float=0.0) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册