From 653f51939c0894524018b2a435b6b28390178c78 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Wed, 20 Jan 2021 11:12:41 +0800 Subject: [PATCH] fix cpu (#5212) --- PaddleNLP/examples/language_model/transformer-xl/eval.py | 3 +++ PaddleNLP/examples/language_model/transformer-xl/train.py | 1 + 2 files changed, 4 insertions(+) diff --git a/PaddleNLP/examples/language_model/transformer-xl/eval.py b/PaddleNLP/examples/language_model/transformer-xl/eval.py index e5188239..d2817d16 100644 --- a/PaddleNLP/examples/language_model/transformer-xl/eval.py +++ b/PaddleNLP/examples/language_model/transformer-xl/eval.py @@ -55,6 +55,9 @@ def do_eval(args): (loss, np.exp(loss)) return logger_info + if not args.use_gpu: + paddle.set_device("cpu") + vocab = get_lm_vocab(args) eval_loader = get_lm_data_loader(args, vocab, "valid") test_loader = get_lm_data_loader(args, vocab, "test") diff --git a/PaddleNLP/examples/language_model/transformer-xl/train.py b/PaddleNLP/examples/language_model/transformer-xl/train.py index 91297212..78f371a1 100644 --- a/PaddleNLP/examples/language_model/transformer-xl/train.py +++ b/PaddleNLP/examples/language_model/transformer-xl/train.py @@ -37,6 +37,7 @@ def do_train(args): else: rank = 0 trainer_count = 1 + paddle.set_device("cpu") if trainer_count > 1: dist.init_parallel_env() -- GitLab