未验证 提交 7710ee04 编写于 作者: U user3984 提交者: GitHub

fix data type error when training with fp16 and DynamicToStatic (#9696)

Co-authored-by: 文幕地方's avatarWenmuZhou <572459439@qq.com>
上级 2a98d40b
......@@ -276,7 +276,9 @@ class ParallelSARDecoder(BaseDecoder):
hf_c = holistic_feat.shape[-1]
holistic_feat = paddle.expand(
holistic_feat, shape=[bsz, seq_len, hf_c])
y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
y = self.prediction(
paddle.concat((y, attn_feat.astype(y.dtype),
holistic_feat.astype(y.dtype)), 2))
else:
y = self.prediction(attn_feat)
# bsz * (seq_len + 1) * num_classes
......@@ -298,7 +300,7 @@ class ParallelSARDecoder(BaseDecoder):
lab_embedding = self.embedding(label)
# bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1)
out_enc = out_enc.unsqueeze(1).astype(lab_embedding.dtype)
# bsz * 1 * emb_dim
in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
# bsz * (seq_len + 1) * C
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册