提交 d1c74b39 编写于 作者: L LielinJiang 提交者: Lv Mengsi

Fix get cuda device bug when only use cpu (#4066)

* fix get cuda device bug when only use cpu

* fix typo
上级 0dfd43a2
...@@ -164,7 +164,7 @@ total_step = args.total_step ...@@ -164,7 +164,7 @@ total_step = args.total_step
with fluid.program_guard(tp, sp): with fluid.program_guard(tp, sp):
if args.use_py_reader: if args.use_py_reader:
batch_size_each = batch_size // fluid.core.get_cuda_device_count() batch_size_each = batch_size // utility.get_device_count()
py_reader = fluid.layers.py_reader(capacity=64, py_reader = fluid.layers.py_reader(capacity=64,
shapes=[[batch_size_each, 3] + image_shape, [batch_size_each] + image_shape], shapes=[[batch_size_each, 3] + image_shape, [batch_size_each] + image_shape],
dtypes=['float32', 'int32']) dtypes=['float32', 'int32'])
...@@ -197,7 +197,7 @@ with fluid.program_guard(tp, sp): ...@@ -197,7 +197,7 @@ with fluid.program_guard(tp, sp):
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = fluid.core.get_cuda_device_count() exec_strategy.num_threads = utility.get_device_count()
exec_strategy.num_iteration_per_drop_scope = 100 exec_strategy.num_iteration_per_drop_scope = 100
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
if args.memory_optimize: if args.memory_optimize:
...@@ -225,11 +225,11 @@ else: ...@@ -225,11 +225,11 @@ else:
binary = fluid.compiler.CompiledProgram(tp) binary = fluid.compiler.CompiledProgram(tp)
if args.use_py_reader: if args.use_py_reader:
assert(batch_size % fluid.core.get_cuda_device_count() == 0) assert(batch_size % utility.get_device_count() == 0)
def data_gen(): def data_gen():
batches = dataset.get_batch_generator( batches = dataset.get_batch_generator(
batch_size // fluid.core.get_cuda_device_count(), batch_size // utility.get_device_count(),
total_step * fluid.core.get_cuda_device_count(), total_step * utility.get_device_count(),
use_multiprocessing=args.use_multiprocessing, num_workers=args.num_workers) use_multiprocessing=args.use_multiprocessing, num_workers=args.num_workers)
for b in batches: for b in batches:
yield b[0], b[1] yield b[0], b[1]
...@@ -266,7 +266,7 @@ print("Training done. Model is saved to", args.save_weights_path) ...@@ -266,7 +266,7 @@ print("Training done. Model is saved to", args.save_weights_path)
save_model() save_model()
if args.enable_ce: if args.enable_ce:
gpu_num = fluid.core.get_cuda_device_count() gpu_num = utility.get_device_count()
print("kpis\teach_pass_duration_card%s\t%s" % print("kpis\teach_pass_duration_card%s\t%s" %
(gpu_num, total_time / epoch_idx)) (gpu_num, total_time / epoch_idx))
print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, train_loss)) print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, train_loss))
...@@ -78,3 +78,12 @@ def check_gpu(use_gpu): ...@@ -78,3 +78,12 @@ def check_gpu(use_gpu):
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
pass pass
def get_device_count():
try:
device_num = max(fluid.core.get_cuda_device_count(), 1)
except:
device_num = 1
return device_num
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册