未验证 提交 be413666 编写于 作者: T topduke 提交者: GitHub

modify dtype='int64' to ='int32' for windows train

上级 512f6768
......@@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer):
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype='int64'))
tgt.shape, dtype='int32'))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册