提交 3be84530 编写于 作者: D dengkaipeng

add random_shape interval

上级 71d7986a
...@@ -6,11 +6,14 @@ import time ...@@ -6,11 +6,14 @@ import time
import numpy as np import numpy as np
import threading import threading
import multiprocessing import multiprocessing
import cv2
try: try:
import queue import queue
except ImportError: except ImportError:
import Queue as queue import Queue as queue
import image_utils
class GeneratorEnqueuer(object): class GeneratorEnqueuer(object):
""" """
...@@ -35,11 +38,11 @@ class GeneratorEnqueuer(object): ...@@ -35,11 +38,11 @@ class GeneratorEnqueuer(object):
self._use_multiprocessing = use_multiprocessing self._use_multiprocessing = use_multiprocessing
self._threads = [] self._threads = []
self._stop_event = None self._stop_event = None
self.queue = None self.queues = []
self._manager = None self._manager = None
self.seed = random_seed self.seed = random_seed
def start(self, workers=1, max_queue_size=10): def start(self, workers=1, max_queue_size=10, random_sizes=[608]):
""" """
Start worker threads which add data from the generator into the queue. Start worker threads which add data from the generator into the queue.
...@@ -49,16 +52,37 @@ class GeneratorEnqueuer(object): ...@@ -49,16 +52,37 @@ class GeneratorEnqueuer(object):
(when full, threads could block on `put()`) (when full, threads could block on `put()`)
""" """
self.random_sizes = random_sizes
self.size_num = len(random_sizes)
def data_generator_task(): def data_generator_task():
""" """
Data generator task. Data generator task.
""" """
def task(): def task():
if (self.queue is not None and if len(self.queues) > 0:
self.queue.qsize() < max_queue_size):
generator_output = next(self._generator) generator_output = next(self._generator)
self.queue.put((generator_output)) queue_idx = 0
while(True):
if self.queues[queue_idx].full():
queue_idx = (queue_idx + 1) % self.size_num
continue
else:
size = self.random_sizes[queue_idx]
for g in generator_output:
g[0] = g[0].transpose((1, 2, 0))
g[0] = image_utils.random_interp(g[0], size, cv2.INTER_LINEAR)
g[0] = g[0].transpose((2, 0, 1))
try:
self.queues[queue_idx].put_nowait(generator_output)
except:
continue
else:
break
# self.queue.put((generator_output))
else: else:
time.sleep(self.wait_time) time.sleep(self.wait_time)
...@@ -81,11 +105,14 @@ class GeneratorEnqueuer(object): ...@@ -81,11 +105,14 @@ class GeneratorEnqueuer(object):
try: try:
if self._use_multiprocessing: if self._use_multiprocessing:
self._manager = multiprocessing.Manager() self._manager = multiprocessing.Manager()
self.queue = self._manager.Queue(maxsize=max_queue_size) for i in range(self.size_num):
self.queues.append(self._manager.Queue(maxsize=max_queue_size))
self._stop_event = multiprocessing.Event() self._stop_event = multiprocessing.Event()
else: else:
self.genlock = threading.Lock() self.genlock = threading.Lock()
self.queue = queue.Queue() # self.queue = queue.Queue()
for i in range(self.size_num):
self.queues.append(queue.Queue())
self._stop_event = threading.Event() self._stop_event = threading.Event()
for _ in range(workers): for _ in range(workers):
if self._use_multiprocessing: if self._use_multiprocessing:
...@@ -134,7 +161,7 @@ class GeneratorEnqueuer(object): ...@@ -134,7 +161,7 @@ class GeneratorEnqueuer(object):
self._stop_event = None self._stop_event = None
self.queue = None self.queue = None
def get(self): def get(self, queue_idx):
""" """
Creates a generator to extract data from the queue. Creates a generator to extract data from the queue.
Skip the data if it is `None`. Skip the data if it is `None`.
...@@ -143,8 +170,8 @@ class GeneratorEnqueuer(object): ...@@ -143,8 +170,8 @@ class GeneratorEnqueuer(object):
tuple of data in the queue. tuple of data in the queue.
""" """
while self.is_running(): while self.is_running():
if not self.queue.empty(): if not self.queues[queue_idx].empty():
inputs = self.queue.get() inputs = self.queues[queue_idx].get()
if inputs is not None: if inputs is not None:
yield inputs yield inputs
else: else:
......
...@@ -209,7 +209,7 @@ def image_augment(img, gtboxes, gtlabels, gtscores, size, means=None): ...@@ -209,7 +209,7 @@ def image_augment(img, gtboxes, gtlabels, gtscores, size, means=None):
img = random_distort(img) img = random_distort(img)
img, gtboxes = random_expand(img, gtboxes, fill=means) img, gtboxes = random_expand(img, gtboxes, fill=means)
img, gtboxes, gtlabels, gtscores = random_crop(img, gtboxes, gtlabels, gtscores) img, gtboxes, gtlabels, gtscores = random_crop(img, gtboxes, gtlabels, gtscores)
img = random_interp(img, size, cv2.INTER_LINEAR) # img = random_interp(img, size, cv2.INTER_LINEAR)
img, gtboxes = random_flip(img, gtboxes) img, gtboxes = random_flip(img, gtboxes)
gtboxes, gtlabels, gtscores = shuffle_gtbox(gtboxes, gtlabels, gtscores) gtboxes, gtlabels, gtscores = shuffle_gtbox(gtboxes, gtlabels, gtscores)
......
...@@ -41,7 +41,7 @@ class DataSetReader(object): ...@@ -41,7 +41,7 @@ class DataSetReader(object):
# cfg.data_dir = "dataset/coco" # cfg.data_dir = "dataset/coco"
# cfg.train_file_list = 'annotations/instances_val2017.json' # cfg.train_file_list = 'annotations/instances_val2017.json'
# cfg.train_data_dir = 'val2017' # cfg.train_data_dir = 'val2017'
# cfg.dataset = "coco2017" cfg.dataset = "coco2017"
if 'coco2014' in cfg.dataset: if 'coco2014' in cfg.dataset:
cfg.train_file_list = 'annotations/instances_train2014.json' cfg.train_file_list = 'annotations/instances_train2014.json'
cfg.train_data_dir = 'train2014' cfg.train_data_dir = 'train2014'
...@@ -224,14 +224,15 @@ class DataSetReader(object): ...@@ -224,14 +224,15 @@ class DataSetReader(object):
if read_cnt % len(imgs) == 0 and shuffle: if read_cnt % len(imgs) == 0 and shuffle:
np.random.shuffle(imgs) np.random.shuffle(imgs)
im, gt_boxes, gt_labels, gt_scores = img_reader_with_augment(img, img_size, cfg.pixel_means, cfg.pixel_stds, mixup_img) im, gt_boxes, gt_labels, gt_scores = img_reader_with_augment(img, img_size, cfg.pixel_means, cfg.pixel_stds, mixup_img)
batch_out.append((im, gt_boxes, gt_labels, gt_scores)) batch_out.append([im, gt_boxes, gt_labels, gt_scores])
# img_ids.append((img['id'], mixup_img['id'] if mixup_img else -1)) # img_ids.append((img['id'], mixup_img['id'] if mixup_img else -1))
if len(batch_out) == batch_size: if len(batch_out) == batch_size:
# print("img_ids: ", img_ids) # print("img_ids: ", img_ids)
yield batch_out yield batch_out
batch_out = [] batch_out = []
img_size = get_img_size(size, random_sizes) if total_read_cnt % 10 == 0:
img_size = get_img_size(size, random_sizes)
# img_ids = [] # img_ids = []
elif mode == 'test': elif mode == 'test':
...@@ -263,6 +264,8 @@ def train(size=416, ...@@ -263,6 +264,8 @@ def train(size=416,
shuffle=True, shuffle=True,
mixup_iter=0, mixup_iter=0,
random_sizes=[], random_sizes=[],
interval=10,
pyreader_num=1,
use_multiprocessing=True, use_multiprocessing=True,
num_workers=8, num_workers=8,
max_queue=24): max_queue=24):
...@@ -280,17 +283,25 @@ def train(size=416, ...@@ -280,17 +283,25 @@ def train(size=416,
try: try:
enqueuer = GeneratorEnqueuer( enqueuer = GeneratorEnqueuer(
infinite_reader(), use_multiprocessing=use_multiprocessing) infinite_reader(), use_multiprocessing=use_multiprocessing)
enqueuer.start(max_queue_size=max_queue, workers=num_workers) enqueuer.start(max_queue_size=max_queue, workers=num_workers, random_sizes=random_sizes)
generator_out = None generator_out = None
np.random.seed(1000)
intervals = pyreader_num * interval
cnt = 0
idx = np.random.randint(len(random_sizes))
while True: while True:
while enqueuer.is_running(): while enqueuer.is_running():
if not enqueuer.queue.empty(): if not enqueuer.queues[idx].empty():
generator_out = enqueuer.queue.get() generator_out = enqueuer.queues[idx].get()
break break
else: else:
print(idx," empty")
time.sleep(0.02) time.sleep(0.02)
yield generator_out yield generator_out
generator_out = None generator_out = None
cnt += 1
if cnt % intervals == 0:
idx = np.random.randint(len(random_sizes))
finally: finally:
if enqueuer is not None: if enqueuer is not None:
enqueuer.stop() enqueuer.stop()
......
...@@ -96,11 +96,11 @@ def train(): ...@@ -96,11 +96,11 @@ def train():
mixup_iter = cfg.max_iter - cfg.start_iter - cfg.no_mixup_iter mixup_iter = cfg.max_iter - cfg.start_iter - cfg.no_mixup_iter
if cfg.use_pyreader: if cfg.use_pyreader:
train_reader = reader.train(input_size, batch_size=int(hyperparams['batch'])/devices_num, shuffle=True, mixup_iter=mixup_iter, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess) train_reader = reader.train(input_size, batch_size=int(hyperparams['batch'])/devices_num, shuffle=True, mixup_iter=mixup_iter, random_sizes=random_sizes, interval=10, pyreader_num=devices, use_multiprocessing=cfg.use_multiprocess)
py_reader = model.py_reader py_reader = model.py_reader
py_reader.decorate_paddle_reader(train_reader) py_reader.decorate_paddle_reader(train_reader)
else: else:
train_reader = reader.train(input_size, batch_size=int(hyperparams['batch']), shuffle=True, mixup_iter=mixup_iter, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess) train_reader = reader.train(input_size, batch_size=int(hyperparams['batch']), shuffle=True, mixup_iter=mixup_iter, random_sizes=random_sizes, interval=10, use_multiprocessing=cfg.use_multiprocess)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds()) feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
def save_model(postfix): def save_model(postfix):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册