未验证 提交 6445362f 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #7139 from andyjpaddle/fix_amp_re

[TIPC] Fix amp train for re
......@@ -154,13 +154,14 @@ def check_xpu(use_xpu):
except Exception as e:
pass
def to_float32(preds):
if isinstance(preds, dict):
for k in preds:
if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
preds[k] = preds[k].astype(paddle.float32)
preds[k] = paddle.to_tensor(preds[k], dtype='float32')
elif isinstance(preds, list):
for k in range(len(preds)):
if isinstance(preds[k], dict):
......@@ -168,11 +169,12 @@ def to_float32(preds):
elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
preds[k] = preds[k].astype(paddle.float32)
preds[k] = paddle.to_tensor(preds[k], dtype='float32')
else:
preds = preds.astype(paddle.float32)
preds = paddle.to_tensor(preds, dtype='float32')
return preds
def train(config,
train_dataloader,
valid_dataloader,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册