提交 f49ec1d7 编写于 作者: K Kaipeng Deng 提交者: qingqing01

Fix process exit error and shuffle gt error (#1911)

* fix shuffle gt bug.
* fix process exit bug.
上级 4dd4632b
...@@ -124,7 +124,7 @@ class GeneratorEnqueuer(object): ...@@ -124,7 +124,7 @@ class GeneratorEnqueuer(object):
for thread in self._threads: for thread in self._threads:
if self._use_multiprocessing: if self._use_multiprocessing:
if thread.is_alive(): if thread.is_alive():
thread.terminate() thread.join(timeout)
else: else:
thread.join(timeout) thread.join(timeout)
if self._manager: if self._manager:
......
...@@ -154,7 +154,7 @@ def random_expand(img, gtboxes, max_ratio=4., fill=None, keep_ratio=True, thresh ...@@ -154,7 +154,7 @@ def random_expand(img, gtboxes, max_ratio=4., fill=None, keep_ratio=True, thresh
def shuffle_gtbox(gtbox, gtlabel, gtscore): def shuffle_gtbox(gtbox, gtlabel, gtscore):
gt = np.concatenate([gtbox, gtlabel[:, np.newaxis], gtscore[:, np.newaxis]], axis=1) gt = np.concatenate([gtbox, gtlabel[:, np.newaxis], gtscore[:, np.newaxis]], axis=1)
idx = np.arange(gt.shape[1]) idx = np.arange(gt.shape[0])
np.random.shuffle(idx) np.random.shuffle(idx)
gt = gt[idx, :] gt = gt[idx, :]
return gt[:, :4], gt[:, 4], gt[:, 5] return gt[:, :4], gt[:, 4], gt[:, 5]
......
...@@ -134,7 +134,6 @@ class YOLOv3(object): ...@@ -134,7 +134,6 @@ class YOLOv3(object):
anchor_mask = cfg.anchor_masks[i] anchor_mask = cfg.anchor_masks[i]
if self.is_train: if self.is_train:
ignore_thresh = float(self.ignore_thresh)
loss = fluid.layers.yolov3_loss( loss = fluid.layers.yolov3_loss(
x=out, x=out,
gtbox=self.gtbox, gtbox=self.gtbox,
......
...@@ -250,10 +250,11 @@ dsr = DataSetReader() ...@@ -250,10 +250,11 @@ dsr = DataSetReader()
def train(size=416, def train(size=416,
batch_size=64, batch_size=64,
shuffle=True, shuffle=True,
total_iter=0,
mixup_iter=0, mixup_iter=0,
random_sizes=[], random_sizes=[],
num_workers=12, num_workers=8,
max_queue=36, max_queue=32,
use_multiprocessing=True): use_multiprocessing=True):
generator = dsr.get_reader('train', size, batch_size, shuffle, int(mixup_iter/num_workers), random_sizes) generator = dsr.get_reader('train', size, batch_size, shuffle, int(mixup_iter/num_workers), random_sizes)
...@@ -266,6 +267,7 @@ def train(size=416, ...@@ -266,6 +267,7 @@ def train(size=416,
yield data yield data
def reader(): def reader():
cnt = 0
try: try:
enqueuer = GeneratorEnqueuer( enqueuer = GeneratorEnqueuer(
infinite_reader(), use_multiprocessing=use_multiprocessing) infinite_reader(), use_multiprocessing=use_multiprocessing)
...@@ -279,6 +281,10 @@ def train(size=416, ...@@ -279,6 +281,10 @@ def train(size=416,
else: else:
time.sleep(0.02) time.sleep(0.02)
yield generator_out yield generator_out
cnt += 1
if cnt >= total_iter:
enqueuer.stop()
return
generator_out = None generator_out = None
finally: finally:
if enqueuer is not None: if enqueuer is not None:
......
...@@ -86,8 +86,9 @@ def train(): ...@@ -86,8 +86,9 @@ def train():
if cfg.random_shape: if cfg.random_shape:
random_sizes = [32 * i for i in range(10, 20)] random_sizes = [32 * i for i in range(10, 20)]
mixup_iter = cfg.max_iter - cfg.start_iter - cfg.no_mixup_iter total_iter = cfg.max_iter - cfg.start_iter
train_reader = reader.train(input_size, batch_size=cfg.batch_size, shuffle=True, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess) mixup_iter = total_iter - cfg.no_mixup_iter
train_reader = reader.train(input_size, batch_size=cfg.batch_size, shuffle=True, total_iter=total_iter*devices_num, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess)
py_reader = model.py_reader py_reader = model.py_reader
py_reader.decorate_paddle_reader(train_reader) py_reader.decorate_paddle_reader(train_reader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册