From fff5cc27d1a17f5e39c11b6e5836674aea168806 Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Mon, 17 Feb 2020 11:52:55 +0800 Subject: [PATCH] change reader to DataLoader (#4293) * fix load for yolo_dy --- dygraph/yolov3/reader.py | 49 +++++++++++++++++----------------------- dygraph/yolov3/train.py | 8 +++---- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/dygraph/yolov3/reader.py b/dygraph/yolov3/reader.py index 92a7ac1a..c74a938e 100644 --- a/dygraph/yolov3/reader.py +++ b/dygraph/yolov3/reader.py @@ -126,7 +126,6 @@ class DataSetReader(object): for k in ['date_captured', 'url', 'license', 'file_name']: if k in img: del img[k] - if is_train: self._parse_gt_annotations(img) @@ -147,6 +146,7 @@ class DataSetReader(object): shuffle=False, shuffle_seed=None, mixup_iter=0, + max_iter=0, random_sizes=[], image=None): assert mode in ['train', 'test', 'infer'], "Unknow mode type!" @@ -248,6 +248,8 @@ class DataSetReader(object): yield batch_out batch_out = [] total_iter += 1 + if total_iter >= max_iter: + return img_size = get_img_size(size, random_sizes) elif mode == 'test': @@ -296,10 +298,10 @@ def train(size=416, random_sizes=[], num_workers=8, max_queue=32, - use_multiprocess_reader=True): + use_multiprocess_reader=True, + use_gpu=True): generator = dsr.get_reader('train', size, batch_size, shuffle, shuffle_seed, - int(mixup_iter / num_workers), random_sizes) - + int(mixup_iter / num_workers), total_iter, random_sizes) if not use_multiprocess_reader: return generator else: @@ -317,30 +319,21 @@ def train(size=416, def reader(): cnt = 0 - try: - enqueuer = GeneratorEnqueuer( - infinite_reader(), use_multiprocessing=use_multiprocess_reader) - enqueuer.start(max_queue_size=max_queue, workers=num_workers) - generator_out = None - while True: - while enqueuer.is_running(): - if not enqueuer.queue.empty(): - generator_out = enqueuer.queue.get() - break - else: - time.sleep(0.02) - yield generator_out - cnt += 1 - if cnt >= total_iter: - enqueuer.stop() - return - generator_out = None - except Exception as e: - print("Exception occured in reader: {}".format(str(e))) - finally: - if enqueuer: - enqueuer.stop() - + data_loader = fluid.io.DataLoader.from_generator(capacity=64,use_multiprocess=True,iterable=True) + if use_gpu: + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) if cfg.use_data_parallel else fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + data_loader.set_sample_list_generator(infinite_reader,places=place) + generator_out = [] + for data in data_loader(): + for i in data: + generator_out.append(i.numpy()[0]) + yield [generator_out] + generator_out = [] + cnt += 1 + if cnt >= total_iter: + return return reader diff --git a/dygraph/yolov3/train.py b/dygraph/yolov3/train.py index 7c1548e6..7c742fd3 100755 --- a/dygraph/yolov3/train.py +++ b/dygraph/yolov3/train.py @@ -75,7 +75,8 @@ def train(): gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) if cfg.use_data_parallel else fluid.CUDAPlace(0) - + if not cfg.use_gpu: + palce = fluid.CPUPlace() with fluid.dygraph.guard(place): if args.use_data_parallel: strategy = fluid.dygraph.parallel.prepare_context() @@ -141,8 +142,8 @@ def train(): mixup_iter=mixup_iter * devices_num, random_sizes=random_sizes, use_multiprocess_reader=cfg.use_multiprocess_reader, - num_workers=cfg.worker_num) - + num_workers=cfg.worker_num, + use_gpu=cfg.use_gpu) if args.use_data_parallel: train_reader = fluid.contrib.reader.distributed_batch_reader(train_reader) smoothed_loss = SmoothedValue() @@ -150,7 +151,6 @@ def train(): for iter_id, data in enumerate(train_reader()): prev_start_time = start_time start_time = time.time() - img = np.array([x[0] for x in data]).astype('float32') img = to_variable(img) -- GitLab