Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
291fc0f0
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
291fc0f0
编写于
6月 09, 2021
作者:
K
Kaipeng Deng
提交者:
GitHub
6月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add random state generate in DataLoader worker (#33310)
* add random state generate in DataLoader worker. test=develop
上级
52007915
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
106 addition
and
0 deletion
+106
-0
python/paddle/fluid/dataloader/worker.py
python/paddle/fluid/dataloader/worker.py
+92
-0
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py
...d/tests/unittests/test_multiprocess_dataloader_dataset.py
+14
-0
未找到文件。
python/paddle/fluid/dataloader/worker.py
浏览文件 @
291fc0f0
...
...
@@ -168,6 +168,89 @@ class _WorkerException(object):
raise
self
.
exc_type
(
msg
)
# The function `_generate_states` is adapted from `numpy.random.SeedSequence`
# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
# Here is the copyright:
# SeedSequence is derived from Melissa E. O'Neill's C++11 `std::seed_seq`
# implementation, as it has a lot of nice properties that we want.
# https://gist.github.com/imneme/540829265469e673d045
# http://www.pcg-random.org/posts/developing-a-seed_seq-alternative.html
# The MIT License (MIT)
# Copyright (c) 2015 Melissa E. O'Neill
# Copyright (c) 2019 NumPy Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
INIT_A
=
0x43b0d7e5
MULT_A
=
0x931e8875
INIT_B
=
0x8b51f9dd
MULT_B
=
0x58f38ded
MIX_MULT_L
=
0xca01f9dd
MIX_MULT_R
=
0x4973f715
XSHIFT
=
np
.
dtype
(
np
.
uint32
).
itemsize
*
8
//
2
MASK32
=
0xFFFFFFFF
def
_generate_states
(
base_seed
=
0
,
worker_id
=
0
):
# init hash constant
hash_const_A
=
INIT_A
hash_const_B
=
INIT_B
def
hash
(
value
):
nonlocal
hash_const_A
value
=
(
value
^
hash_const_A
)
&
MASK32
hash_const_A
=
(
hash_const_A
*
MULT_A
)
&
MASK32
value
=
(
value
*
hash_const_A
)
&
MASK32
value
=
(
value
^
(
value
>>
XSHIFT
))
&
MASK32
return
value
def
mix
(
x
,
y
):
result_x
=
(
MIX_MULT_L
*
x
)
&
MASK32
result_y
=
(
MIX_MULT_R
*
y
)
&
MASK32
result
=
(
result_x
-
result_y
)
&
MASK32
result
=
(
result
^
(
result
>>
XSHIFT
))
&
MASK32
return
result
# init entropys with based_seed and worker_id and calculate pool
entropys
=
[
worker_id
,
base_seed
&
MASK32
,
base_seed
>>
32
,
0
]
pool
=
[
hash
(
entropy
)
for
entropy
in
entropys
]
# mix all bits together
for
i
in
range
(
len
(
pool
)):
for
j
in
range
(
len
(
pool
)):
if
i
!=
j
:
pool
[
j
]
=
mix
(
pool
[
j
],
hash
(
pool
[
i
]))
states
=
[]
for
p
in
pool
:
state
=
(
p
^
hash_const_B
)
&
MASK32
hash_const_B
=
(
hash_const_B
*
MULT_B
)
&
MASK32
state
=
(
state
*
hash_const_B
)
&
MASK32
state
=
(
state
^
(
state
>>
XSHIFT
))
&
MASK32
states
.
append
(
state
)
return
states
def
_worker_loop
(
dataset
,
dataset_kind
,
indices_queue
,
out_queue
,
done_event
,
auto_collate_batch
,
collate_fn
,
init_fn
,
worker_id
,
num_workers
,
use_shared_memory
):
...
...
@@ -181,6 +264,15 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
# set signal handler
core
.
_set_process_signal_handler
()
# set different numpy seed for each worker
try
:
import
numpy
as
np
import
time
except
ImportError
:
pass
else
:
np
.
random
.
seed
(
_generate_states
(
int
(
time
.
time
()),
worker_id
))
global
_worker_info
_worker_info
=
WorkerInfo
(
id
=
worker_id
,
num_workers
=
num_workers
,
dataset
=
dataset
)
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py
浏览文件 @
291fc0f0
...
...
@@ -330,5 +330,19 @@ class TestComplextDataset(unittest.TestCase):
self
.
run_main
(
num_workers
)
class
TestDataLoaderGenerateStates
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
inputs
=
[(
0
,
1
),
(
0
,
2
),
(
1
,
3
)]
self
.
outputs
=
[[
1835504127
,
1731038949
,
1320224556
,
2330041505
],
[
2834126987
,
2358157858
,
1860244682
,
1437227251
],
[
457190280
,
2660306227
,
859341110
,
354512857
]]
def
test_main
(
self
):
from
paddle.fluid.dataloader.worker
import
_generate_states
for
inp
,
outp
in
zip
(
self
.
inputs
,
self
.
outputs
):
out
=
_generate_states
(
*
inp
)
assert
out
==
outp
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录