未验证 提交 b1fcba33 编写于 作者: L liu zhengxi 提交者: GitHub

refine cpu (#5214)

* refine transformer cpu

* delete run_pretrain.sh
上级 84d366d4
...@@ -41,6 +41,7 @@ def do_train(args): ...@@ -41,6 +41,7 @@ def do_train(args):
else: else:
rank = 0 rank = 0
trainer_count = 1 trainer_count = 1
paddle.set_device("cpu")
if trainer_count > 1: if trainer_count > 1:
dist.init_parallel_env() dist.init_parallel_env()
......
python -m paddle.distributed.launch \
--gpus="0,1" \
train.py
...@@ -44,8 +44,11 @@ def do_train(args): ...@@ -44,8 +44,11 @@ def do_train(args):
gpu_id) if args.use_gpu else paddle.static.cpu_places() gpu_id) if args.use_gpu else paddle.static.cpu_places()
trainer_count = 1 if args.use_gpu else len(places) trainer_count = 1 if args.use_gpu else len(places)
else: else:
places = paddle.static.cuda_places( if args.use_gpu:
) if args.use_gpu else paddle.static.cpu_places() places = paddle.static.cuda_places()
else:
places = paddle.static.cpu_places()
paddle.set_device("cpu")
trainer_count = len(places) trainer_count = len(places)
# Set seed for CE # Set seed for CE
......
...@@ -33,6 +33,7 @@ def do_train(args): ...@@ -33,6 +33,7 @@ def do_train(args):
else: else:
rank = 0 rank = 0
trainer_count = 1 trainer_count = 1
paddle.set_device("cpu")
if trainer_count > 1: if trainer_count > 1:
dist.init_parallel_env() dist.init_parallel_env()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册