未验证 提交 fff5cc27 编写于 作者: X xiaoting 提交者: GitHub

change reader to DataLoader (#4293)

* fix load for yolo_dy
上级 c3829289
......@@ -126,7 +126,6 @@ class DataSetReader(object):
for k in ['date_captured', 'url', 'license', 'file_name']:
if k in img:
del img[k]
if is_train:
self._parse_gt_annotations(img)
......@@ -147,6 +146,7 @@ class DataSetReader(object):
shuffle=False,
shuffle_seed=None,
mixup_iter=0,
max_iter=0,
random_sizes=[],
image=None):
assert mode in ['train', 'test', 'infer'], "Unknow mode type!"
......@@ -248,6 +248,8 @@ class DataSetReader(object):
yield batch_out
batch_out = []
total_iter += 1
if total_iter >= max_iter:
return
img_size = get_img_size(size, random_sizes)
elif mode == 'test':
......@@ -296,10 +298,10 @@ def train(size=416,
random_sizes=[],
num_workers=8,
max_queue=32,
use_multiprocess_reader=True):
use_multiprocess_reader=True,
use_gpu=True):
generator = dsr.get_reader('train', size, batch_size, shuffle, shuffle_seed,
int(mixup_iter / num_workers), random_sizes)
int(mixup_iter / num_workers), total_iter, random_sizes)
if not use_multiprocess_reader:
return generator
else:
......@@ -317,30 +319,21 @@ def train(size=416,
def reader():
cnt = 0
try:
enqueuer = GeneratorEnqueuer(
infinite_reader(), use_multiprocessing=use_multiprocess_reader)
enqueuer.start(max_queue_size=max_queue, workers=num_workers)
generator_out = None
while True:
while enqueuer.is_running():
if not enqueuer.queue.empty():
generator_out = enqueuer.queue.get()
break
data_loader = fluid.io.DataLoader.from_generator(capacity=64,use_multiprocess=True,iterable=True)
if use_gpu:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) if cfg.use_data_parallel else fluid.CUDAPlace(0)
else:
time.sleep(0.02)
yield generator_out
place = fluid.CPUPlace()
data_loader.set_sample_list_generator(infinite_reader,places=place)
generator_out = []
for data in data_loader():
for i in data:
generator_out.append(i.numpy()[0])
yield [generator_out]
generator_out = []
cnt += 1
if cnt >= total_iter:
enqueuer.stop()
return
generator_out = None
except Exception as e:
print("Exception occured in reader: {}".format(str(e)))
finally:
if enqueuer:
enqueuer.stop()
return reader
......
......@@ -75,7 +75,8 @@ def train():
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) if cfg.use_data_parallel else fluid.CUDAPlace(0)
if not cfg.use_gpu:
palce = fluid.CPUPlace()
with fluid.dygraph.guard(place):
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
......@@ -141,8 +142,8 @@ def train():
mixup_iter=mixup_iter * devices_num,
random_sizes=random_sizes,
use_multiprocess_reader=cfg.use_multiprocess_reader,
num_workers=cfg.worker_num)
num_workers=cfg.worker_num,
use_gpu=cfg.use_gpu)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(train_reader)
smoothed_loss = SmoothedValue()
......@@ -150,7 +151,6 @@ def train():
for iter_id, data in enumerate(train_reader()):
prev_start_time = start_time
start_time = time.time()
img = np.array([x[0] for x in data]).astype('float32')
img = to_variable(img)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册