提交 8e3aa3a0 编写于 作者: A andyjpaddle

fix amp train for re

上级 1477b89a
......@@ -171,7 +171,7 @@ def to_float32(preds):
else:
preds[k] = paddle.to_tensor(preds[k], dtype='float32')
else:
preds[k] = paddle.to_tensor(preds[k], dtype='float32')
preds = paddle.to_tensor(preds, dtype='float32')
return preds
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册