未验证 提交 2106f668 编写于 作者: Z Zhang Ting 提交者: GitHub

fix the bug that _DataLoaderIterMultiProcess use time to generate the seed (#43318)

* fix the bug that _DataLoaderIterMultiProcess use time to generate the seed

* use np.random.randint to generate a base seed
上级 403b127b
...@@ -374,6 +374,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -374,6 +374,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# see _try_put_indices # see _try_put_indices
self._thread_lock = threading.Lock() self._thread_lock = threading.Lock()
self._base_seed = np.random.randint(low=0, high=sys.maxsize)
# init workers and indices queues and put 2 indices in each indices queue # init workers and indices queues and put 2 indices in each indices queue
self._init_workers() self._init_workers()
for _ in range(self._outstanding_capacity): for _ in range(self._outstanding_capacity):
...@@ -406,7 +408,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -406,7 +408,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._data_queue, self._workers_done_event, self._data_queue, self._workers_done_event,
self._auto_collate_batch, self._collate_fn, self._auto_collate_batch, self._collate_fn,
self._drop_last, self._worker_init_fn, i, self._drop_last, self._worker_init_fn, i,
self._num_workers, self._use_shared_memory)) self._num_workers, self._use_shared_memory,
self._base_seed))
worker.daemon = True worker.daemon = True
worker.start() worker.start()
self._workers.append(worker) self._workers.append(worker)
......
...@@ -257,7 +257,7 @@ def _generate_states(base_seed=0, worker_id=0): ...@@ -257,7 +257,7 @@ def _generate_states(base_seed=0, worker_id=0):
def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
auto_collate_batch, collate_fn, drop_last, init_fn, worker_id, auto_collate_batch, collate_fn, drop_last, init_fn, worker_id,
num_workers, use_shared_memory): num_workers, use_shared_memory, base_seed):
try: try:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly, # NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet # some shared memory objects may have been applied for but have not yet
...@@ -272,15 +272,20 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, ...@@ -272,15 +272,20 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
try: try:
import numpy as np import numpy as np
import time import time
import random
except ImportError: except ImportError:
pass pass
else: else:
np.random.seed(_generate_states(int(time.time()), worker_id)) seed = base_seed + worker_id
random.seed(seed)
paddle.seed(seed)
np.random.seed(_generate_states(base_seed, worker_id))
global _worker_info global _worker_info
_worker_info = WorkerInfo(id=worker_id, _worker_info = WorkerInfo(id=worker_id,
num_workers=num_workers, num_workers=num_workers,
dataset=dataset) dataset=dataset,
seed=base_seed)
init_exception = None init_exception = None
try: try:
......
...@@ -181,10 +181,11 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): ...@@ -181,10 +181,11 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
for i in range(10): for i in range(10):
indices_queue.put([i, i + 10]) indices_queue.put([i, i + 10])
indices_queue.put(None) indices_queue.put(None)
base_seed = 1234
_worker_loop(loader._dataset, 0, indices_queue, _worker_loop(loader._dataset, 0, indices_queue,
loader._data_queue, loader._workers_done_event, loader._data_queue, loader._workers_done_event,
True, _collate_fn, True, _init_fn, 0, 1, True, _collate_fn, True, _init_fn, 0, 1,
loader._use_shared_memory) loader._use_shared_memory, base_seed)
self.assertTrue(False) self.assertTrue(False)
except AssertionError: except AssertionError:
pass pass
...@@ -223,10 +224,11 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): ...@@ -223,10 +224,11 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
indices_queue.put([i, i + 10]) indices_queue.put([i, i + 10])
indices_queue.put(None) indices_queue.put(None)
loader._workers_done_event.set() loader._workers_done_event.set()
base_seed = 1234
_worker_loop(loader._dataset, 0, indices_queue, _worker_loop(loader._dataset, 0, indices_queue,
loader._data_queue, loader._workers_done_event, loader._data_queue, loader._workers_done_event,
True, _collate_fn, True, _init_fn, 0, 1, True, _collate_fn, True, _init_fn, 0, 1,
loader._use_shared_memory) loader._use_shared_memory, base_seed)
self.assertTrue(True) self.assertTrue(True)
except AssertionError: except AssertionError:
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册