未验证 提交 653f5193 编写于 作者: L liu zhengxi 提交者: GitHub

fix cpu (#5212)

上级 05cd6ce3
......@@ -55,6 +55,9 @@ def do_eval(args):
(loss, np.exp(loss))
return logger_info
if not args.use_gpu:
paddle.set_device("cpu")
vocab = get_lm_vocab(args)
eval_loader = get_lm_data_loader(args, vocab, "valid")
test_loader = get_lm_data_loader(args, vocab, "test")
......
......@@ -37,6 +37,7 @@ def do_train(args):
else:
rank = 0
trainer_count = 1
paddle.set_device("cpu")
if trainer_count > 1:
dist.init_parallel_env()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册