未验证 提交 fff5cc27 编写于 作者: X xiaoting 提交者: GitHub

change reader to DataLoader (#4293)

* fix load for yolo_dy
上级 c3829289
...@@ -126,7 +126,6 @@ class DataSetReader(object): ...@@ -126,7 +126,6 @@ class DataSetReader(object):
for k in ['date_captured', 'url', 'license', 'file_name']: for k in ['date_captured', 'url', 'license', 'file_name']:
if k in img: if k in img:
del img[k] del img[k]
if is_train: if is_train:
self._parse_gt_annotations(img) self._parse_gt_annotations(img)
...@@ -147,6 +146,7 @@ class DataSetReader(object): ...@@ -147,6 +146,7 @@ class DataSetReader(object):
shuffle=False, shuffle=False,
shuffle_seed=None, shuffle_seed=None,
mixup_iter=0, mixup_iter=0,
max_iter=0,
random_sizes=[], random_sizes=[],
image=None): image=None):
assert mode in ['train', 'test', 'infer'], "Unknow mode type!" assert mode in ['train', 'test', 'infer'], "Unknow mode type!"
...@@ -248,6 +248,8 @@ class DataSetReader(object): ...@@ -248,6 +248,8 @@ class DataSetReader(object):
yield batch_out yield batch_out
batch_out = [] batch_out = []
total_iter += 1 total_iter += 1
if total_iter >= max_iter:
return
img_size = get_img_size(size, random_sizes) img_size = get_img_size(size, random_sizes)
elif mode == 'test': elif mode == 'test':
...@@ -296,10 +298,10 @@ def train(size=416, ...@@ -296,10 +298,10 @@ def train(size=416,
random_sizes=[], random_sizes=[],
num_workers=8, num_workers=8,
max_queue=32, max_queue=32,
use_multiprocess_reader=True): use_multiprocess_reader=True,
use_gpu=True):
generator = dsr.get_reader('train', size, batch_size, shuffle, shuffle_seed, generator = dsr.get_reader('train', size, batch_size, shuffle, shuffle_seed,
int(mixup_iter / num_workers), random_sizes) int(mixup_iter / num_workers), total_iter, random_sizes)
if not use_multiprocess_reader: if not use_multiprocess_reader:
return generator return generator
else: else:
...@@ -317,30 +319,21 @@ def train(size=416, ...@@ -317,30 +319,21 @@ def train(size=416,
def reader(): def reader():
cnt = 0 cnt = 0
try: data_loader = fluid.io.DataLoader.from_generator(capacity=64,use_multiprocess=True,iterable=True)
enqueuer = GeneratorEnqueuer( if use_gpu:
infinite_reader(), use_multiprocessing=use_multiprocess_reader) place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) if cfg.use_data_parallel else fluid.CUDAPlace(0)
enqueuer.start(max_queue_size=max_queue, workers=num_workers) else:
generator_out = None place = fluid.CPUPlace()
while True: data_loader.set_sample_list_generator(infinite_reader,places=place)
while enqueuer.is_running(): generator_out = []
if not enqueuer.queue.empty(): for data in data_loader():
generator_out = enqueuer.queue.get() for i in data:
break generator_out.append(i.numpy()[0])
else: yield [generator_out]
time.sleep(0.02) generator_out = []
yield generator_out cnt += 1
cnt += 1 if cnt >= total_iter:
if cnt >= total_iter: return
enqueuer.stop()
return
generator_out = None
except Exception as e:
print("Exception occured in reader: {}".format(str(e)))
finally:
if enqueuer:
enqueuer.stop()
return reader return reader
......
...@@ -75,7 +75,8 @@ def train(): ...@@ -75,7 +75,8 @@ def train():
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) if cfg.use_data_parallel else fluid.CUDAPlace(0) place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) if cfg.use_data_parallel else fluid.CUDAPlace(0)
if not cfg.use_gpu:
palce = fluid.CPUPlace()
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
if args.use_data_parallel: if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context() strategy = fluid.dygraph.parallel.prepare_context()
...@@ -141,8 +142,8 @@ def train(): ...@@ -141,8 +142,8 @@ def train():
mixup_iter=mixup_iter * devices_num, mixup_iter=mixup_iter * devices_num,
random_sizes=random_sizes, random_sizes=random_sizes,
use_multiprocess_reader=cfg.use_multiprocess_reader, use_multiprocess_reader=cfg.use_multiprocess_reader,
num_workers=cfg.worker_num) num_workers=cfg.worker_num,
use_gpu=cfg.use_gpu)
if args.use_data_parallel: if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(train_reader) train_reader = fluid.contrib.reader.distributed_batch_reader(train_reader)
smoothed_loss = SmoothedValue() smoothed_loss = SmoothedValue()
...@@ -150,7 +151,6 @@ def train(): ...@@ -150,7 +151,6 @@ def train():
for iter_id, data in enumerate(train_reader()): for iter_id, data in enumerate(train_reader()):
prev_start_time = start_time prev_start_time = start_time
start_time = time.time() start_time = time.time()
img = np.array([x[0] for x in data]).astype('float32') img = np.array([x[0] for x in data]).astype('float32')
img = to_variable(img) img = to_variable(img)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册