From 3c02c82771ad4075d90875a00619688f0901ca30 Mon Sep 17 00:00:00 2001
From: hesham <h.farahat@huawei.com>
Date: Fri, 17 Apr 2020 10:12:50 -0400
Subject: [PATCH] Bug in weak reference. Add new test cases

---
 mindspore/dataset/engine/iterators.py    |  6 ++--
 tests/ut/python/dataset/test_iterator.py | 40 ++++++++++++++++++++++++
 2 files changed, 43 insertions(+), 3 deletions(-)

diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py
index 69dd9ce0a..d498d57c8 100644
--- a/mindspore/dataset/engine/iterators.py
+++ b/mindspore/dataset/engine/iterators.py
@@ -28,10 +28,10 @@ ITERATORS_LIST = list()
 
 
 def _cleanup():
-    for itr in ITERATORS_LIST:
-        iter_ref = itr()
+    for itr_ref in ITERATORS_LIST:
+        itr = itr_ref()
         if itr is not None:
-            iter_ref.release()
+            itr.release()
 
 
 def alter_tree(node):
diff --git a/tests/ut/python/dataset/test_iterator.py b/tests/ut/python/dataset/test_iterator.py
index d2518e111..102fd0eea 100644
--- a/tests/ut/python/dataset/test_iterator.py
+++ b/tests/ut/python/dataset/test_iterator.py
@@ -13,8 +13,10 @@
 # limitations under the License.
 # ==============================================================================
 import numpy as np
+import pytest
 
 import mindspore.dataset as ds
+from mindspore.dataset.engine.iterators import ITERATORS_LIST, _cleanup
 
 DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
 SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
@@ -41,3 +43,41 @@ def test_case_iterator():
     check(COLUMNS[0:7])
     check(COLUMNS[7:8])
     check(COLUMNS[0:2:8])
+
+
+def test_iterator_weak_ref():
+    ITERATORS_LIST.clear()
+    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
+    itr1 = data.create_tuple_iterator()
+    itr2 = data.create_tuple_iterator()
+    itr3 = data.create_tuple_iterator()
+
+    assert len(ITERATORS_LIST) == 3
+    assert sum(itr() is not None for itr in ITERATORS_LIST) == 3
+
+    del itr1
+    assert len(ITERATORS_LIST) == 3
+    assert sum(itr() is not None for itr in ITERATORS_LIST) == 2
+
+    del itr2
+    assert len(ITERATORS_LIST) == 3
+    assert sum(itr() is not None for itr in ITERATORS_LIST) == 1
+
+    del itr3
+    assert len(ITERATORS_LIST) == 3
+    assert sum(itr() is not None for itr in ITERATORS_LIST) == 0
+
+    itr1 = data.create_tuple_iterator()
+    itr2 = data.create_tuple_iterator()
+    itr3 = data.create_tuple_iterator()
+
+    _cleanup()
+    with pytest.raises(AttributeError) as info:
+        itr2.get_next()
+    assert "object has no attribute 'depipeline'" in str(info.value)
+
+    del itr1
+    assert len(ITERATORS_LIST) == 6
+    assert sum(itr() is not None for itr in ITERATORS_LIST) == 2
+
+    _cleanup()
-- 
GitLab