提交 43122c20 编写于 作者: D Dang Qingqing

Fix reader bug.

上级 1847c180
...@@ -174,6 +174,9 @@ def parallel_exe(args, ...@@ -174,6 +174,9 @@ def parallel_exe(args,
elif data_args.dataset == 'pascalvoc': elif data_args.dataset == 'pascalvoc':
num_classes = 21 num_classes = 21
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
gt_box = fluid.layers.data( gt_box = fluid.layers.data(
name='gt_box', shape=[4], dtype='float32', lod_level=1) name='gt_box', shape=[4], dtype='float32', lod_level=1)
...@@ -253,6 +256,7 @@ def parallel_exe(args, ...@@ -253,6 +256,7 @@ def parallel_exe(args,
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
prev_start_time = start_time prev_start_time = start_time
start_time = time.time() start_time = time.time()
if len(data) < devices_num: continue
loss_v, = train_exe.run(fetch_list=[loss.name], loss_v, = train_exe.run(fetch_list=[loss.name],
feed_dict=feeder.feed(data)) feed_dict=feeder.feed(data))
end_time = time.time() end_time = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册