diff --git a/tools/program.py b/tools/program.py index 0fa0e609bd14d07cc593786b3a3f760cb9b98500..53994b8560f826c43cd50919237fab2b7fe1c64d 100755 --- a/tools/program.py +++ b/tools/program.py @@ -154,13 +154,14 @@ def check_xpu(use_xpu): except Exception as e: pass + def to_float32(preds): if isinstance(preds, dict): for k in preds: if isinstance(preds[k], dict) or isinstance(preds[k], list): preds[k] = to_float32(preds[k]) else: - preds[k] = preds[k].astype(paddle.float32) + preds[k] = paddle.to_tensor(preds[k], dtype='float32') elif isinstance(preds, list): for k in range(len(preds)): if isinstance(preds[k], dict): @@ -168,11 +169,12 @@ def to_float32(preds): elif isinstance(preds[k], list): preds[k] = to_float32(preds[k]) else: - preds[k] = preds[k].astype(paddle.float32) + preds[k] = paddle.to_tensor(preds[k], dtype='float32') else: - preds = preds.astype(paddle.float32) + preds = paddle.to_tensor(preds, dtype='float32') return preds + def train(config, train_dataloader, valid_dataloader,