提交 83987d01 编写于 作者: W walloollaw 提交者: qingqing01

fix bug in sharedmemory (#3229)

* fix bug in sharedmemory
* fix bug in sharedmemory
上级 cd8b44c5
...@@ -90,18 +90,15 @@ def create_reader(feed, max_iter=0, args_path=None, my_source=None): ...@@ -90,18 +90,15 @@ def create_reader(feed, max_iter=0, args_path=None, my_source=None):
# named `DATASET_DIR` (e.g., coco, pascal), if not present either, download # named `DATASET_DIR` (e.g., coco, pascal), if not present either, download
data_config = _prepare_data_config(feed, args_path) data_config = _prepare_data_config(feed, args_path)
bufsize = 10 bufsize = getattr(feed, 'bufsize', 10)
use_process = False use_process = getattr(feed, 'use_process', False)
if getattr(feed, 'bufsize', None) is not None: memsize = getattr(feed, 'memsize', '3G')
bufsize = feed.bufsize
if getattr(feed, 'use_process', None) is not None:
use_process = feed.use_process
transform_config = { transform_config = {
'WORKER_CONF': { 'WORKER_CONF': {
'bufsize': bufsize, 'bufsize': bufsize,
'worker_num': feed.num_workers, 'worker_num': feed.num_workers,
'use_process': use_process 'use_process': use_process,
'memsize': memsize
}, },
'BATCH_SIZE': feed.batch_size, 'BATCH_SIZE': feed.batch_size,
'DROP_LAST': feed.drop_last, 'DROP_LAST': feed.drop_last,
...@@ -282,6 +279,10 @@ class DataFeed(object): ...@@ -282,6 +279,10 @@ class DataFeed(object):
shuffle (bool): if samples should be shuffled shuffle (bool): if samples should be shuffled
drop_last (bool): drop last batch if size is uneven drop_last (bool): drop last batch if size is uneven
num_workers (int): number of workers processes (or threads) num_workers (int): number of workers processes (or threads)
bufsize (int): size of queue used to buffer results from workers
use_process (bool): use process or thread as workers
memsize (str): size of shared memory used in result queue
when 'use_process' is True, default to '3G'
""" """
__category__ = 'data' __category__ = 'data'
...@@ -299,6 +300,7 @@ class DataFeed(object): ...@@ -299,6 +300,7 @@ class DataFeed(object):
num_workers=2, num_workers=2,
bufsize=10, bufsize=10,
use_process=False, use_process=False,
memsize=None,
use_padded_im_info=False): use_padded_im_info=False):
super(DataFeed, self).__init__() super(DataFeed, self).__init__()
self.fields = fields self.fields = fields
...@@ -313,6 +315,7 @@ class DataFeed(object): ...@@ -313,6 +315,7 @@ class DataFeed(object):
self.num_workers = num_workers self.num_workers = num_workers
self.bufsize = bufsize self.bufsize = bufsize
self.use_process = use_process self.use_process = use_process
self.memsize = memsize
self.dataset = dataset self.dataset = dataset
self.use_padded_im_info = use_padded_im_info self.use_padded_im_info = use_padded_im_info
if isinstance(dataset, dict): if isinstance(dataset, dict):
...@@ -337,7 +340,8 @@ class TrainFeed(DataFeed): ...@@ -337,7 +340,8 @@ class TrainFeed(DataFeed):
with_background=True, with_background=True,
num_workers=2, num_workers=2,
bufsize=10, bufsize=10,
use_process=True): use_process=True,
memsize=None):
super(TrainFeed, self).__init__( super(TrainFeed, self).__init__(
dataset, dataset,
fields, fields,
...@@ -351,7 +355,8 @@ class TrainFeed(DataFeed): ...@@ -351,7 +355,8 @@ class TrainFeed(DataFeed):
with_background=with_background, with_background=with_background,
num_workers=num_workers, num_workers=num_workers,
bufsize=bufsize, bufsize=bufsize,
use_process=use_process, ) use_process=use_process,
memsize=memsize)
@register @register
...@@ -439,8 +444,10 @@ class FasterRCNNTrainFeed(DataFeed): ...@@ -439,8 +444,10 @@ class FasterRCNNTrainFeed(DataFeed):
shuffle=True, shuffle=True,
samples=-1, samples=-1,
drop_last=False, drop_last=False,
bufsize=10,
num_workers=2, num_workers=2,
use_process=False): use_process=False,
memsize=None):
# XXX this should be handled by the data loader, since `fields` is # XXX this should be handled by the data loader, since `fields` is
# given, just collect them # given, just collect them
sample_transforms.append(ArrangeRCNN()) sample_transforms.append(ArrangeRCNN())
...@@ -454,8 +461,10 @@ class FasterRCNNTrainFeed(DataFeed): ...@@ -454,8 +461,10 @@ class FasterRCNNTrainFeed(DataFeed):
shuffle=shuffle, shuffle=shuffle,
samples=samples, samples=samples,
drop_last=drop_last, drop_last=drop_last,
bufsize=bufsize,
num_workers=num_workers, num_workers=num_workers,
use_process=use_process) use_process=use_process,
memsize=memsize)
# XXX these modes should be unified # XXX these modes should be unified
self.mode = 'TRAIN' self.mode = 'TRAIN'
...@@ -722,7 +731,8 @@ class SSDTrainFeed(DataFeed): ...@@ -722,7 +731,8 @@ class SSDTrainFeed(DataFeed):
drop_last=True, drop_last=True,
num_workers=8, num_workers=8,
bufsize=10, bufsize=10,
use_process=True): use_process=True,
memsize=None):
sample_transforms.append(ArrangeSSD()) sample_transforms.append(ArrangeSSD())
super(SSDTrainFeed, self).__init__( super(SSDTrainFeed, self).__init__(
dataset, dataset,
...@@ -736,7 +746,8 @@ class SSDTrainFeed(DataFeed): ...@@ -736,7 +746,8 @@ class SSDTrainFeed(DataFeed):
drop_last=drop_last, drop_last=drop_last,
num_workers=num_workers, num_workers=num_workers,
bufsize=bufsize, bufsize=bufsize,
use_process=use_process) use_process=use_process,
memsize=None)
self.mode = 'TRAIN' self.mode = 'TRAIN'
...@@ -767,7 +778,8 @@ class SSDEvalFeed(DataFeed): ...@@ -767,7 +778,8 @@ class SSDEvalFeed(DataFeed):
drop_last=True, drop_last=True,
num_workers=8, num_workers=8,
bufsize=10, bufsize=10,
use_process=False): use_process=False,
memsize=None):
sample_transforms.append(ArrangeEvalSSD()) sample_transforms.append(ArrangeEvalSSD())
super(SSDEvalFeed, self).__init__( super(SSDEvalFeed, self).__init__(
dataset, dataset,
...@@ -781,7 +793,8 @@ class SSDEvalFeed(DataFeed): ...@@ -781,7 +793,8 @@ class SSDEvalFeed(DataFeed):
drop_last=drop_last, drop_last=drop_last,
num_workers=num_workers, num_workers=num_workers,
bufsize=bufsize, bufsize=bufsize,
use_process=use_process) use_process=use_process,
memsize=memsize)
self.mode = 'VAL' self.mode = 'VAL'
...@@ -809,7 +822,8 @@ class SSDTestFeed(DataFeed): ...@@ -809,7 +822,8 @@ class SSDTestFeed(DataFeed):
drop_last=False, drop_last=False,
num_workers=8, num_workers=8,
bufsize=10, bufsize=10,
use_process=False): use_process=False,
memsize=None):
sample_transforms.append(ArrangeTestSSD()) sample_transforms.append(ArrangeTestSSD())
if isinstance(dataset, dict): if isinstance(dataset, dict):
dataset = SimpleDataSet(**dataset) dataset = SimpleDataSet(**dataset)
...@@ -825,7 +839,8 @@ class SSDTestFeed(DataFeed): ...@@ -825,7 +839,8 @@ class SSDTestFeed(DataFeed):
drop_last=drop_last, drop_last=drop_last,
num_workers=num_workers, num_workers=num_workers,
bufsize=bufsize, bufsize=bufsize,
use_process=use_process) use_process=use_process,
memsize=memsize)
self.mode = 'TEST' self.mode = 'TEST'
...@@ -873,6 +888,7 @@ class YoloTrainFeed(DataFeed): ...@@ -873,6 +888,7 @@ class YoloTrainFeed(DataFeed):
num_workers=8, num_workers=8,
bufsize=128, bufsize=128,
use_process=True, use_process=True,
memsize=None,
num_max_boxes=50, num_max_boxes=50,
mixup_epoch=250): mixup_epoch=250):
sample_transforms.append(ArrangeYOLO()) sample_transforms.append(ArrangeYOLO())
...@@ -889,7 +905,8 @@ class YoloTrainFeed(DataFeed): ...@@ -889,7 +905,8 @@ class YoloTrainFeed(DataFeed):
with_background=with_background, with_background=with_background,
num_workers=num_workers, num_workers=num_workers,
bufsize=bufsize, bufsize=bufsize,
use_process=use_process) use_process=use_process,
memsize=memsize)
self.num_max_boxes = num_max_boxes self.num_max_boxes = num_max_boxes
self.mixup_epoch = mixup_epoch self.mixup_epoch = mixup_epoch
self.mode = 'TRAIN' self.mode = 'TRAIN'
...@@ -923,7 +940,8 @@ class YoloEvalFeed(DataFeed): ...@@ -923,7 +940,8 @@ class YoloEvalFeed(DataFeed):
with_background=False, with_background=False,
num_workers=8, num_workers=8,
num_max_boxes=50, num_max_boxes=50,
use_process=False): use_process=False,
memsize=None):
sample_transforms.append(ArrangeEvalYOLO()) sample_transforms.append(ArrangeEvalYOLO())
super(YoloEvalFeed, self).__init__( super(YoloEvalFeed, self).__init__(
dataset, dataset,
...@@ -937,7 +955,8 @@ class YoloEvalFeed(DataFeed): ...@@ -937,7 +955,8 @@ class YoloEvalFeed(DataFeed):
drop_last=drop_last, drop_last=drop_last,
with_background=with_background, with_background=with_background,
num_workers=num_workers, num_workers=num_workers,
use_process=use_process) use_process=use_process,
memsize=memsize)
self.num_max_boxes = num_max_boxes self.num_max_boxes = num_max_boxes
self.mode = 'VAL' self.mode = 'VAL'
self.bufsize = 128 self.bufsize = 128
...@@ -976,7 +995,8 @@ class YoloTestFeed(DataFeed): ...@@ -976,7 +995,8 @@ class YoloTestFeed(DataFeed):
with_background=False, with_background=False,
num_workers=8, num_workers=8,
num_max_boxes=50, num_max_boxes=50,
use_process=False): use_process=False,
memsize=None):
sample_transforms.append(ArrangeTestYOLO()) sample_transforms.append(ArrangeTestYOLO())
if isinstance(dataset, dict): if isinstance(dataset, dict):
dataset = SimpleDataSet(**dataset) dataset = SimpleDataSet(**dataset)
...@@ -992,7 +1012,8 @@ class YoloTestFeed(DataFeed): ...@@ -992,7 +1012,8 @@ class YoloTestFeed(DataFeed):
drop_last=drop_last, drop_last=drop_last,
with_background=with_background, with_background=with_background,
num_workers=num_workers, num_workers=num_workers,
use_process=use_process) use_process=use_process,
memsize=memsize)
self.mode = 'TEST' self.mode = 'TEST'
self.bufsize = 128 self.bufsize = 128
......
...@@ -29,7 +29,11 @@ TRANSFORM: ...@@ -29,7 +29,11 @@ TRANSFORM:
BATCH_SIZE: 1 BATCH_SIZE: 1
IS_PADDING: True IS_PADDING: True
DROP_LAST: False DROP_LAST: False
WORKER_CONF:
BUFSIZE: 100
WORKER_NUM: 4
USE_PROCESS: True
MEMSIZE: 2G
VAL: VAL:
OPS: OPS:
- OP: DecodeImage - OP: DecodeImage
...@@ -39,6 +43,6 @@ TRANSFORM: ...@@ -39,6 +43,6 @@ TRANSFORM:
- OP: ArrangeSSD - OP: ArrangeSSD
BATCH_SIZE: 1 BATCH_SIZE: 1
WORKER_CONF: WORKER_CONF:
BUFSIZE: 200 BUFSIZE: 100
WORKER_NUM: 8 WORKER_NUM: 4
USE_PROCESS: False USE_PROCESS: True
...@@ -26,5 +26,7 @@ TRANSFORM: ...@@ -26,5 +26,7 @@ TRANSFORM:
IS_PADDING: True IS_PADDING: True
DROP_LAST: False DROP_LAST: False
WORKER_CONF: WORKER_CONF:
BUFSIZE: 10 BUFSIZE: 100
WORKER_NUM: 2 WORKER_NUM: 4
MEMSIZE: 2G
USE_PROCESS: True
...@@ -49,8 +49,15 @@ class ParallelMappedDataset(ProxiedDataset): ...@@ -49,8 +49,15 @@ class ParallelMappedDataset(ProxiedDataset):
super(ParallelMappedDataset, self).__init__(source) super(ParallelMappedDataset, self).__init__(source)
worker_args = {k.lower(): v for k, v in worker_args.items()} worker_args = {k.lower(): v for k, v in worker_args.items()}
args = {'bufsize': 100, 'worker_num': 8} args = {'bufsize': 100, 'worker_num': 8,
'use_process': False, 'memsize': '3G'}
args.update(worker_args) args.update(worker_args)
if args['use_process'] and type(args['memsize']) is str:
assert args['memsize'][-1].lower() == 'g', \
"invalid param for memsize[%s], should be ended with 'G' or 'g'" % (args['memsize'])
gb = args['memsize'][:-1]
args['memsize'] = int(gb) * 1024 ** 3
self._worker_args = args self._worker_args = args
self._started = False self._started = False
self._source = source self._source = source
...@@ -60,9 +67,7 @@ class ParallelMappedDataset(ProxiedDataset): ...@@ -60,9 +67,7 @@ class ParallelMappedDataset(ProxiedDataset):
def _setup(self): def _setup(self):
"""setup input/output queues and workers """ """setup input/output queues and workers """
use_process = False use_process = self._worker_args.get('use_process', False)
if 'use_process' in self._worker_args:
use_process = self._worker_args['use_process']
if use_process and sys.platform == "win32": if use_process and sys.platform == "win32":
logger.info("Use multi-thread reader instead of " logger.info("Use multi-thread reader instead of "
"multi-process reader on Windows.") "multi-process reader on Windows.")
...@@ -73,6 +78,9 @@ class ParallelMappedDataset(ProxiedDataset): ...@@ -73,6 +78,9 @@ class ParallelMappedDataset(ProxiedDataset):
from .shared_queue import SharedQueue as Queue from .shared_queue import SharedQueue as Queue
from multiprocessing import Process as Worker from multiprocessing import Process as Worker
from multiprocessing import Event from multiprocessing import Event
memsize = self._worker_args['memsize']
self._inq = Queue(bufsize, memsize=memsize)
self._outq = Queue(bufsize, memsize=memsize)
else: else:
if six.PY3: if six.PY3:
from queue import Queue from queue import Queue
...@@ -80,11 +88,10 @@ class ParallelMappedDataset(ProxiedDataset): ...@@ -80,11 +88,10 @@ class ParallelMappedDataset(ProxiedDataset):
from Queue import Queue from Queue import Queue
from threading import Thread as Worker from threading import Thread as Worker
from threading import Event from threading import Event
self._inq = Queue(bufsize) self._inq = Queue(bufsize)
self._outq = Queue(bufsize) self._outq = Queue(bufsize)
consumer_num = self._worker_args['worker_num']
consumer_num = self._worker_args['worker_num']
id = str(uuid.uuid4())[-3:] id = str(uuid.uuid4())[-3:]
self._producer = threading.Thread( self._producer = threading.Thread(
target=self._produce, target=self._produce,
......
...@@ -318,16 +318,14 @@ class PageAllocator(object): ...@@ -318,16 +318,14 @@ class PageAllocator(object):
while True: while True:
# maybe flags already has some '0' pages, # maybe flags already has some '0' pages,
# so just check 'page_num - len(flags)' pages # so just check 'page_num - len(flags)' pages
flags += self.get_page_status( flags = self.get_page_status(
pos, page_num - len(flags), ret_flag=True) pos, page_num, ret_flag=True)
if flags.count('0') == page_num: if flags.count('0') == page_num:
break break
# not found enough pages, so shift to next few pages # not found enough pages, so shift to next few pages
free_pos = flags.rfind('1') + 1 free_pos = flags.rfind('1') + 1
flags = flags[free_pos:]
pos += free_pos pos += free_pos
end = pos + page_num end = pos + page_num
if end > pages: if end > pages:
...@@ -355,9 +353,6 @@ class PageAllocator(object): ...@@ -355,9 +353,6 @@ class PageAllocator(object):
self.set_page_status(pos, page_num, '1') self.set_page_status(pos, page_num, '1')
used += page_num used += page_num
self.set_alloc_info(end, used) self.set_alloc_info(end, used)
assert self.get_page_status(pos, page_num) == (page_num, 1), \
'faild to validate the page status'
return pos return pos
def free_page(self, start, page_num): def free_page(self, start, page_num):
...@@ -530,7 +525,7 @@ class SharedMemoryMgr(object): ...@@ -530,7 +525,7 @@ class SharedMemoryMgr(object):
logger.info('destroy [%s]' % (self)) logger.info('destroy [%s]' % (self))
if not self._released and not self._allocator.empty(): if not self._released and not self._allocator.empty():
logger.warn('not empty when delete this SharedMemoryMgr[%s]' % logger.debug('not empty when delete this SharedMemoryMgr[%s]' %
(self)) (self))
else: else:
self._released = True self._released = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册