diff --git a/mindspore/ccsrc/pybind_api/ir/cell_py.cc b/mindspore/ccsrc/pybind_api/ir/cell_py.cc index efd80209a260929d5278f7d3c5e306ce8c8b3f5c..6ec01e9021792b25a5ffe24141486337f7b17b5c 100644 --- a/mindspore/ccsrc/pybind_api/ir/cell_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/cell_py.cc @@ -45,6 +45,19 @@ REGISTER_PYBIND_DEFINE(Cell, ([](const py::module *m) { .def("_del_attr", &Cell::DelAttr, "Delete Cell attr.") .def( "construct", []() { MS_LOG(EXCEPTION) << "we should define `construct` for all `cell`."; }, - "construct"); + "construct") + .def(py::pickle( + [](const Cell &cell) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::str(cell.name())); + }, + [](const py::tuple &tup) { // __setstate__ + if (tup.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + Cell data(tup[0].cast()); + return data; + })); })); } // namespace mindspore diff --git a/tests/ut/python/nn/test_cell.py b/tests/ut/python/nn/test_cell.py index 30066ee855663a42b337606e204176bc334e4f81..0c4668403961924c8bec927387c680e0da2c19ea 100644 --- a/tests/ut/python/nn/test_cell.py +++ b/tests/ut/python/nn/test_cell.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """ test cell """ +import copy import numpy as np import pytest @@ -200,6 +201,11 @@ def test_exceptions(): m.construct() +def test_cell_copy(): + net = ConvNet() + copy.deepcopy(net) + + def test_del(): """ test_del """ ta = Tensor(np.ones([2, 3]))