未验证 提交 b20d22df 编写于 作者: D duanyanhui 提交者: GitHub

fix generator pickle for custom device (#55247)

上级 31bf1e88
...@@ -40,7 +40,8 @@ void BindGenerator(py::module* m_ptr) { ...@@ -40,7 +40,8 @@ void BindGenerator(py::module* m_ptr) {
[](std::shared_ptr<phi::Generator::GeneratorState>& self) { [](std::shared_ptr<phi::Generator::GeneratorState>& self) {
return self->current_seed; return self->current_seed;
}) })
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU)
// NOTE(shenliang03): Due to the inability to serialize mt19937_64 // NOTE(shenliang03): Due to the inability to serialize mt19937_64
// type, resulting in a problem with precision under the cpu. // type, resulting in a problem with precision under the cpu.
.def(py::pickle( .def(py::pickle(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册