From 73e41c893b258628c8faeb8841b0708c9f20ad48 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Wed, 24 Aug 2022 20:18:53 +0800 Subject: [PATCH] Solve the random state serialization (#45327) * fix utest * fix utest * fix utest * fix log * fix random utest --- paddle/fluid/framework/generator.cc | 10 +++++ paddle/fluid/pybind/generator_py.cc | 33 ++++++++++++++++- .../tests/unittests/test_cuda_random_seed.py | 37 +++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/generator.cc b/paddle/fluid/framework/generator.cc index 3ebf09cc7aa..cc2d76a57a9 100644 --- a/paddle/fluid/framework/generator.cc +++ b/paddle/fluid/framework/generator.cc @@ -131,6 +131,11 @@ std::shared_ptr GetCPURandomEngine(uint64_t seed) { phi::Generator::GeneratorState Generator::GetState() { std::lock_guard lock(this->mu_); state_.cpu_engine = *engine_; + VLOG(4) << "Get Random state: " + << "device id: " << (uint64_t)(this->state_.device) + << ", current_seed: " << this->state_.current_seed + << ", thread_offset: " << this->state_.thread_offset + << ", cpu engine: " << *(this->engine_); return this->state_; } @@ -138,6 +143,11 @@ void Generator::SetState(const phi::Generator::GeneratorState& state) { std::lock_guard lock(this->mu_); this->state_ = state; this->engine_ = std::make_shared(state.cpu_engine); + VLOG(4) << "Set Random state: " + << "device id: " << (uint64_t)(this->state_.device) + << ", current_seed: " << this->state_.current_seed + << ", thread_offset: " << this->state_.thread_offset + << ", cpu engine: " << *(this->engine_); } uint64_t Generator::GetCurrentSeed() { diff --git a/paddle/fluid/pybind/generator_py.cc b/paddle/fluid/pybind/generator_py.cc index e456526f844..00ee7b23368 100644 --- a/paddle/fluid/pybind/generator_py.cc +++ b/paddle/fluid/pybind/generator_py.cc @@ -39,7 +39,38 @@ void BindGenerator(py::module* m_ptr) { .def("current_seed", [](std::shared_ptr& self) { return self->current_seed; - }); + }) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + // NOTE(shenliang03): Due to the inability to serialize mt19937_64 + // type, resulting in a problem with precision under the cpu. + .def(py::pickle( + [](const phi::Generator::GeneratorState& s) { // __getstate__ + return py::make_tuple(s.device, s.current_seed, s.thread_offset); + }, + [](py::tuple s) { // __setstate__ + if (s.size() != 3) + throw std::runtime_error( + "Invalid Random state. Please check the format(device, " + "current_seed, thread_offset)."); + + phi::Generator::GeneratorState state; + state.device = s[0].cast(); + state.current_seed = s[1].cast(); + state.thread_offset = s[2].cast(); + + std::seed_seq seq({state.current_seed}); + auto engine = std::make_shared(seq); + state.cpu_engine = *engine; + return state; + })) +#endif + .def("__str__", [](const phi::Generator::GeneratorState& self) { + std::stringstream ostr; + ostr << self.device << " " << self.current_seed << " " + << self.thread_offset << " " << self.cpu_engine; + return ostr.str(); + }); + py::class_(m, "mt19937_64", ""); py::class_>( m, "Generator") diff --git a/python/paddle/fluid/tests/unittests/test_cuda_random_seed.py b/python/paddle/fluid/tests/unittests/test_cuda_random_seed.py index ad854bebd01..dab8745dbbb 100644 --- a/python/paddle/fluid/tests/unittests/test_cuda_random_seed.py +++ b/python/paddle/fluid/tests/unittests/test_cuda_random_seed.py @@ -23,6 +23,8 @@ import paddle.fluid as fluid import numpy as np import paddle import paddle.fluid.core as core +import shutil +import tempfile @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -169,6 +171,41 @@ class TestGeneratorSeed(unittest.TestCase): np.testing.assert_allclose(out1_res2, out2_res2, rtol=1e-05) self.assertTrue(not np.allclose(out1_res2, out1_res1)) + def test_generator_pickle(self): + output_dir = tempfile.mkdtemp() + random_file = os.path.join(output_dir, "random.pdmodel") + + fluid.enable_dygraph() + x0 = paddle.randn([120], dtype="float32") + + st = paddle.get_cuda_rng_state() + st_dict = {"random_state": st} + print("state: ", st[0]) + + paddle.save(st_dict, random_file) + x1 = paddle.randn([120], dtype="float32") + + lt_dict = paddle.load(random_file) + st = lt_dict["random_state"] + + paddle.set_cuda_rng_state(st) + x2 = paddle.randn([120], dtype="float32") + + lt_dict = paddle.load(random_file) + st = lt_dict["random_state"] + paddle.set_cuda_rng_state(st) + x3 = paddle.randn([120], dtype="float32") + + x1_np = x1.numpy() + x2_np = x2.numpy() + x3_np = x3.numpy() + + print(">>>>>>> gaussian random dygraph state load/save >>>>>>>") + np.testing.assert_equal(x1_np, x2_np) + np.testing.assert_equal(x1_np, x2_np) + + shutil.rmtree(output_dir) + if __name__ == "__main__": unittest.main() -- GitLab