diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 69dd9ce0a9768dc05a7f87c07c57c7db443b46bc..d498d57c85f3b7a74f98c2fc33b2001d08dc3af3 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -28,10 +28,10 @@ ITERATORS_LIST = list() def _cleanup(): - for itr in ITERATORS_LIST: - iter_ref = itr() + for itr_ref in ITERATORS_LIST: + itr = itr_ref() if itr is not None: - iter_ref.release() + itr.release() def alter_tree(node): diff --git a/tests/ut/python/dataset/test_iterator.py b/tests/ut/python/dataset/test_iterator.py index d2518e11198a6f6ba1bc531053885d9936393d58..102fd0eea1cbae8fa72702156745d4f402b7a72c 100644 --- a/tests/ut/python/dataset/test_iterator.py +++ b/tests/ut/python/dataset/test_iterator.py @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== import numpy as np +import pytest import mindspore.dataset as ds +from mindspore.dataset.engine.iterators import ITERATORS_LIST, _cleanup DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json" @@ -41,3 +43,41 @@ def test_case_iterator(): check(COLUMNS[0:7]) check(COLUMNS[7:8]) check(COLUMNS[0:2:8]) + + +def test_iterator_weak_ref(): + ITERATORS_LIST.clear() + data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + itr1 = data.create_tuple_iterator() + itr2 = data.create_tuple_iterator() + itr3 = data.create_tuple_iterator() + + assert len(ITERATORS_LIST) == 3 + assert sum(itr() is not None for itr in ITERATORS_LIST) == 3 + + del itr1 + assert len(ITERATORS_LIST) == 3 + assert sum(itr() is not None for itr in ITERATORS_LIST) == 2 + + del itr2 + assert len(ITERATORS_LIST) == 3 + assert sum(itr() is not None for itr in ITERATORS_LIST) == 1 + + del itr3 + assert len(ITERATORS_LIST) == 3 + assert sum(itr() is not None for itr in ITERATORS_LIST) == 0 + + itr1 = data.create_tuple_iterator() + itr2 = data.create_tuple_iterator() + itr3 = data.create_tuple_iterator() + + _cleanup() + with pytest.raises(AttributeError) as info: + itr2.get_next() + assert "object has no attribute 'depipeline'" in str(info.value) + + del itr1 + assert len(ITERATORS_LIST) == 6 + assert sum(itr() is not None for itr in ITERATORS_LIST) == 2 + + _cleanup()