提交 c0c0b098 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!378 Multiple Iterators can cause performance problem

Merge pull request !378 from h.farahat/multi_itr
...@@ -225,11 +225,13 @@ void bindTensor(py::module *m) { ...@@ -225,11 +225,13 @@ void bindTensor(py::module *m) {
(void)py::class_<DataType>(*m, "DataType") (void)py::class_<DataType>(*m, "DataType")
.def(py::init<std::string>()) .def(py::init<std::string>())
.def(py::self == py::self) .def(py::self == py::self)
.def("__str__", &DataType::ToString); .def("__str__", &DataType::ToString)
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
} }
void bindTensorOps1(py::module *m) { void bindTensorOps1(py::module *m) {
(void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp"); (void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp")
.def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
(void)py::class_<NormalizeOp, TensorOp, std::shared_ptr<NormalizeOp>>( (void)py::class_<NormalizeOp, TensorOp, std::shared_ptr<NormalizeOp>>(
*m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.") *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.")
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
"""Built-in iterators. """Built-in iterators.
""" """
from abc import abstractmethod from abc import abstractmethod
import copy
import weakref
from mindspore._c_dataengine import DEPipeline from mindspore._c_dataengine import DEPipeline
from mindspore._c_dataengine import OpName from mindspore._c_dataengine import OpName
...@@ -27,7 +29,9 @@ ITERATORS_LIST = list() ...@@ -27,7 +29,9 @@ ITERATORS_LIST = list()
def _cleanup(): def _cleanup():
for itr in ITERATORS_LIST: for itr in ITERATORS_LIST:
itr.release() iter_ref = itr()
if itr is not None:
iter_ref.release()
def alter_tree(node): def alter_tree(node):
...@@ -73,8 +77,10 @@ class Iterator: ...@@ -73,8 +77,10 @@ class Iterator:
""" """
def __init__(self, dataset): def __init__(self, dataset):
ITERATORS_LIST.append(self) ITERATORS_LIST.append(weakref.ref(self))
self.dataset = alter_tree(dataset) # create a copy of tree and work on it.
self.dataset = copy.deepcopy(dataset)
self.dataset = alter_tree(self.dataset)
if not self.__is_tree(): if not self.__is_tree():
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
self.depipeline = DEPipeline() self.depipeline = DEPipeline()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册