未验证 提交 6445362f 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #7139 from andyjpaddle/fix_amp_re

[TIPC] Fix amp train for re
...@@ -154,13 +154,14 @@ def check_xpu(use_xpu): ...@@ -154,13 +154,14 @@ def check_xpu(use_xpu):
except Exception as e: except Exception as e:
pass pass
def to_float32(preds): def to_float32(preds):
if isinstance(preds, dict): if isinstance(preds, dict):
for k in preds: for k in preds:
if isinstance(preds[k], dict) or isinstance(preds[k], list): if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k]) preds[k] = to_float32(preds[k])
else: else:
preds[k] = preds[k].astype(paddle.float32) preds[k] = paddle.to_tensor(preds[k], dtype='float32')
elif isinstance(preds, list): elif isinstance(preds, list):
for k in range(len(preds)): for k in range(len(preds)):
if isinstance(preds[k], dict): if isinstance(preds[k], dict):
...@@ -168,11 +169,12 @@ def to_float32(preds): ...@@ -168,11 +169,12 @@ def to_float32(preds):
elif isinstance(preds[k], list): elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k]) preds[k] = to_float32(preds[k])
else: else:
preds[k] = preds[k].astype(paddle.float32) preds[k] = paddle.to_tensor(preds[k], dtype='float32')
else: else:
preds = preds.astype(paddle.float32) preds = paddle.to_tensor(preds, dtype='float32')
return preds return preds
def train(config, def train(config,
train_dataloader, train_dataloader,
valid_dataloader, valid_dataloader,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册