Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
73e41c89
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
73e41c89
编写于
8月 24, 2022
作者:
S
ShenLiang
提交者:
GitHub
8月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Solve the random state serialization (#45327)
* fix utest * fix utest * fix utest * fix log * fix random utest
上级
728d5b3a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
79 addition
and
1 deletion
+79
-1
paddle/fluid/framework/generator.cc
paddle/fluid/framework/generator.cc
+10
-0
paddle/fluid/pybind/generator_py.cc
paddle/fluid/pybind/generator_py.cc
+32
-1
python/paddle/fluid/tests/unittests/test_cuda_random_seed.py
python/paddle/fluid/tests/unittests/test_cuda_random_seed.py
+37
-0
未找到文件。
paddle/fluid/framework/generator.cc
浏览文件 @
73e41c89
...
...
@@ -131,6 +131,11 @@ std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t seed) {
phi
::
Generator
::
GeneratorState
Generator
::
GetState
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mu_
);
state_
.
cpu_engine
=
*
engine_
;
VLOG
(
4
)
<<
"Get Random state: "
<<
"device id: "
<<
(
uint64_t
)(
this
->
state_
.
device
)
<<
", current_seed: "
<<
this
->
state_
.
current_seed
<<
", thread_offset: "
<<
this
->
state_
.
thread_offset
<<
", cpu engine: "
<<
*
(
this
->
engine_
);
return
this
->
state_
;
}
...
...
@@ -138,6 +143,11 @@ void Generator::SetState(const phi::Generator::GeneratorState& state) {
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mu_
);
this
->
state_
=
state
;
this
->
engine_
=
std
::
make_shared
<
std
::
mt19937_64
>
(
state
.
cpu_engine
);
VLOG
(
4
)
<<
"Set Random state: "
<<
"device id: "
<<
(
uint64_t
)(
this
->
state_
.
device
)
<<
", current_seed: "
<<
this
->
state_
.
current_seed
<<
", thread_offset: "
<<
this
->
state_
.
thread_offset
<<
", cpu engine: "
<<
*
(
this
->
engine_
);
}
uint64_t
Generator
::
GetCurrentSeed
()
{
...
...
paddle/fluid/pybind/generator_py.cc
浏览文件 @
73e41c89
...
...
@@ -39,7 +39,38 @@ void BindGenerator(py::module* m_ptr) {
.
def
(
"current_seed"
,
[](
std
::
shared_ptr
<
phi
::
Generator
::
GeneratorState
>&
self
)
{
return
self
->
current_seed
;
});
})
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// NOTE(shenliang03): Due to the inability to serialize mt19937_64
// type, resulting in a problem with precision under the cpu.
.
def
(
py
::
pickle
(
[](
const
phi
::
Generator
::
GeneratorState
&
s
)
{
// __getstate__
return
py
::
make_tuple
(
s
.
device
,
s
.
current_seed
,
s
.
thread_offset
);
},
[](
py
::
tuple
s
)
{
// __setstate__
if
(
s
.
size
()
!=
3
)
throw
std
::
runtime_error
(
"Invalid Random state. Please check the format(device, "
"current_seed, thread_offset)."
);
phi
::
Generator
::
GeneratorState
state
;
state
.
device
=
s
[
0
].
cast
<
std
::
int64_t
>
();
state
.
current_seed
=
s
[
1
].
cast
<
std
::
uint64_t
>
();
state
.
thread_offset
=
s
[
2
].
cast
<
std
::
uint64_t
>
();
std
::
seed_seq
seq
({
state
.
current_seed
});
auto
engine
=
std
::
make_shared
<
std
::
mt19937_64
>
(
seq
);
state
.
cpu_engine
=
*
engine
;
return
state
;
}))
#endif
.
def
(
"__str__"
,
[](
const
phi
::
Generator
::
GeneratorState
&
self
)
{
std
::
stringstream
ostr
;
ostr
<<
self
.
device
<<
" "
<<
self
.
current_seed
<<
" "
<<
self
.
thread_offset
<<
" "
<<
self
.
cpu_engine
;
return
ostr
.
str
();
});
py
::
class_
<
std
::
mt19937_64
>
(
m
,
"mt19937_64"
,
""
);
py
::
class_
<
framework
::
Generator
,
std
::
shared_ptr
<
framework
::
Generator
>>
(
m
,
"Generator"
)
...
...
python/paddle/fluid/tests/unittests/test_cuda_random_seed.py
浏览文件 @
73e41c89
...
...
@@ -23,6 +23,8 @@ import paddle.fluid as fluid
import
numpy
as
np
import
paddle
import
paddle.fluid.core
as
core
import
shutil
import
tempfile
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -169,6 +171,41 @@ class TestGeneratorSeed(unittest.TestCase):
np
.
testing
.
assert_allclose
(
out1_res2
,
out2_res2
,
rtol
=
1e-05
)
self
.
assertTrue
(
not
np
.
allclose
(
out1_res2
,
out1_res1
))
def
test_generator_pickle
(
self
):
output_dir
=
tempfile
.
mkdtemp
()
random_file
=
os
.
path
.
join
(
output_dir
,
"random.pdmodel"
)
fluid
.
enable_dygraph
()
x0
=
paddle
.
randn
([
120
],
dtype
=
"float32"
)
st
=
paddle
.
get_cuda_rng_state
()
st_dict
=
{
"random_state"
:
st
}
print
(
"state: "
,
st
[
0
])
paddle
.
save
(
st_dict
,
random_file
)
x1
=
paddle
.
randn
([
120
],
dtype
=
"float32"
)
lt_dict
=
paddle
.
load
(
random_file
)
st
=
lt_dict
[
"random_state"
]
paddle
.
set_cuda_rng_state
(
st
)
x2
=
paddle
.
randn
([
120
],
dtype
=
"float32"
)
lt_dict
=
paddle
.
load
(
random_file
)
st
=
lt_dict
[
"random_state"
]
paddle
.
set_cuda_rng_state
(
st
)
x3
=
paddle
.
randn
([
120
],
dtype
=
"float32"
)
x1_np
=
x1
.
numpy
()
x2_np
=
x2
.
numpy
()
x3_np
=
x3
.
numpy
()
print
(
">>>>>>> gaussian random dygraph state load/save >>>>>>>"
)
np
.
testing
.
assert_equal
(
x1_np
,
x2_np
)
np
.
testing
.
assert_equal
(
x1_np
,
x2_np
)
shutil
.
rmtree
(
output_dir
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录