提交 2112ee1e 编写于 作者: H Hui Zhang

fix decoder mask

上级 1ff97432
...@@ -601,17 +601,22 @@ class U2Tester(U2Trainer): ...@@ -601,17 +601,22 @@ class U2Tester(U2Trainer):
# assert isinstance(input_spec, list), type(input_spec) # assert isinstance(input_spec, list), type(input_spec)
infer_model.eval() infer_model.eval()
#static_model = paddle.jit.to_static(infer_model., input_spec=input_spec) #static_model = paddle.jit.to_static(infer_model., input_spec=input_spec)
decoder_max_time = 100
encoder_max_time = None
encoder_model_size = 256
static_model = paddle.jit.to_static( static_model = paddle.jit.to_static(
infer_model.forward_attention_decoder, infer_model.forward_attention_decoder,
input_spec=[ input_spec=[
paddle.static.InputSpec(shape=[1, None],dtype='int32'), paddle.static.InputSpec(
paddle.static.InputSpec(shape=[1],dtype='int32'), shape=[1, decoder_max_time], dtype='int32'), # tgt
paddle.static.InputSpec(shape=[1, None, 256],dtype='int32'), paddle.static.InputSpec(shape=[1], dtype='int32'), # tgt_len
] paddle.static.InputSpec(
) shape=[1, encoder_max_time, encoder_model_size],
dtype='int32'), # encoder_out
])
logger.info(f"Export code: {static_model}") logger.info(f"Export code: {static_model}")
paddle.jit.save(static_model, self.args.export_path) paddle.jit.save(static_model, self.args.export_path)
def run_export(self): def run_export(self):
......
...@@ -120,7 +120,7 @@ class TransformerDecoder(nn.Module): ...@@ -120,7 +120,7 @@ class TransformerDecoder(nn.Module):
""" """
tgt = ys_in_pad tgt = ys_in_pad
# tgt_mask: (B, 1, L) # tgt_mask: (B, 1, L)
tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1)) tgt_mask = make_non_pad_mask(ys_in_lens).unsqueeze(1)
# m: (1, L, L) # m: (1, L, L)
m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0) m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0)
# tgt_mask: (B, L, L) # tgt_mask: (B, L, L)
...@@ -164,19 +164,24 @@ class TransformerDecoder(nn.Module): ...@@ -164,19 +164,24 @@ class TransformerDecoder(nn.Module):
y.shape` is (batch, token) y.shape` is (batch, token)
""" """
x, _ = self.embed(tgt) x, _ = self.embed(tgt)
new_cache = [] new_cache = []
for i, decoder in enumerate(self.decoders): for i, decoder in enumerate(self.decoders):
if cache is None: if cache is None:
c = None c = None
else: else:
c = cache[i] c = cache[i]
x, tgt_mask, memory, memory_mask = decoder( x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, memory_mask, cache=c) x, tgt_mask, memory, memory_mask, cache=c)
new_cache.append(x) new_cache.append(x)
if self.normalize_before: if self.normalize_before:
y = self.after_norm(x[:, -1]) y = self.after_norm(x[:, -1])
else: else:
y = x[:, -1] y = x[:, -1]
if self.use_output_layer: if self.use_output_layer:
y = paddle.log_softmax(self.output_layer(y), axis=-1) y = paddle.log_softmax(self.output_layer(y), axis=-1)
return y, new_cache return y, new_cache
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册