提交 6345e607 编写于 作者: R ruri 提交者: GitHub

fix multi-cpu bugs (#4250)

上级 fdac2d0b
......@@ -122,7 +122,8 @@ def eval(args):
exe.run(fluid.default_startup_program())
if args.use_gpu:
places = fluid.framework.cuda_places()
else:
places = fluid.framework.cpu_places()
compiled_program = fluid.compiler.CompiledProgram(
test_program).with_data_parallel(places=places)
......@@ -137,7 +138,8 @@ def eval(args):
cnt = 0
parallel_data = []
parallel_id = []
place_num = paddle.fluid.core.get_cuda_device_count()
place_num = paddle.fluid.core.get_cuda_device_count(
) if args.use_gpu else int(os.environ.get('CPU_NUM', 1))
real_iter = 0
info_dict = {}
......
......@@ -106,9 +106,10 @@ def infer(args):
place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
places = place
if args.use_gpu:
places = fluid.framework.cuda_places()
else:
places = fluid.framework.cpu_places()
compiled_program = fluid.compiler.CompiledProgram(
test_program).with_data_parallel(places=places)
......@@ -147,7 +148,8 @@ def infer(args):
info = {}
parallel_data = []
parallel_id = []
place_num = paddle.fluid.core.get_cuda_device_count() if args.use_gpu else 1
place_num = paddle.fluid.core.get_cuda_device_count(
) if args.use_gpu else int(os.environ.get('CPU_NUM', 1))
if os.path.exists(args.save_json_path):
logger.warning("path: {} Already exists! will recover it\n".format(
args.save_json_path))
......
......@@ -203,9 +203,17 @@ def train(args):
else:
imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None)
train_reader = imagenet_reader.train(settings=args)
places = place
if num_trainers <= 1 and args.use_gpu:
if args.use_gpu:
if num_trainers <= 1:
places = fluid.framework.cuda_places()
else:
places = place
else:
if num_trainers <= 1:
places = fluid.framework.cpu_places()
else:
places = place
train_data_loader.set_sample_list_generator(train_reader, places)
if args.validate:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册