提交 6345e607 编写于 作者: R ruri 提交者: GitHub

fix multi-cpu bugs (#4250)

上级 fdac2d0b
...@@ -122,7 +122,8 @@ def eval(args): ...@@ -122,7 +122,8 @@ def eval(args):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
if args.use_gpu: if args.use_gpu:
places = fluid.framework.cuda_places() places = fluid.framework.cuda_places()
else:
places = fluid.framework.cpu_places()
compiled_program = fluid.compiler.CompiledProgram( compiled_program = fluid.compiler.CompiledProgram(
test_program).with_data_parallel(places=places) test_program).with_data_parallel(places=places)
...@@ -137,7 +138,8 @@ def eval(args): ...@@ -137,7 +138,8 @@ def eval(args):
cnt = 0 cnt = 0
parallel_data = [] parallel_data = []
parallel_id = [] parallel_id = []
place_num = paddle.fluid.core.get_cuda_device_count() place_num = paddle.fluid.core.get_cuda_device_count(
) if args.use_gpu else int(os.environ.get('CPU_NUM', 1))
real_iter = 0 real_iter = 0
info_dict = {} info_dict = {}
......
...@@ -106,9 +106,10 @@ def infer(args): ...@@ -106,9 +106,10 @@ def infer(args):
place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
places = place
if args.use_gpu: if args.use_gpu:
places = fluid.framework.cuda_places() places = fluid.framework.cuda_places()
else:
places = fluid.framework.cpu_places()
compiled_program = fluid.compiler.CompiledProgram( compiled_program = fluid.compiler.CompiledProgram(
test_program).with_data_parallel(places=places) test_program).with_data_parallel(places=places)
...@@ -147,7 +148,8 @@ def infer(args): ...@@ -147,7 +148,8 @@ def infer(args):
info = {} info = {}
parallel_data = [] parallel_data = []
parallel_id = [] parallel_id = []
place_num = paddle.fluid.core.get_cuda_device_count() if args.use_gpu else 1 place_num = paddle.fluid.core.get_cuda_device_count(
) if args.use_gpu else int(os.environ.get('CPU_NUM', 1))
if os.path.exists(args.save_json_path): if os.path.exists(args.save_json_path):
logger.warning("path: {} Already exists! will recover it\n".format( logger.warning("path: {} Already exists! will recover it\n".format(
args.save_json_path)) args.save_json_path))
......
...@@ -203,9 +203,17 @@ def train(args): ...@@ -203,9 +203,17 @@ def train(args):
else: else:
imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None) imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None)
train_reader = imagenet_reader.train(settings=args) train_reader = imagenet_reader.train(settings=args)
places = place if args.use_gpu:
if num_trainers <= 1 and args.use_gpu: if num_trainers <= 1:
places = fluid.framework.cuda_places() places = fluid.framework.cuda_places()
else:
places = place
else:
if num_trainers <= 1:
places = fluid.framework.cpu_places()
else:
places = place
train_data_loader.set_sample_list_generator(train_reader, places) train_data_loader.set_sample_list_generator(train_reader, places)
if args.validate: if args.validate:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册