From 6345e607d2bb606a4a9436fcf2e09751407df69b Mon Sep 17 00:00:00 2001 From: ruri Date: Mon, 10 Feb 2020 16:57:38 +0800 Subject: [PATCH] fix multi-cpu bugs (#4250) --- PaddleCV/image_classification/eval.py | 6 ++++-- PaddleCV/image_classification/infer.py | 6 ++++-- PaddleCV/image_classification/train.py | 14 +++++++++++--- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/PaddleCV/image_classification/eval.py b/PaddleCV/image_classification/eval.py index 8d592006..44ea4736 100644 --- a/PaddleCV/image_classification/eval.py +++ b/PaddleCV/image_classification/eval.py @@ -122,7 +122,8 @@ def eval(args): exe.run(fluid.default_startup_program()) if args.use_gpu: places = fluid.framework.cuda_places() - + else: + places = fluid.framework.cpu_places() compiled_program = fluid.compiler.CompiledProgram( test_program).with_data_parallel(places=places) @@ -137,7 +138,8 @@ def eval(args): cnt = 0 parallel_data = [] parallel_id = [] - place_num = paddle.fluid.core.get_cuda_device_count() + place_num = paddle.fluid.core.get_cuda_device_count( + ) if args.use_gpu else int(os.environ.get('CPU_NUM', 1)) real_iter = 0 info_dict = {} diff --git a/PaddleCV/image_classification/infer.py b/PaddleCV/image_classification/infer.py index 1b1368ae..092ba636 100644 --- a/PaddleCV/image_classification/infer.py +++ b/PaddleCV/image_classification/infer.py @@ -106,9 +106,10 @@ def infer(args): place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - places = place if args.use_gpu: places = fluid.framework.cuda_places() + else: + places = fluid.framework.cpu_places() compiled_program = fluid.compiler.CompiledProgram( test_program).with_data_parallel(places=places) @@ -147,7 +148,8 @@ def infer(args): info = {} parallel_data = [] parallel_id = [] - place_num = paddle.fluid.core.get_cuda_device_count() if args.use_gpu else 1 + place_num = paddle.fluid.core.get_cuda_device_count( + ) if args.use_gpu else int(os.environ.get('CPU_NUM', 1)) if os.path.exists(args.save_json_path): logger.warning("path: {} Already exists! will recover it\n".format( args.save_json_path)) diff --git a/PaddleCV/image_classification/train.py b/PaddleCV/image_classification/train.py index 74358784..e61333fd 100755 --- a/PaddleCV/image_classification/train.py +++ b/PaddleCV/image_classification/train.py @@ -203,9 +203,17 @@ def train(args): else: imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None) train_reader = imagenet_reader.train(settings=args) - places = place - if num_trainers <= 1 and args.use_gpu: - places = fluid.framework.cuda_places() + if args.use_gpu: + if num_trainers <= 1: + places = fluid.framework.cuda_places() + else: + places = place + else: + if num_trainers <= 1: + places = fluid.framework.cpu_places() + else: + places = place + train_data_loader.set_sample_list_generator(train_reader, places) if args.validate: -- GitLab