diff --git a/PaddleNLP/benchmark/transformer/dygraph/train.py b/PaddleNLP/benchmark/transformer/dygraph/train.py index 58424f063317f191985fcc877521cbd941a1a32a..69687e2e51b180b7dc3fe4882abf629e63be5ca9 100644 --- a/PaddleNLP/benchmark/transformer/dygraph/train.py +++ b/PaddleNLP/benchmark/transformer/dygraph/train.py @@ -41,6 +41,7 @@ def do_train(args): else: rank = 0 trainer_count = 1 + paddle.set_device("cpu") if trainer_count > 1: dist.init_parallel_env() diff --git a/PaddleNLP/benchmark/transformer/static/run_pretrain.sh b/PaddleNLP/benchmark/transformer/static/run_pretrain.sh deleted file mode 100644 index 8136b0a2abc5818cdd10f51927a23ae98363e702..0000000000000000000000000000000000000000 --- a/PaddleNLP/benchmark/transformer/static/run_pretrain.sh +++ /dev/null @@ -1,4 +0,0 @@ - -python -m paddle.distributed.launch \ - --gpus="0,1" \ - train.py diff --git a/PaddleNLP/benchmark/transformer/static/train.py b/PaddleNLP/benchmark/transformer/static/train.py index 45904f7de60052518b150943cc5820218e2abc6c..b99c056dc87e403251e83ab7e86bae1b4156d83a 100644 --- a/PaddleNLP/benchmark/transformer/static/train.py +++ b/PaddleNLP/benchmark/transformer/static/train.py @@ -44,8 +44,11 @@ def do_train(args): gpu_id) if args.use_gpu else paddle.static.cpu_places() trainer_count = 1 if args.use_gpu else len(places) else: - places = paddle.static.cuda_places( - ) if args.use_gpu else paddle.static.cpu_places() + if args.use_gpu: + places = paddle.static.cuda_places() + else: + places = paddle.static.cpu_places() + paddle.set_device("cpu") trainer_count = len(places) # Set seed for CE diff --git a/PaddleNLP/examples/machine_translation/transformer/train.py b/PaddleNLP/examples/machine_translation/transformer/train.py index 0137968ff6cd53da7af7dcd8dbdd7f0c4b8000c8..415125797049d72d540e2d2e529be261e60c147e 100644 --- a/PaddleNLP/examples/machine_translation/transformer/train.py +++ b/PaddleNLP/examples/machine_translation/transformer/train.py @@ -33,6 +33,7 @@ def do_train(args): else: rank = 0 trainer_count = 1 + paddle.set_device("cpu") if trainer_count > 1: dist.init_parallel_env()