From b1fcba33b01acfa6dd1e50996e976a23f8ae25d7 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Thu, 21 Jan 2021 21:34:40 +0800 Subject: [PATCH] refine cpu (#5214) * refine transformer cpu * delete run_pretrain.sh --- PaddleNLP/benchmark/transformer/dygraph/train.py | 1 + PaddleNLP/benchmark/transformer/static/run_pretrain.sh | 4 ---- PaddleNLP/benchmark/transformer/static/train.py | 7 +++++-- .../examples/machine_translation/transformer/train.py | 1 + 4 files changed, 7 insertions(+), 6 deletions(-) delete mode 100644 PaddleNLP/benchmark/transformer/static/run_pretrain.sh diff --git a/PaddleNLP/benchmark/transformer/dygraph/train.py b/PaddleNLP/benchmark/transformer/dygraph/train.py index 58424f06..69687e2e 100644 --- a/PaddleNLP/benchmark/transformer/dygraph/train.py +++ b/PaddleNLP/benchmark/transformer/dygraph/train.py @@ -41,6 +41,7 @@ def do_train(args): else: rank = 0 trainer_count = 1 + paddle.set_device("cpu") if trainer_count > 1: dist.init_parallel_env() diff --git a/PaddleNLP/benchmark/transformer/static/run_pretrain.sh b/PaddleNLP/benchmark/transformer/static/run_pretrain.sh deleted file mode 100644 index 8136b0a2..00000000 --- a/PaddleNLP/benchmark/transformer/static/run_pretrain.sh +++ /dev/null @@ -1,4 +0,0 @@ - -python -m paddle.distributed.launch \ - --gpus="0,1" \ - train.py diff --git a/PaddleNLP/benchmark/transformer/static/train.py b/PaddleNLP/benchmark/transformer/static/train.py index 45904f7d..b99c056d 100644 --- a/PaddleNLP/benchmark/transformer/static/train.py +++ b/PaddleNLP/benchmark/transformer/static/train.py @@ -44,8 +44,11 @@ def do_train(args): gpu_id) if args.use_gpu else paddle.static.cpu_places() trainer_count = 1 if args.use_gpu else len(places) else: - places = paddle.static.cuda_places( - ) if args.use_gpu else paddle.static.cpu_places() + if args.use_gpu: + places = paddle.static.cuda_places() + else: + places = paddle.static.cpu_places() + paddle.set_device("cpu") trainer_count = len(places) # Set seed for CE diff --git a/PaddleNLP/examples/machine_translation/transformer/train.py b/PaddleNLP/examples/machine_translation/transformer/train.py index 0137968f..41512579 100644 --- a/PaddleNLP/examples/machine_translation/transformer/train.py +++ b/PaddleNLP/examples/machine_translation/transformer/train.py @@ -33,6 +33,7 @@ def do_train(args): else: rank = 0 trainer_count = 1 + paddle.set_device("cpu") if trainer_count > 1: dist.init_parallel_env() -- GitLab