提交 83987d01 编写于 作者: W walloollaw 提交者: qingqing01

fix bug in sharedmemory (#3229)

* fix bug in sharedmemory
* fix bug in sharedmemory
上级 cd8b44c5
......@@ -90,18 +90,15 @@ def create_reader(feed, max_iter=0, args_path=None, my_source=None):
# named `DATASET_DIR` (e.g., coco, pascal), if not present either, download
data_config = _prepare_data_config(feed, args_path)
bufsize = 10
use_process = False
if getattr(feed, 'bufsize', None) is not None:
bufsize = feed.bufsize
if getattr(feed, 'use_process', None) is not None:
use_process = feed.use_process
bufsize = getattr(feed, 'bufsize', 10)
use_process = getattr(feed, 'use_process', False)
memsize = getattr(feed, 'memsize', '3G')
transform_config = {
'WORKER_CONF': {
'bufsize': bufsize,
'worker_num': feed.num_workers,
'use_process': use_process
'use_process': use_process,
'memsize': memsize
},
'BATCH_SIZE': feed.batch_size,
'DROP_LAST': feed.drop_last,
......@@ -282,6 +279,10 @@ class DataFeed(object):
shuffle (bool): if samples should be shuffled
drop_last (bool): drop last batch if size is uneven
num_workers (int): number of workers processes (or threads)
bufsize (int): size of queue used to buffer results from workers
use_process (bool): use process or thread as workers
memsize (str): size of shared memory used in result queue
when 'use_process' is True, default to '3G'
"""
__category__ = 'data'
......@@ -299,6 +300,7 @@ class DataFeed(object):
num_workers=2,
bufsize=10,
use_process=False,
memsize=None,
use_padded_im_info=False):
super(DataFeed, self).__init__()
self.fields = fields
......@@ -313,6 +315,7 @@ class DataFeed(object):
self.num_workers = num_workers
self.bufsize = bufsize
self.use_process = use_process
self.memsize = memsize
self.dataset = dataset
self.use_padded_im_info = use_padded_im_info
if isinstance(dataset, dict):
......@@ -337,7 +340,8 @@ class TrainFeed(DataFeed):
with_background=True,
num_workers=2,
bufsize=10,
use_process=True):
use_process=True,
memsize=None):
super(TrainFeed, self).__init__(
dataset,
fields,
......@@ -351,7 +355,8 @@ class TrainFeed(DataFeed):
with_background=with_background,
num_workers=num_workers,
bufsize=bufsize,
use_process=use_process, )
use_process=use_process,
memsize=memsize)
@register
......@@ -439,8 +444,10 @@ class FasterRCNNTrainFeed(DataFeed):
shuffle=True,
samples=-1,
drop_last=False,
bufsize=10,
num_workers=2,
use_process=False):
use_process=False,
memsize=None):
# XXX this should be handled by the data loader, since `fields` is
# given, just collect them
sample_transforms.append(ArrangeRCNN())
......@@ -454,8 +461,10 @@ class FasterRCNNTrainFeed(DataFeed):
shuffle=shuffle,
samples=samples,
drop_last=drop_last,
bufsize=bufsize,
num_workers=num_workers,
use_process=use_process)
use_process=use_process,
memsize=memsize)
# XXX these modes should be unified
self.mode = 'TRAIN'
......@@ -722,7 +731,8 @@ class SSDTrainFeed(DataFeed):
drop_last=True,
num_workers=8,
bufsize=10,
use_process=True):
use_process=True,
memsize=None):
sample_transforms.append(ArrangeSSD())
super(SSDTrainFeed, self).__init__(
dataset,
......@@ -736,7 +746,8 @@ class SSDTrainFeed(DataFeed):
drop_last=drop_last,
num_workers=num_workers,
bufsize=bufsize,
use_process=use_process)
use_process=use_process,
memsize=None)
self.mode = 'TRAIN'
......@@ -767,7 +778,8 @@ class SSDEvalFeed(DataFeed):
drop_last=True,
num_workers=8,
bufsize=10,
use_process=False):
use_process=False,
memsize=None):
sample_transforms.append(ArrangeEvalSSD())
super(SSDEvalFeed, self).__init__(
dataset,
......@@ -781,7 +793,8 @@ class SSDEvalFeed(DataFeed):
drop_last=drop_last,
num_workers=num_workers,
bufsize=bufsize,
use_process=use_process)
use_process=use_process,
memsize=memsize)
self.mode = 'VAL'
......@@ -809,7 +822,8 @@ class SSDTestFeed(DataFeed):
drop_last=False,
num_workers=8,
bufsize=10,
use_process=False):
use_process=False,
memsize=None):
sample_transforms.append(ArrangeTestSSD())
if isinstance(dataset, dict):
dataset = SimpleDataSet(**dataset)
......@@ -825,7 +839,8 @@ class SSDTestFeed(DataFeed):
drop_last=drop_last,
num_workers=num_workers,
bufsize=bufsize,
use_process=use_process)
use_process=use_process,
memsize=memsize)
self.mode = 'TEST'
......@@ -873,6 +888,7 @@ class YoloTrainFeed(DataFeed):
num_workers=8,
bufsize=128,
use_process=True,
memsize=None,
num_max_boxes=50,
mixup_epoch=250):
sample_transforms.append(ArrangeYOLO())
......@@ -889,7 +905,8 @@ class YoloTrainFeed(DataFeed):
with_background=with_background,
num_workers=num_workers,
bufsize=bufsize,
use_process=use_process)
use_process=use_process,
memsize=memsize)
self.num_max_boxes = num_max_boxes
self.mixup_epoch = mixup_epoch
self.mode = 'TRAIN'
......@@ -923,7 +940,8 @@ class YoloEvalFeed(DataFeed):
with_background=False,
num_workers=8,
num_max_boxes=50,
use_process=False):
use_process=False,
memsize=None):
sample_transforms.append(ArrangeEvalYOLO())
super(YoloEvalFeed, self).__init__(
dataset,
......@@ -937,7 +955,8 @@ class YoloEvalFeed(DataFeed):
drop_last=drop_last,
with_background=with_background,
num_workers=num_workers,
use_process=use_process)
use_process=use_process,
memsize=memsize)
self.num_max_boxes = num_max_boxes
self.mode = 'VAL'
self.bufsize = 128
......@@ -976,7 +995,8 @@ class YoloTestFeed(DataFeed):
with_background=False,
num_workers=8,
num_max_boxes=50,
use_process=False):
use_process=False,
memsize=None):
sample_transforms.append(ArrangeTestYOLO())
if isinstance(dataset, dict):
dataset = SimpleDataSet(**dataset)
......@@ -992,7 +1012,8 @@ class YoloTestFeed(DataFeed):
drop_last=drop_last,
with_background=with_background,
num_workers=num_workers,
use_process=use_process)
use_process=use_process,
memsize=memsize)
self.mode = 'TEST'
self.bufsize = 128
......
......@@ -29,7 +29,11 @@ TRANSFORM:
BATCH_SIZE: 1
IS_PADDING: True
DROP_LAST: False
WORKER_CONF:
BUFSIZE: 100
WORKER_NUM: 4
USE_PROCESS: True
MEMSIZE: 2G
VAL:
OPS:
- OP: DecodeImage
......@@ -39,6 +43,6 @@ TRANSFORM:
- OP: ArrangeSSD
BATCH_SIZE: 1
WORKER_CONF:
BUFSIZE: 200
WORKER_NUM: 8
USE_PROCESS: False
BUFSIZE: 100
WORKER_NUM: 4
USE_PROCESS: True
......@@ -26,5 +26,7 @@ TRANSFORM:
IS_PADDING: True
DROP_LAST: False
WORKER_CONF:
BUFSIZE: 10
WORKER_NUM: 2
BUFSIZE: 100
WORKER_NUM: 4
MEMSIZE: 2G
USE_PROCESS: True
......@@ -49,8 +49,15 @@ class ParallelMappedDataset(ProxiedDataset):
super(ParallelMappedDataset, self).__init__(source)
worker_args = {k.lower(): v for k, v in worker_args.items()}
args = {'bufsize': 100, 'worker_num': 8}
args = {'bufsize': 100, 'worker_num': 8,
'use_process': False, 'memsize': '3G'}
args.update(worker_args)
if args['use_process'] and type(args['memsize']) is str:
assert args['memsize'][-1].lower() == 'g', \
"invalid param for memsize[%s], should be ended with 'G' or 'g'" % (args['memsize'])
gb = args['memsize'][:-1]
args['memsize'] = int(gb) * 1024 ** 3
self._worker_args = args
self._started = False
self._source = source
......@@ -60,9 +67,7 @@ class ParallelMappedDataset(ProxiedDataset):
def _setup(self):
"""setup input/output queues and workers """
use_process = False
if 'use_process' in self._worker_args:
use_process = self._worker_args['use_process']
use_process = self._worker_args.get('use_process', False)
if use_process and sys.platform == "win32":
logger.info("Use multi-thread reader instead of "
"multi-process reader on Windows.")
......@@ -73,6 +78,9 @@ class ParallelMappedDataset(ProxiedDataset):
from .shared_queue import SharedQueue as Queue
from multiprocessing import Process as Worker
from multiprocessing import Event
memsize = self._worker_args['memsize']
self._inq = Queue(bufsize, memsize=memsize)
self._outq = Queue(bufsize, memsize=memsize)
else:
if six.PY3:
from queue import Queue
......@@ -80,11 +88,10 @@ class ParallelMappedDataset(ProxiedDataset):
from Queue import Queue
from threading import Thread as Worker
from threading import Event
self._inq = Queue(bufsize)
self._outq = Queue(bufsize)
consumer_num = self._worker_args['worker_num']
consumer_num = self._worker_args['worker_num']
id = str(uuid.uuid4())[-3:]
self._producer = threading.Thread(
target=self._produce,
......
......@@ -318,16 +318,14 @@ class PageAllocator(object):
while True:
# maybe flags already has some '0' pages,
# so just check 'page_num - len(flags)' pages
flags += self.get_page_status(
pos, page_num - len(flags), ret_flag=True)
flags = self.get_page_status(
pos, page_num, ret_flag=True)
if flags.count('0') == page_num:
break
# not found enough pages, so shift to next few pages
free_pos = flags.rfind('1') + 1
flags = flags[free_pos:]
pos += free_pos
end = pos + page_num
if end > pages:
......@@ -355,9 +353,6 @@ class PageAllocator(object):
self.set_page_status(pos, page_num, '1')
used += page_num
self.set_alloc_info(end, used)
assert self.get_page_status(pos, page_num) == (page_num, 1), \
'faild to validate the page status'
return pos
def free_page(self, start, page_num):
......@@ -530,7 +525,7 @@ class SharedMemoryMgr(object):
logger.info('destroy [%s]' % (self))
if not self._released and not self._allocator.empty():
logger.warn('not empty when delete this SharedMemoryMgr[%s]' %
logger.debug('not empty when delete this SharedMemoryMgr[%s]' %
(self))
else:
self._released = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册