未验证 提交 73e41c89 编写于 作者: S ShenLiang 提交者: GitHub

Solve the random state serialization (#45327)

* fix utest

* fix utest

* fix utest

* fix log

* fix random utest
上级 728d5b3a
...@@ -131,6 +131,11 @@ std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t seed) { ...@@ -131,6 +131,11 @@ std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t seed) {
phi::Generator::GeneratorState Generator::GetState() { phi::Generator::GeneratorState Generator::GetState() {
std::lock_guard<std::mutex> lock(this->mu_); std::lock_guard<std::mutex> lock(this->mu_);
state_.cpu_engine = *engine_; state_.cpu_engine = *engine_;
VLOG(4) << "Get Random state: "
<< "device id: " << (uint64_t)(this->state_.device)
<< ", current_seed: " << this->state_.current_seed
<< ", thread_offset: " << this->state_.thread_offset
<< ", cpu engine: " << *(this->engine_);
return this->state_; return this->state_;
} }
...@@ -138,6 +143,11 @@ void Generator::SetState(const phi::Generator::GeneratorState& state) { ...@@ -138,6 +143,11 @@ void Generator::SetState(const phi::Generator::GeneratorState& state) {
std::lock_guard<std::mutex> lock(this->mu_); std::lock_guard<std::mutex> lock(this->mu_);
this->state_ = state; this->state_ = state;
this->engine_ = std::make_shared<std::mt19937_64>(state.cpu_engine); this->engine_ = std::make_shared<std::mt19937_64>(state.cpu_engine);
VLOG(4) << "Set Random state: "
<< "device id: " << (uint64_t)(this->state_.device)
<< ", current_seed: " << this->state_.current_seed
<< ", thread_offset: " << this->state_.thread_offset
<< ", cpu engine: " << *(this->engine_);
} }
uint64_t Generator::GetCurrentSeed() { uint64_t Generator::GetCurrentSeed() {
......
...@@ -39,7 +39,38 @@ void BindGenerator(py::module* m_ptr) { ...@@ -39,7 +39,38 @@ void BindGenerator(py::module* m_ptr) {
.def("current_seed", .def("current_seed",
[](std::shared_ptr<phi::Generator::GeneratorState>& self) { [](std::shared_ptr<phi::Generator::GeneratorState>& self) {
return self->current_seed; return self->current_seed;
})
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// NOTE(shenliang03): Due to the inability to serialize mt19937_64
// type, resulting in a problem with precision under the cpu.
.def(py::pickle(
[](const phi::Generator::GeneratorState& s) { // __getstate__
return py::make_tuple(s.device, s.current_seed, s.thread_offset);
},
[](py::tuple s) { // __setstate__
if (s.size() != 3)
throw std::runtime_error(
"Invalid Random state. Please check the format(device, "
"current_seed, thread_offset).");
phi::Generator::GeneratorState state;
state.device = s[0].cast<std::int64_t>();
state.current_seed = s[1].cast<std::uint64_t>();
state.thread_offset = s[2].cast<std::uint64_t>();
std::seed_seq seq({state.current_seed});
auto engine = std::make_shared<std::mt19937_64>(seq);
state.cpu_engine = *engine;
return state;
}))
#endif
.def("__str__", [](const phi::Generator::GeneratorState& self) {
std::stringstream ostr;
ostr << self.device << " " << self.current_seed << " "
<< self.thread_offset << " " << self.cpu_engine;
return ostr.str();
}); });
py::class_<std::mt19937_64>(m, "mt19937_64", ""); py::class_<std::mt19937_64>(m, "mt19937_64", "");
py::class_<framework::Generator, std::shared_ptr<framework::Generator>>( py::class_<framework::Generator, std::shared_ptr<framework::Generator>>(
m, "Generator") m, "Generator")
......
...@@ -23,6 +23,8 @@ import paddle.fluid as fluid ...@@ -23,6 +23,8 @@ import paddle.fluid as fluid
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import shutil
import tempfile
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
...@@ -169,6 +171,41 @@ class TestGeneratorSeed(unittest.TestCase): ...@@ -169,6 +171,41 @@ class TestGeneratorSeed(unittest.TestCase):
np.testing.assert_allclose(out1_res2, out2_res2, rtol=1e-05) np.testing.assert_allclose(out1_res2, out2_res2, rtol=1e-05)
self.assertTrue(not np.allclose(out1_res2, out1_res1)) self.assertTrue(not np.allclose(out1_res2, out1_res1))
def test_generator_pickle(self):
output_dir = tempfile.mkdtemp()
random_file = os.path.join(output_dir, "random.pdmodel")
fluid.enable_dygraph()
x0 = paddle.randn([120], dtype="float32")
st = paddle.get_cuda_rng_state()
st_dict = {"random_state": st}
print("state: ", st[0])
paddle.save(st_dict, random_file)
x1 = paddle.randn([120], dtype="float32")
lt_dict = paddle.load(random_file)
st = lt_dict["random_state"]
paddle.set_cuda_rng_state(st)
x2 = paddle.randn([120], dtype="float32")
lt_dict = paddle.load(random_file)
st = lt_dict["random_state"]
paddle.set_cuda_rng_state(st)
x3 = paddle.randn([120], dtype="float32")
x1_np = x1.numpy()
x2_np = x2.numpy()
x3_np = x3.numpy()
print(">>>>>>> gaussian random dygraph state load/save >>>>>>>")
np.testing.assert_equal(x1_np, x2_np)
np.testing.assert_equal(x1_np, x2_np)
shutil.rmtree(output_dir)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册