提交 d1c74b39 编写于 作者: L LielinJiang 提交者: Lv Mengsi

Fix get cuda device bug when only use cpu (#4066)

* fix get cuda device bug when only use cpu

* fix typo
上级 0dfd43a2
......@@ -164,7 +164,7 @@ total_step = args.total_step
with fluid.program_guard(tp, sp):
if args.use_py_reader:
batch_size_each = batch_size // fluid.core.get_cuda_device_count()
batch_size_each = batch_size // utility.get_device_count()
py_reader = fluid.layers.py_reader(capacity=64,
shapes=[[batch_size_each, 3] + image_shape, [batch_size_each] + image_shape],
dtypes=['float32', 'int32'])
......@@ -197,7 +197,7 @@ with fluid.program_guard(tp, sp):
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = fluid.core.get_cuda_device_count()
exec_strategy.num_threads = utility.get_device_count()
exec_strategy.num_iteration_per_drop_scope = 100
build_strategy = fluid.BuildStrategy()
if args.memory_optimize:
......@@ -225,11 +225,11 @@ else:
binary = fluid.compiler.CompiledProgram(tp)
if args.use_py_reader:
assert(batch_size % fluid.core.get_cuda_device_count() == 0)
assert(batch_size % utility.get_device_count() == 0)
def data_gen():
batches = dataset.get_batch_generator(
batch_size // fluid.core.get_cuda_device_count(),
total_step * fluid.core.get_cuda_device_count(),
batch_size // utility.get_device_count(),
total_step * utility.get_device_count(),
use_multiprocessing=args.use_multiprocessing, num_workers=args.num_workers)
for b in batches:
yield b[0], b[1]
......@@ -266,7 +266,7 @@ print("Training done. Model is saved to", args.save_weights_path)
save_model()
if args.enable_ce:
gpu_num = fluid.core.get_cuda_device_count()
gpu_num = utility.get_device_count()
print("kpis\teach_pass_duration_card%s\t%s" %
(gpu_num, total_time / epoch_idx))
print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, train_loss))
......@@ -78,3 +78,12 @@ def check_gpu(use_gpu):
sys.exit(1)
except Exception as e:
pass
def get_device_count():
try:
device_num = max(fluid.core.get_cuda_device_count(), 1)
except:
device_num = 1
return device_num
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册