未验证 提交 48fc16f2 编写于 作者: C chentianyu03 提交者: GitHub

add varbase_copy support CUDAPinnedPlace (#32883)

上级 c3ae0d40
...@@ -1699,6 +1699,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -1699,6 +1699,7 @@ void BindImperative(py::module *m_ptr) {
m.def("varbase_copy", &VarBaseCopy<platform::CPUPlace>); m.def("varbase_copy", &VarBaseCopy<platform::CPUPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::CUDAPlace>); m.def("varbase_copy", &VarBaseCopy<platform::CUDAPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::XPUPlace>); m.def("varbase_copy", &VarBaseCopy<platform::XPUPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::CUDAPinnedPlace>);
m.def( m.def(
"dygraph_partial_grad", "dygraph_partial_grad",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册