提交 24938289 编写于 作者: W Wei Luning

fix bug in cell pickle and copy

上级 f480e482
...@@ -45,6 +45,19 @@ REGISTER_PYBIND_DEFINE(Cell, ([](const py::module *m) { ...@@ -45,6 +45,19 @@ REGISTER_PYBIND_DEFINE(Cell, ([](const py::module *m) {
.def("_del_attr", &Cell::DelAttr, "Delete Cell attr.") .def("_del_attr", &Cell::DelAttr, "Delete Cell attr.")
.def( .def(
"construct", []() { MS_LOG(EXCEPTION) << "we should define `construct` for all `cell`."; }, "construct", []() { MS_LOG(EXCEPTION) << "we should define `construct` for all `cell`."; },
"construct"); "construct")
.def(py::pickle(
[](const Cell &cell) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::str(cell.name()));
},
[](const py::tuple &tup) { // __setstate__
if (tup.size() != 1) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
Cell data(tup[0].cast<std::string>());
return data;
}));
})); }));
} // namespace mindspore } // namespace mindspore
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" test cell """ """ test cell """
import copy
import numpy as np import numpy as np
import pytest import pytest
...@@ -200,6 +201,11 @@ def test_exceptions(): ...@@ -200,6 +201,11 @@ def test_exceptions():
m.construct() m.construct()
def test_cell_copy():
net = ConvNet()
copy.deepcopy(net)
def test_del(): def test_del():
""" test_del """ """ test_del """
ta = Tensor(np.ones([2, 3])) ta = Tensor(np.ones([2, 3]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册