From 3c02c82771ad4075d90875a00619688f0901ca30 Mon Sep 17 00:00:00 2001 From: hesham <h.farahat@huawei.com> Date: Fri, 17 Apr 2020 10:12:50 -0400 Subject: [PATCH] Bug in weak reference. Add new test cases --- mindspore/dataset/engine/iterators.py | 6 ++-- tests/ut/python/dataset/test_iterator.py | 40 ++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 69dd9ce0a..d498d57c8 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -28,10 +28,10 @@ ITERATORS_LIST = list() def _cleanup(): - for itr in ITERATORS_LIST: - iter_ref = itr() + for itr_ref in ITERATORS_LIST: + itr = itr_ref() if itr is not None: - iter_ref.release() + itr.release() def alter_tree(node): diff --git a/tests/ut/python/dataset/test_iterator.py b/tests/ut/python/dataset/test_iterator.py index d2518e111..102fd0eea 100644 --- a/tests/ut/python/dataset/test_iterator.py +++ b/tests/ut/python/dataset/test_iterator.py @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== import numpy as np +import pytest import mindspore.dataset as ds +from mindspore.dataset.engine.iterators import ITERATORS_LIST, _cleanup DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json" @@ -41,3 +43,41 @@ def test_case_iterator(): check(COLUMNS[0:7]) check(COLUMNS[7:8]) check(COLUMNS[0:2:8]) + + +def test_iterator_weak_ref(): + ITERATORS_LIST.clear() + data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + itr1 = data.create_tuple_iterator() + itr2 = data.create_tuple_iterator() + itr3 = data.create_tuple_iterator() + + assert len(ITERATORS_LIST) == 3 + assert sum(itr() is not None for itr in ITERATORS_LIST) == 3 + + del itr1 + assert len(ITERATORS_LIST) == 3 + assert sum(itr() is not None for itr in ITERATORS_LIST) == 2 + + del itr2 + assert len(ITERATORS_LIST) == 3 + assert sum(itr() is not None for itr in ITERATORS_LIST) == 1 + + del itr3 + assert len(ITERATORS_LIST) == 3 + assert sum(itr() is not None for itr in ITERATORS_LIST) == 0 + + itr1 = data.create_tuple_iterator() + itr2 = data.create_tuple_iterator() + itr3 = data.create_tuple_iterator() + + _cleanup() + with pytest.raises(AttributeError) as info: + itr2.get_next() + assert "object has no attribute 'depipeline'" in str(info.value) + + del itr1 + assert len(ITERATORS_LIST) == 6 + assert sum(itr() is not None for itr in ITERATORS_LIST) == 2 + + _cleanup() -- GitLab