提交 3a8869fb 编写于 作者: H Hui Zhang

rm to_static decarator; configure jit save for ctc_activation

上级 1c9f238b
......@@ -513,9 +513,9 @@ class U2Tester(U2Trainer):
infer_model.forward_attention_decoder, input_spec=input_spec)
######################### infer_model.ctc_activation ########################
# TODO: 512(encoder_output) be configable
input_spec = [
paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')
# encoder_out, (B,T,D)
paddle.static.InputSpec(shape=[batch_size, None, model_size], dtype='float32')
]
infer_model.ctc_activation = paddle.jit.to_static(
infer_model.ctc_activation, input_spec=input_spec)
......
......@@ -599,12 +599,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
return self.eos
# @jit.to_static(input_spec=[
# paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
# paddle.static.InputSpec(shape=[1], dtype='int32'),
# -1,
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'),
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')])
# @jit.to_static
def forward_encoder_chunk(
self,
xs: paddle.Tensor,
......@@ -658,10 +653,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
return self.ctc.log_softmax(xs)
# @jit.to_static(input_spec=[
# paddle.static.InputSpec(shape=[None, None], dtype='int64'),
# paddle.static.InputSpec(shape=[None], dtype='int64'),
# paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')])
# @jit.to_static
def forward_attention_decoder(
self,
hyps: paddle.Tensor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册