未验证 提交 76710e5f 编写于 作者: K Kaipeng Deng 提交者: GitHub

add persistent_workers (#34017)

* add persistent_workers. test=develop
上级 b451ff26
......@@ -37,7 +37,8 @@ from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
from .batch_sampler import _InfiniteIterableSampler
from .collate import default_collate_fn, default_convert_fn
from .worker import ParentWatchDog, get_worker_info, _worker_loop, \
_DatasetKind, _IterableDatasetStopIteration, _WorkerException
_DatasetKind, _IterableDatasetStopIteration, _WorkerException, \
_ResumeIteration
from .flat import _flatten_batch, _restore_batch
__all__ = ['get_worker_info']
......@@ -67,15 +68,10 @@ class _DataLoaderIterBase(object):
self._dataset_kind = loader.dataset_kind
self._pin_memory = loader.pin_memory
self._sampler_iter = iter(self._index_sampler)
if self._auto_collate_batch:
self._sampler_iter = iter(loader.batch_sampler)
self._collate_fn = loader.collate_fn or default_collate_fn
else:
if self._dataset_kind == _DatasetKind.MAP:
self._sampler_iter = iter(list(range(len(self._dataset))))
else:
self._sampler_iter = iter(
_InfiniteIterableSampler(self._dataset, 1))
self._collate_fn = loader.collate_fn or default_convert_fn
# LoDTensorBlockingQueue instance for create_py_reader and a thread
......@@ -87,6 +83,16 @@ class _DataLoaderIterBase(object):
self._thread = None
self._thread_done_event = threading.Event()
@property
def _index_sampler(self):
if self._auto_collate_batch:
return self._batch_sampler
else:
if self._dataset_kind == _DatasetKind.MAP:
return list(range(len(self._dataset)))
else:
return _InfiniteIterableSampler(self._dataset, 1)
def __iter__(self):
return self
......@@ -242,6 +248,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
def __init__(self, loader):
super(_DataLoaderIterMultiProcess, self).__init__(loader)
self._persistent_workers = loader._persistent_workers
self._resume_worker_cnt = 0
assert self._num_workers > 0, "Multi-process DataLoader " \
"invalid num_workers({})".format(self._num_workers)
......@@ -336,13 +345,65 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._pin_memory)
self._thread_done_event = threading.Event()
# thread event is only need in multi-processing mode
self._thread = threading.Thread(
target=self._thread_loop, args=(_current_expected_place(), ))
self._thread.daemon = True
self._thread.start()
def _shutdown_worker(self, worker_id):
if self._worker_status[worker_id]:
def _reset(self):
# resume iteration in following steps
# 1. Resume workers, clear worker caches
# put _ResumeIteration to all worker as resume iteration flag
with self._thread_lock:
self._resume_worker_cnt = self._num_workers
for worker_id in range(self._num_workers):
self._indices_queues[worker_id].put(_ResumeIteration())
self._batches_outstanding += 1
# all flag will be check in _thread_loop, simply wait here
while self._resume_worker_cnt > 0:
time.sleep(0.5)
# 2. clear blocking_queue caches
# in order not to restart the thread, we just clear
# the blocking_queue cachees instead of recreating one
while self._blocking_queue.size() >= len(self._places):
if in_dygraph_mode():
self._reader.read_next_var_list()
elif self._return_list:
self._reader.read_next_list()
else:
data = self._reader.read_next()
# 3. reset all states
self._send_idx = 0
self._rcvd_idx = 0
self._batches_outstanding = 0
self._task_infos = {}
self._structure_infos = []
# set all worker status available
self._worker_status = [True] * self._num_workers
# 4. reset _sampler_iter and put prefetch indices to start next epoch
# init workers and indices queues and put 2 indices in each indices queue
self._sampler_iter = iter(self._index_sampler)
for _ in range(self._outstanding_capacity):
self._try_put_indices()
def _clear_and_remove_data_queue(self):
if self._data_queue is not None:
while True:
try:
self._data_queue.get_nowait()
except:
self._data_queue.cancel_join_thread()
self._data_queue.close()
break
def _shutdown_worker(self, worker_id, shutdown=False):
if self._worker_status[worker_id] or (self._persistent_workers and
shutdown):
self._indices_queues[worker_id].put(None)
self._worker_status[worker_id] = False
......@@ -357,7 +418,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# indices_queue
self._workers_done_event.set()
for i in range(self._num_workers):
self._shutdown_worker(i)
self._shutdown_worker(i, shutdown=True)
if not self._shutdown:
for w in self._workers:
......@@ -392,6 +453,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if batch is None:
self._exit_thread_expectedly()
else:
if isinstance(batch, _ResumeIteration):
assert self._resume_worker_cnt > 0
self._resume_worker_cnt -= 1
continue
try:
# pack as LoDTensorArray
array = core.LoDTensorArray()
......@@ -412,7 +477,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if not self._blocking_queue.push(array):
self._blocking_queue.close()
except:
except Exception as e:
self._exit_thread_unexpectedly()
six.reraise(*sys.exc_info())
finally:
......@@ -428,7 +493,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# batch indices and increase _rcvd_idx
if self._dataset_kind == _DatasetKind.ITER:
while self._rcvd_idx < self._send_idx:
sys.stdout.flush()
info = self._task_infos[self._rcvd_idx]
if len(info) == 3 or self._worker_status[info[0]]:
break
......@@ -436,12 +500,16 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._rcvd_idx += 1
self._batches_outstanding -= 1
else:
# NOTE: _rcvd_idx and _send_idx only record batches among
# workers, if batches among workers drained, there
# may also be data in blocking queue
if self._batches_outstanding < len(self._places):
return None
continue
# NOTE: in persistent workers mode, do not check data
# drained here, simply let it go to _data_queue
# reading to get _ResumeIteration
if not self._persistent_workers:
# NOTE: _rcvd_idx and _send_idx only record batches among
# workers, if batches among workers drained, there
# may also be data in blocking queue
if self._batches_outstanding < len(self._places):
return None
continue
if self._rcvd_idx in self._task_infos and \
len(self._task_infos[self._rcvd_idx]) == 3:
......@@ -493,12 +561,20 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# is discard, outstanding batch number should be decrease
# and another indices should be put for other workers
# may still working.
self._shutdown_worker(data.worker_id)
self._batches_outstanding -= 1
if self._persistent_workers:
self._worker_status[data.worker_id] = False
else:
self._shutdown_worker(data.worker_id)
self._batches_outstanding -= 1
self._try_put_indices()
continue
idx, batch, structure = data
if isinstance(idx, _ResumeIteration) and batch is None \
and structure is None:
return idx
if isinstance(batch, _WorkerException):
self._exit_thread_unexpectedly()
batch.reraise()
......@@ -557,8 +633,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# set _thread_done_event here, py_reader will raise StopIteration,
# end workers and indices_queues in StopIteration handling
if self._batches_outstanding < len(self._places):
self._thread_done_event.set()
self._blocking_queue.close()
if self._persistent_workers:
raise StopIteration
else:
self._thread_done_event.set()
self._blocking_queue.close()
if in_dygraph_mode():
data = self._reader.read_next_var_list()
......@@ -583,8 +662,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._on_output_batch()
return data
except StopIteration:
self._reader.shutdown()
self._try_shutdown_all()
if not self._persistent_workers:
self._reader.shutdown()
self._try_shutdown_all()
six.reraise(*sys.exc_info())
# python2 compatibility
......
......@@ -36,6 +36,10 @@ class _IterableDatasetStopIteration(object):
self.worker_id = worker_id
class _ResumeIteration(object):
pass
class _DatasetKind(object):
MAP = 0
ITER = 1
......@@ -292,6 +296,13 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
except queue.Empty:
continue
if isinstance(data, _ResumeIteration):
out_queue.put((data, None, None))
iterator_drained = False
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collate_batch, collate_fn, True)
continue
# None as poison piil, so worker event should be set
if data is None:
assert done_event.is_set() or iterator_drained, \
......
......@@ -325,7 +325,8 @@ class DataLoader(object):
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None):
worker_init_fn=None,
persistent_workers=False):
self.return_list = return_list
self.collate_fn = collate_fn
self.use_buffer_reader = use_buffer_reader
......@@ -407,6 +408,9 @@ class DataLoader(object):
self.pin_memory = True if use_pinned_memory(
) is None else use_pinned_memory()
self._persistent_workers = persistent_workers
self._iterator = None
def __len__(self):
if self.dataset_kind == _DatasetKind.ITER:
raise ValueError("length of IterableDataset not supported")
......@@ -419,6 +423,12 @@ class DataLoader(object):
def __iter__(self):
if self.num_workers == 0:
return _DataLoaderIterSingleProcess(self)
elif self._persistent_workers:
if self._iterator is None:
self._iterator = _DataLoaderIterMultiProcess(self)
else:
self._iterator._reset()
return self._iterator
else:
return _DataLoaderIterMultiProcess(self)
......
......@@ -66,7 +66,7 @@ class SimpleFCNet(fluid.dygraph.Layer):
class TestDygraphDataLoader(unittest.TestCase):
def run_main(self, num_workers, places):
def run_main(self, num_workers, places, persistent_workers):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]):
......@@ -78,7 +78,8 @@ class TestDygraphDataLoader(unittest.TestCase):
dataset,
num_workers=num_workers,
batch_size=BATCH_SIZE,
drop_last=True)
drop_last=True,
persistent_workers=persistent_workers)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
step_list = []
......@@ -110,20 +111,25 @@ class TestDygraphDataLoader(unittest.TestCase):
def test_main(self):
# dynamic graph do not run with_data_parallel
for p in prepare_places(False):
results = []
for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers)
sys.stdout.flush()
ret = self.run_main(num_workers=num_workers, places=p)
results.append(ret)
diff = np.max(
np.abs(results[0]['loss'] - results[1]['loss']) /
np.abs(results[0]['loss']))
self.assertLess(diff, 1e-2)
for persistent_workers in [False, True]:
results = []
for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers,
persistent_workers)
sys.stdout.flush()
ret = self.run_main(
num_workers=num_workers,
places=p,
persistent_workers=persistent_workers)
results.append(ret)
diff = np.max(
np.abs(results[0]['loss'] - results[1]['loss']) /
np.abs(results[0]['loss']))
self.assertLess(diff, 1e-2)
class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
def run_main(self, num_workers, places):
def run_main(self, num_workers, places, persistent_workers):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]):
......@@ -135,7 +141,8 @@ class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
dataset,
num_workers=num_workers,
batch_size=None,
drop_last=True)
drop_last=True,
persistent_workers=persistent_workers)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
step_list = []
......
......@@ -66,7 +66,7 @@ class SimpleFCNet(fluid.dygraph.Layer):
class TestDygraphDataLoader(unittest.TestCase):
def run_main(self, num_workers, places):
def run_main(self, num_workers, places, persistent_workers):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]):
......@@ -78,7 +78,8 @@ class TestDygraphDataLoader(unittest.TestCase):
dataset,
num_workers=num_workers,
batch_size=BATCH_SIZE,
drop_last=True)
drop_last=True,
persistent_workers=persistent_workers)
step_list = []
loss_list = []
......@@ -109,18 +110,23 @@ class TestDygraphDataLoader(unittest.TestCase):
def test_main(self):
# dynamic graph do not run with_data_parallel
for p in prepare_places(False):
results = []
for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers)
sys.stdout.flush()
ret = self.run_main(num_workers=num_workers, places=p)
results.append(ret)
assert results[0]['loss'].shape[0] * 2 == results[1]['loss'].shape[
0]
for persistent_workers in [False, True]:
results = []
for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers,
persistent_workers)
sys.stdout.flush()
ret = self.run_main(
num_workers=num_workers,
places=p,
persistent_workers=persistent_workers)
results.append(ret)
assert results[0]['loss'].shape[0] * 2 == results[1][
'loss'].shape[0]
class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
def run_main(self, num_workers, places):
def run_main(self, num_workers, places, persistent_workers):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]):
......@@ -132,7 +138,8 @@ class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
dataset,
num_workers=num_workers,
batch_size=None,
drop_last=True)
drop_last=True,
persistent_workers=persistent_workers)
step_list = []
loss_list = []
......
......@@ -93,14 +93,14 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True):
if with_gpu and fluid.core.is_compiled_with_cuda():
tmp = fluid.cuda_places()[:2]
assert len(tmp) > 0, "no gpu detected"
if with_data_parallel:
if with_data_parallel and len(tmp) > 1:
places.append(tmp)
places.append([tmp[0]])
return places
class TestStaticDataLoader(unittest.TestCase):
def run_main(self, num_workers, places):
def run_main(self, num_workers, places, persistent_workers):
scope = fluid.Scope()
with fluid.scope_guard(scope):
startup_prog, main_prog, image, label, loss = simple_fc_net_static()
......@@ -113,7 +113,8 @@ class TestStaticDataLoader(unittest.TestCase):
num_workers=num_workers,
batch_size=BATCH_SIZE,
return_list=False,
drop_last=True)
drop_last=True,
persistent_workers=persistent_workers)
# assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
exe = fluid.Executor(place=places[0])
......@@ -158,14 +159,19 @@ class TestStaticDataLoader(unittest.TestCase):
def test_main(self):
for p in prepare_places(True):
results = []
for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers)
sys.stdout.flush()
ret = self.run_main(num_workers=num_workers, places=p)
results.append(ret)
assert results[0]['loss'].shape[0] * 2 == results[1]['loss'].shape[
0]
for persistent_workers in [False, True]:
results = []
for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers,
persistent_workers)
sys.stdout.flush()
ret = self.run_main(
num_workers=num_workers,
places=p,
persistent_workers=persistent_workers)
results.append(ret)
assert results[0]['loss'].shape[0] * 2 == results[1][
'loss'].shape[0]
class RandomBatchedDataset(IterableDataset):
......@@ -188,7 +194,7 @@ class RandomBatchedDataset(IterableDataset):
class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
def run_main(self, num_workers, places):
def run_main(self, num_workers, places, persistent_workers):
scope = fluid.Scope()
with fluid.scope_guard(scope):
startup_prog, main_prog, image, label, loss = simple_fc_net_static()
......@@ -201,7 +207,8 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
num_workers=num_workers,
batch_size=None,
return_list=False,
drop_last=True)
drop_last=True,
persistent_workers=persistent_workers)
exe = fluid.Executor(place=places[0])
exe.run(startup_prog)
......
......@@ -94,14 +94,14 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True):
if with_gpu and fluid.core.is_compiled_with_cuda():
tmp = fluid.cuda_places()[:2]
assert len(tmp) > 0, "no gpu detected"
if with_data_parallel:
if with_data_parallel and len(tmp) > 1:
places.append(tmp)
places.append([tmp[0]])
return places
class TestStaticDataLoader(unittest.TestCase):
def run_main(self, num_workers, places, use_pe=True):
def run_main(self, num_workers, places, persistent_workers, use_pe=True):
scope = fluid.Scope()
with fluid.scope_guard(scope):
startup_prog, main_prog, image, label, loss = simple_fc_net_static()
......@@ -114,7 +114,8 @@ class TestStaticDataLoader(unittest.TestCase):
num_workers=num_workers,
batch_size=BATCH_SIZE,
return_list=False,
drop_last=True)
drop_last=True,
persistent_workers=persistent_workers)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
exe = fluid.Executor(place=places[0])
......@@ -162,16 +163,21 @@ class TestStaticDataLoader(unittest.TestCase):
def test_main(self):
for p in prepare_places(True):
results = []
for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers)
sys.stdout.flush()
ret = self.run_main(num_workers=num_workers, places=p)
results.append(ret)
diff = np.max(
np.abs(results[0]['loss'] - results[1]['loss']) /
np.abs(results[0]['loss']))
self.assertLess(diff, 1e-2)
for persistent_workers in [True, False]:
results = []
for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers,
persistent_workers)
sys.stdout.flush()
ret = self.run_main(
num_workers=num_workers,
places=p,
persistent_workers=persistent_workers)
results.append(ret)
diff = np.max(
np.abs(results[0]['loss'] - results[1]['loss']) /
np.abs(results[0]['loss']))
self.assertLess(diff, 1e-2)
class TestStaticDataLoaderReturnList(unittest.TestCase):
......@@ -241,7 +247,7 @@ class RandomBatchedDataset(Dataset):
class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
def run_main(self, num_workers, places):
def run_main(self, num_workers, places, persistent_workers):
scope = fluid.Scope()
with fluid.scope_guard(scope):
startup_prog, main_prog, image, label, loss = simple_fc_net_static()
......@@ -254,7 +260,8 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
num_workers=num_workers,
batch_size=None,
return_list=False,
drop_last=True)
drop_last=True,
persistent_workers=persistent_workers)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
exe = fluid.Executor(place=places[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册