未验证 提交 76710e5f 编写于 作者: K Kaipeng Deng 提交者: GitHub

add persistent_workers (#34017)

* add persistent_workers. test=develop
上级 b451ff26
...@@ -37,7 +37,8 @@ from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher ...@@ -37,7 +37,8 @@ from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
from .batch_sampler import _InfiniteIterableSampler from .batch_sampler import _InfiniteIterableSampler
from .collate import default_collate_fn, default_convert_fn from .collate import default_collate_fn, default_convert_fn
from .worker import ParentWatchDog, get_worker_info, _worker_loop, \ from .worker import ParentWatchDog, get_worker_info, _worker_loop, \
_DatasetKind, _IterableDatasetStopIteration, _WorkerException _DatasetKind, _IterableDatasetStopIteration, _WorkerException, \
_ResumeIteration
from .flat import _flatten_batch, _restore_batch from .flat import _flatten_batch, _restore_batch
__all__ = ['get_worker_info'] __all__ = ['get_worker_info']
...@@ -67,15 +68,10 @@ class _DataLoaderIterBase(object): ...@@ -67,15 +68,10 @@ class _DataLoaderIterBase(object):
self._dataset_kind = loader.dataset_kind self._dataset_kind = loader.dataset_kind
self._pin_memory = loader.pin_memory self._pin_memory = loader.pin_memory
self._sampler_iter = iter(self._index_sampler)
if self._auto_collate_batch: if self._auto_collate_batch:
self._sampler_iter = iter(loader.batch_sampler)
self._collate_fn = loader.collate_fn or default_collate_fn self._collate_fn = loader.collate_fn or default_collate_fn
else: else:
if self._dataset_kind == _DatasetKind.MAP:
self._sampler_iter = iter(list(range(len(self._dataset))))
else:
self._sampler_iter = iter(
_InfiniteIterableSampler(self._dataset, 1))
self._collate_fn = loader.collate_fn or default_convert_fn self._collate_fn = loader.collate_fn or default_convert_fn
# LoDTensorBlockingQueue instance for create_py_reader and a thread # LoDTensorBlockingQueue instance for create_py_reader and a thread
...@@ -87,6 +83,16 @@ class _DataLoaderIterBase(object): ...@@ -87,6 +83,16 @@ class _DataLoaderIterBase(object):
self._thread = None self._thread = None
self._thread_done_event = threading.Event() self._thread_done_event = threading.Event()
@property
def _index_sampler(self):
if self._auto_collate_batch:
return self._batch_sampler
else:
if self._dataset_kind == _DatasetKind.MAP:
return list(range(len(self._dataset)))
else:
return _InfiniteIterableSampler(self._dataset, 1)
def __iter__(self): def __iter__(self):
return self return self
...@@ -242,6 +248,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -242,6 +248,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
def __init__(self, loader): def __init__(self, loader):
super(_DataLoaderIterMultiProcess, self).__init__(loader) super(_DataLoaderIterMultiProcess, self).__init__(loader)
self._persistent_workers = loader._persistent_workers
self._resume_worker_cnt = 0
assert self._num_workers > 0, "Multi-process DataLoader " \ assert self._num_workers > 0, "Multi-process DataLoader " \
"invalid num_workers({})".format(self._num_workers) "invalid num_workers({})".format(self._num_workers)
...@@ -336,13 +345,65 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -336,13 +345,65 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._pin_memory) self._pin_memory)
self._thread_done_event = threading.Event() self._thread_done_event = threading.Event()
# thread event is only need in multi-processing mode
self._thread = threading.Thread( self._thread = threading.Thread(
target=self._thread_loop, args=(_current_expected_place(), )) target=self._thread_loop, args=(_current_expected_place(), ))
self._thread.daemon = True self._thread.daemon = True
self._thread.start() self._thread.start()
def _shutdown_worker(self, worker_id): def _reset(self):
if self._worker_status[worker_id]: # resume iteration in following steps
# 1. Resume workers, clear worker caches
# put _ResumeIteration to all worker as resume iteration flag
with self._thread_lock:
self._resume_worker_cnt = self._num_workers
for worker_id in range(self._num_workers):
self._indices_queues[worker_id].put(_ResumeIteration())
self._batches_outstanding += 1
# all flag will be check in _thread_loop, simply wait here
while self._resume_worker_cnt > 0:
time.sleep(0.5)
# 2. clear blocking_queue caches
# in order not to restart the thread, we just clear
# the blocking_queue cachees instead of recreating one
while self._blocking_queue.size() >= len(self._places):
if in_dygraph_mode():
self._reader.read_next_var_list()
elif self._return_list:
self._reader.read_next_list()
else:
data = self._reader.read_next()
# 3. reset all states
self._send_idx = 0
self._rcvd_idx = 0
self._batches_outstanding = 0
self._task_infos = {}
self._structure_infos = []
# set all worker status available
self._worker_status = [True] * self._num_workers
# 4. reset _sampler_iter and put prefetch indices to start next epoch
# init workers and indices queues and put 2 indices in each indices queue
self._sampler_iter = iter(self._index_sampler)
for _ in range(self._outstanding_capacity):
self._try_put_indices()
def _clear_and_remove_data_queue(self):
if self._data_queue is not None:
while True:
try:
self._data_queue.get_nowait()
except:
self._data_queue.cancel_join_thread()
self._data_queue.close()
break
def _shutdown_worker(self, worker_id, shutdown=False):
if self._worker_status[worker_id] or (self._persistent_workers and
shutdown):
self._indices_queues[worker_id].put(None) self._indices_queues[worker_id].put(None)
self._worker_status[worker_id] = False self._worker_status[worker_id] = False
...@@ -357,7 +418,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -357,7 +418,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# indices_queue # indices_queue
self._workers_done_event.set() self._workers_done_event.set()
for i in range(self._num_workers): for i in range(self._num_workers):
self._shutdown_worker(i) self._shutdown_worker(i, shutdown=True)
if not self._shutdown: if not self._shutdown:
for w in self._workers: for w in self._workers:
...@@ -392,6 +453,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -392,6 +453,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if batch is None: if batch is None:
self._exit_thread_expectedly() self._exit_thread_expectedly()
else: else:
if isinstance(batch, _ResumeIteration):
assert self._resume_worker_cnt > 0
self._resume_worker_cnt -= 1
continue
try: try:
# pack as LoDTensorArray # pack as LoDTensorArray
array = core.LoDTensorArray() array = core.LoDTensorArray()
...@@ -412,7 +477,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -412,7 +477,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if not self._blocking_queue.push(array): if not self._blocking_queue.push(array):
self._blocking_queue.close() self._blocking_queue.close()
except: except Exception as e:
self._exit_thread_unexpectedly() self._exit_thread_unexpectedly()
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
finally: finally:
...@@ -428,7 +493,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -428,7 +493,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# batch indices and increase _rcvd_idx # batch indices and increase _rcvd_idx
if self._dataset_kind == _DatasetKind.ITER: if self._dataset_kind == _DatasetKind.ITER:
while self._rcvd_idx < self._send_idx: while self._rcvd_idx < self._send_idx:
sys.stdout.flush()
info = self._task_infos[self._rcvd_idx] info = self._task_infos[self._rcvd_idx]
if len(info) == 3 or self._worker_status[info[0]]: if len(info) == 3 or self._worker_status[info[0]]:
break break
...@@ -436,6 +500,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -436,6 +500,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._rcvd_idx += 1 self._rcvd_idx += 1
self._batches_outstanding -= 1 self._batches_outstanding -= 1
else: else:
# NOTE: in persistent workers mode, do not check data
# drained here, simply let it go to _data_queue
# reading to get _ResumeIteration
if not self._persistent_workers:
# NOTE: _rcvd_idx and _send_idx only record batches among # NOTE: _rcvd_idx and _send_idx only record batches among
# workers, if batches among workers drained, there # workers, if batches among workers drained, there
# may also be data in blocking queue # may also be data in blocking queue
...@@ -493,12 +561,20 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -493,12 +561,20 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# is discard, outstanding batch number should be decrease # is discard, outstanding batch number should be decrease
# and another indices should be put for other workers # and another indices should be put for other workers
# may still working. # may still working.
if self._persistent_workers:
self._worker_status[data.worker_id] = False
else:
self._shutdown_worker(data.worker_id) self._shutdown_worker(data.worker_id)
self._batches_outstanding -= 1 self._batches_outstanding -= 1
self._try_put_indices() self._try_put_indices()
continue continue
idx, batch, structure = data idx, batch, structure = data
if isinstance(idx, _ResumeIteration) and batch is None \
and structure is None:
return idx
if isinstance(batch, _WorkerException): if isinstance(batch, _WorkerException):
self._exit_thread_unexpectedly() self._exit_thread_unexpectedly()
batch.reraise() batch.reraise()
...@@ -557,6 +633,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -557,6 +633,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# set _thread_done_event here, py_reader will raise StopIteration, # set _thread_done_event here, py_reader will raise StopIteration,
# end workers and indices_queues in StopIteration handling # end workers and indices_queues in StopIteration handling
if self._batches_outstanding < len(self._places): if self._batches_outstanding < len(self._places):
if self._persistent_workers:
raise StopIteration
else:
self._thread_done_event.set() self._thread_done_event.set()
self._blocking_queue.close() self._blocking_queue.close()
...@@ -583,6 +662,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -583,6 +662,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._on_output_batch() self._on_output_batch()
return data return data
except StopIteration: except StopIteration:
if not self._persistent_workers:
self._reader.shutdown() self._reader.shutdown()
self._try_shutdown_all() self._try_shutdown_all()
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
......
...@@ -36,6 +36,10 @@ class _IterableDatasetStopIteration(object): ...@@ -36,6 +36,10 @@ class _IterableDatasetStopIteration(object):
self.worker_id = worker_id self.worker_id = worker_id
class _ResumeIteration(object):
pass
class _DatasetKind(object): class _DatasetKind(object):
MAP = 0 MAP = 0
ITER = 1 ITER = 1
...@@ -292,6 +296,13 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, ...@@ -292,6 +296,13 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
except queue.Empty: except queue.Empty:
continue continue
if isinstance(data, _ResumeIteration):
out_queue.put((data, None, None))
iterator_drained = False
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collate_batch, collate_fn, True)
continue
# None as poison piil, so worker event should be set # None as poison piil, so worker event should be set
if data is None: if data is None:
assert done_event.is_set() or iterator_drained, \ assert done_event.is_set() or iterator_drained, \
......
...@@ -325,7 +325,8 @@ class DataLoader(object): ...@@ -325,7 +325,8 @@ class DataLoader(object):
use_buffer_reader=True, use_buffer_reader=True,
use_shared_memory=True, use_shared_memory=True,
timeout=0, timeout=0,
worker_init_fn=None): worker_init_fn=None,
persistent_workers=False):
self.return_list = return_list self.return_list = return_list
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.use_buffer_reader = use_buffer_reader self.use_buffer_reader = use_buffer_reader
...@@ -407,6 +408,9 @@ class DataLoader(object): ...@@ -407,6 +408,9 @@ class DataLoader(object):
self.pin_memory = True if use_pinned_memory( self.pin_memory = True if use_pinned_memory(
) is None else use_pinned_memory() ) is None else use_pinned_memory()
self._persistent_workers = persistent_workers
self._iterator = None
def __len__(self): def __len__(self):
if self.dataset_kind == _DatasetKind.ITER: if self.dataset_kind == _DatasetKind.ITER:
raise ValueError("length of IterableDataset not supported") raise ValueError("length of IterableDataset not supported")
...@@ -419,6 +423,12 @@ class DataLoader(object): ...@@ -419,6 +423,12 @@ class DataLoader(object):
def __iter__(self): def __iter__(self):
if self.num_workers == 0: if self.num_workers == 0:
return _DataLoaderIterSingleProcess(self) return _DataLoaderIterSingleProcess(self)
elif self._persistent_workers:
if self._iterator is None:
self._iterator = _DataLoaderIterMultiProcess(self)
else:
self._iterator._reset()
return self._iterator
else: else:
return _DataLoaderIterMultiProcess(self) return _DataLoaderIterMultiProcess(self)
......
...@@ -66,7 +66,7 @@ class SimpleFCNet(fluid.dygraph.Layer): ...@@ -66,7 +66,7 @@ class SimpleFCNet(fluid.dygraph.Layer):
class TestDygraphDataLoader(unittest.TestCase): class TestDygraphDataLoader(unittest.TestCase):
def run_main(self, num_workers, places): def run_main(self, num_workers, places, persistent_workers):
fluid.default_startup_program().random_seed = 1 fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1 fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]): with fluid.dygraph.guard(places[0]):
...@@ -78,7 +78,8 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -78,7 +78,8 @@ class TestDygraphDataLoader(unittest.TestCase):
dataset, dataset,
num_workers=num_workers, num_workers=num_workers,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
drop_last=True) drop_last=True,
persistent_workers=persistent_workers)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
step_list = [] step_list = []
...@@ -110,11 +111,16 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -110,11 +111,16 @@ class TestDygraphDataLoader(unittest.TestCase):
def test_main(self): def test_main(self):
# dynamic graph do not run with_data_parallel # dynamic graph do not run with_data_parallel
for p in prepare_places(False): for p in prepare_places(False):
for persistent_workers in [False, True]:
results = [] results = []
for num_workers in [0, 2]: for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers) print(self.__class__.__name__, p, num_workers,
persistent_workers)
sys.stdout.flush() sys.stdout.flush()
ret = self.run_main(num_workers=num_workers, places=p) ret = self.run_main(
num_workers=num_workers,
places=p,
persistent_workers=persistent_workers)
results.append(ret) results.append(ret)
diff = np.max( diff = np.max(
np.abs(results[0]['loss'] - results[1]['loss']) / np.abs(results[0]['loss'] - results[1]['loss']) /
...@@ -123,7 +129,7 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -123,7 +129,7 @@ class TestDygraphDataLoader(unittest.TestCase):
class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader): class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
def run_main(self, num_workers, places): def run_main(self, num_workers, places, persistent_workers):
fluid.default_startup_program().random_seed = 1 fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1 fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]): with fluid.dygraph.guard(places[0]):
...@@ -135,7 +141,8 @@ class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader): ...@@ -135,7 +141,8 @@ class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
dataset, dataset,
num_workers=num_workers, num_workers=num_workers,
batch_size=None, batch_size=None,
drop_last=True) drop_last=True,
persistent_workers=persistent_workers)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
step_list = [] step_list = []
......
...@@ -66,7 +66,7 @@ class SimpleFCNet(fluid.dygraph.Layer): ...@@ -66,7 +66,7 @@ class SimpleFCNet(fluid.dygraph.Layer):
class TestDygraphDataLoader(unittest.TestCase): class TestDygraphDataLoader(unittest.TestCase):
def run_main(self, num_workers, places): def run_main(self, num_workers, places, persistent_workers):
fluid.default_startup_program().random_seed = 1 fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1 fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]): with fluid.dygraph.guard(places[0]):
...@@ -78,7 +78,8 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -78,7 +78,8 @@ class TestDygraphDataLoader(unittest.TestCase):
dataset, dataset,
num_workers=num_workers, num_workers=num_workers,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
drop_last=True) drop_last=True,
persistent_workers=persistent_workers)
step_list = [] step_list = []
loss_list = [] loss_list = []
...@@ -109,18 +110,23 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -109,18 +110,23 @@ class TestDygraphDataLoader(unittest.TestCase):
def test_main(self): def test_main(self):
# dynamic graph do not run with_data_parallel # dynamic graph do not run with_data_parallel
for p in prepare_places(False): for p in prepare_places(False):
for persistent_workers in [False, True]:
results = [] results = []
for num_workers in [0, 2]: for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers) print(self.__class__.__name__, p, num_workers,
persistent_workers)
sys.stdout.flush() sys.stdout.flush()
ret = self.run_main(num_workers=num_workers, places=p) ret = self.run_main(
num_workers=num_workers,
places=p,
persistent_workers=persistent_workers)
results.append(ret) results.append(ret)
assert results[0]['loss'].shape[0] * 2 == results[1]['loss'].shape[ assert results[0]['loss'].shape[0] * 2 == results[1][
0] 'loss'].shape[0]
class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader): class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
def run_main(self, num_workers, places): def run_main(self, num_workers, places, persistent_workers):
fluid.default_startup_program().random_seed = 1 fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1 fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]): with fluid.dygraph.guard(places[0]):
...@@ -132,7 +138,8 @@ class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader): ...@@ -132,7 +138,8 @@ class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
dataset, dataset,
num_workers=num_workers, num_workers=num_workers,
batch_size=None, batch_size=None,
drop_last=True) drop_last=True,
persistent_workers=persistent_workers)
step_list = [] step_list = []
loss_list = [] loss_list = []
......
...@@ -93,14 +93,14 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True): ...@@ -93,14 +93,14 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True):
if with_gpu and fluid.core.is_compiled_with_cuda(): if with_gpu and fluid.core.is_compiled_with_cuda():
tmp = fluid.cuda_places()[:2] tmp = fluid.cuda_places()[:2]
assert len(tmp) > 0, "no gpu detected" assert len(tmp) > 0, "no gpu detected"
if with_data_parallel: if with_data_parallel and len(tmp) > 1:
places.append(tmp) places.append(tmp)
places.append([tmp[0]]) places.append([tmp[0]])
return places return places
class TestStaticDataLoader(unittest.TestCase): class TestStaticDataLoader(unittest.TestCase):
def run_main(self, num_workers, places): def run_main(self, num_workers, places, persistent_workers):
scope = fluid.Scope() scope = fluid.Scope()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
startup_prog, main_prog, image, label, loss = simple_fc_net_static() startup_prog, main_prog, image, label, loss = simple_fc_net_static()
...@@ -113,7 +113,8 @@ class TestStaticDataLoader(unittest.TestCase): ...@@ -113,7 +113,8 @@ class TestStaticDataLoader(unittest.TestCase):
num_workers=num_workers, num_workers=num_workers,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
return_list=False, return_list=False,
drop_last=True) drop_last=True,
persistent_workers=persistent_workers)
# assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) # assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
exe = fluid.Executor(place=places[0]) exe = fluid.Executor(place=places[0])
...@@ -158,14 +159,19 @@ class TestStaticDataLoader(unittest.TestCase): ...@@ -158,14 +159,19 @@ class TestStaticDataLoader(unittest.TestCase):
def test_main(self): def test_main(self):
for p in prepare_places(True): for p in prepare_places(True):
for persistent_workers in [False, True]:
results = [] results = []
for num_workers in [0, 2]: for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers) print(self.__class__.__name__, p, num_workers,
persistent_workers)
sys.stdout.flush() sys.stdout.flush()
ret = self.run_main(num_workers=num_workers, places=p) ret = self.run_main(
num_workers=num_workers,
places=p,
persistent_workers=persistent_workers)
results.append(ret) results.append(ret)
assert results[0]['loss'].shape[0] * 2 == results[1]['loss'].shape[ assert results[0]['loss'].shape[0] * 2 == results[1][
0] 'loss'].shape[0]
class RandomBatchedDataset(IterableDataset): class RandomBatchedDataset(IterableDataset):
...@@ -188,7 +194,7 @@ class RandomBatchedDataset(IterableDataset): ...@@ -188,7 +194,7 @@ class RandomBatchedDataset(IterableDataset):
class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
def run_main(self, num_workers, places): def run_main(self, num_workers, places, persistent_workers):
scope = fluid.Scope() scope = fluid.Scope()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
startup_prog, main_prog, image, label, loss = simple_fc_net_static() startup_prog, main_prog, image, label, loss = simple_fc_net_static()
...@@ -201,7 +207,8 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): ...@@ -201,7 +207,8 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
num_workers=num_workers, num_workers=num_workers,
batch_size=None, batch_size=None,
return_list=False, return_list=False,
drop_last=True) drop_last=True,
persistent_workers=persistent_workers)
exe = fluid.Executor(place=places[0]) exe = fluid.Executor(place=places[0])
exe.run(startup_prog) exe.run(startup_prog)
......
...@@ -94,14 +94,14 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True): ...@@ -94,14 +94,14 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True):
if with_gpu and fluid.core.is_compiled_with_cuda(): if with_gpu and fluid.core.is_compiled_with_cuda():
tmp = fluid.cuda_places()[:2] tmp = fluid.cuda_places()[:2]
assert len(tmp) > 0, "no gpu detected" assert len(tmp) > 0, "no gpu detected"
if with_data_parallel: if with_data_parallel and len(tmp) > 1:
places.append(tmp) places.append(tmp)
places.append([tmp[0]]) places.append([tmp[0]])
return places return places
class TestStaticDataLoader(unittest.TestCase): class TestStaticDataLoader(unittest.TestCase):
def run_main(self, num_workers, places, use_pe=True): def run_main(self, num_workers, places, persistent_workers, use_pe=True):
scope = fluid.Scope() scope = fluid.Scope()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
startup_prog, main_prog, image, label, loss = simple_fc_net_static() startup_prog, main_prog, image, label, loss = simple_fc_net_static()
...@@ -114,7 +114,8 @@ class TestStaticDataLoader(unittest.TestCase): ...@@ -114,7 +114,8 @@ class TestStaticDataLoader(unittest.TestCase):
num_workers=num_workers, num_workers=num_workers,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
return_list=False, return_list=False,
drop_last=True) drop_last=True,
persistent_workers=persistent_workers)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
exe = fluid.Executor(place=places[0]) exe = fluid.Executor(place=places[0])
...@@ -162,11 +163,16 @@ class TestStaticDataLoader(unittest.TestCase): ...@@ -162,11 +163,16 @@ class TestStaticDataLoader(unittest.TestCase):
def test_main(self): def test_main(self):
for p in prepare_places(True): for p in prepare_places(True):
for persistent_workers in [True, False]:
results = [] results = []
for num_workers in [0, 2]: for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers) print(self.__class__.__name__, p, num_workers,
persistent_workers)
sys.stdout.flush() sys.stdout.flush()
ret = self.run_main(num_workers=num_workers, places=p) ret = self.run_main(
num_workers=num_workers,
places=p,
persistent_workers=persistent_workers)
results.append(ret) results.append(ret)
diff = np.max( diff = np.max(
np.abs(results[0]['loss'] - results[1]['loss']) / np.abs(results[0]['loss'] - results[1]['loss']) /
...@@ -241,7 +247,7 @@ class RandomBatchedDataset(Dataset): ...@@ -241,7 +247,7 @@ class RandomBatchedDataset(Dataset):
class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
def run_main(self, num_workers, places): def run_main(self, num_workers, places, persistent_workers):
scope = fluid.Scope() scope = fluid.Scope()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
startup_prog, main_prog, image, label, loss = simple_fc_net_static() startup_prog, main_prog, image, label, loss = simple_fc_net_static()
...@@ -254,7 +260,8 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): ...@@ -254,7 +260,8 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
num_workers=num_workers, num_workers=num_workers,
batch_size=None, batch_size=None,
return_list=False, return_list=False,
drop_last=True) drop_last=True,
persistent_workers=persistent_workers)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
exe = fluid.Executor(place=places[0]) exe = fluid.Executor(place=places[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册