未验证 提交 653f5193 编写于 作者: L liu zhengxi 提交者: GitHub

fix cpu (#5212)

上级 05cd6ce3
...@@ -55,6 +55,9 @@ def do_eval(args): ...@@ -55,6 +55,9 @@ def do_eval(args):
(loss, np.exp(loss)) (loss, np.exp(loss))
return logger_info return logger_info
if not args.use_gpu:
paddle.set_device("cpu")
vocab = get_lm_vocab(args) vocab = get_lm_vocab(args)
eval_loader = get_lm_data_loader(args, vocab, "valid") eval_loader = get_lm_data_loader(args, vocab, "valid")
test_loader = get_lm_data_loader(args, vocab, "test") test_loader = get_lm_data_loader(args, vocab, "test")
......
...@@ -37,6 +37,7 @@ def do_train(args): ...@@ -37,6 +37,7 @@ def do_train(args):
else: else:
rank = 0 rank = 0
trainer_count = 1 trainer_count = 1
paddle.set_device("cpu")
if trainer_count > 1: if trainer_count > 1:
dist.init_parallel_env() dist.init_parallel_env()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册