未验证 提交 1cdf69b2 编写于 作者: K Kaipeng Deng 提交者: GitHub

[cherry pick] add random state generate in DataLoader worker (#33434)

* add random state generate in DataLoader worker. test=develop

* fix license and __all__. test=develop

* fix unittest. test=develop
上级 9035fd2e
......@@ -168,6 +168,89 @@ class _WorkerException(object):
raise self.exc_type(msg)
# The function `_generate_states` is adapted from `numpy.random.SeedSequence`
# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
# Here is the copyright:
# SeedSequence is derived from Melissa E. O'Neill's C++11 `std::seed_seq`
# implementation, as it has a lot of nice properties that we want.
# https://gist.github.com/imneme/540829265469e673d045
# http://www.pcg-random.org/posts/developing-a-seed_seq-alternative.html
# The MIT License (MIT)
# Copyright (c) 2015 Melissa E. O'Neill
# Copyright (c) 2019 NumPy Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
INIT_A = 0x43b0d7e5
MULT_A = 0x931e8875
INIT_B = 0x8b51f9dd
MULT_B = 0x58f38ded
MIX_MULT_L = 0xca01f9dd
MIX_MULT_R = 0x4973f715
XSHIFT = np.dtype(np.uint32).itemsize * 8 // 2
MASK32 = 0xFFFFFFFF
def _generate_states(base_seed=0, worker_id=0):
# init hash constant
hash_const_A = INIT_A
hash_const_B = INIT_B
def hash(value):
nonlocal hash_const_A
value = (value ^ hash_const_A) & MASK32
hash_const_A = (hash_const_A * MULT_A) & MASK32
value = (value * hash_const_A) & MASK32
value = (value ^ (value >> XSHIFT)) & MASK32
return value
def mix(x, y):
result_x = (MIX_MULT_L * x) & MASK32
result_y = (MIX_MULT_R * y) & MASK32
result = (result_x - result_y) & MASK32
result = (result ^ (result >> XSHIFT)) & MASK32
return result
# init entropys with based_seed and worker_id and calculate pool
entropys = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
pool = [hash(entropy) for entropy in entropys]
# mix all bits together
for i in range(len(pool)):
for j in range(len(pool)):
if i != j:
pool[j] = mix(pool[j], hash(pool[i]))
states = []
for p in pool:
state = (p ^ hash_const_B) & MASK32
hash_const_B = (hash_const_B * MULT_B) & MASK32
state = (state * hash_const_B) & MASK32
state = (state ^ (state >> XSHIFT)) & MASK32
states.append(state)
return states
def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
auto_collate_batch, collate_fn, init_fn, worker_id,
num_workers, use_shared_memory):
......@@ -181,6 +264,15 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
# set signal handler
core._set_process_signal_handler()
# set different numpy seed for each worker
try:
import numpy as np
import time
except ImportError:
pass
else:
np.random.seed(_generate_states(int(time.time()), worker_id))
global _worker_info
_worker_info = WorkerInfo(
id=worker_id, num_workers=num_workers, dataset=dataset)
......
......@@ -330,5 +330,19 @@ class TestComplextDataset(unittest.TestCase):
self.run_main(num_workers)
class TestDataLoaderGenerateStates(unittest.TestCase):
def setUp(self):
self.inputs = [(0, 1), (0, 2), (1, 3)]
self.outputs = [[1835504127, 1731038949, 1320224556, 2330041505],
[2834126987, 2358157858, 1860244682, 1437227251],
[457190280, 2660306227, 859341110, 354512857]]
def test_main(self):
from paddle.fluid.dataloader.worker import _generate_states
for inp, outp in zip(self.inputs, self.outputs):
out = _generate_states(*inp)
assert out == outp
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册