提交 2112ee1e 编写于 作者: H Hui Zhang

fix decoder mask

上级 1ff97432
......@@ -601,17 +601,22 @@ class U2Tester(U2Trainer):
# assert isinstance(input_spec, list), type(input_spec)
infer_model.eval()
#static_model = paddle.jit.to_static(infer_model., input_spec=input_spec)
decoder_max_time = 100
encoder_max_time = None
encoder_model_size = 256
static_model = paddle.jit.to_static(
infer_model.forward_attention_decoder,
infer_model.forward_attention_decoder,
input_spec=[
paddle.static.InputSpec(shape=[1, None],dtype='int32'),
paddle.static.InputSpec(shape=[1],dtype='int32'),
paddle.static.InputSpec(shape=[1, None, 256],dtype='int32'),
]
)
paddle.static.InputSpec(
shape=[1, decoder_max_time], dtype='int32'), # tgt
paddle.static.InputSpec(shape=[1], dtype='int32'), # tgt_len
paddle.static.InputSpec(
shape=[1, encoder_max_time, encoder_model_size],
dtype='int32'), # encoder_out
])
logger.info(f"Export code: {static_model}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
......
......@@ -120,7 +120,7 @@ class TransformerDecoder(nn.Module):
"""
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1))
tgt_mask = make_non_pad_mask(ys_in_lens).unsqueeze(1)
# m: (1, L, L)
m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0)
# tgt_mask: (B, L, L)
......@@ -164,19 +164,24 @@ class TransformerDecoder(nn.Module):
y.shape` is (batch, token)
"""
x, _ = self.embed(tgt)
new_cache = []
for i, decoder in enumerate(self.decoders):
if cache is None:
c = None
else:
c = cache[i]
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, memory_mask, cache=c)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.use_output_layer:
y = paddle.log_softmax(self.output_layer(y), axis=-1)
return y, new_cache
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册