diff --git a/PaddleNLP/examples/language_model/transformer-xl/eval.py b/PaddleNLP/examples/language_model/transformer-xl/eval.py index e518823967e222e3c3824f71353d63159bd00278..d2817d1636f72d381a15eed2ad2edb4ca13322ad 100644 --- a/PaddleNLP/examples/language_model/transformer-xl/eval.py +++ b/PaddleNLP/examples/language_model/transformer-xl/eval.py @@ -55,6 +55,9 @@ def do_eval(args): (loss, np.exp(loss)) return logger_info + if not args.use_gpu: + paddle.set_device("cpu") + vocab = get_lm_vocab(args) eval_loader = get_lm_data_loader(args, vocab, "valid") test_loader = get_lm_data_loader(args, vocab, "test") diff --git a/PaddleNLP/examples/language_model/transformer-xl/train.py b/PaddleNLP/examples/language_model/transformer-xl/train.py index 912972121a24e4ca87d9ee4b127e133cc19d5dcf..78f371a1614fc1cd6b7a67e23b102ce753d5ccef 100644 --- a/PaddleNLP/examples/language_model/transformer-xl/train.py +++ b/PaddleNLP/examples/language_model/transformer-xl/train.py @@ -37,6 +37,7 @@ def do_train(args): else: rank = 0 trainer_count = 1 + paddle.set_device("cpu") if trainer_count > 1: dist.init_parallel_env()