diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py
index 69dd9ce0a9768dc05a7f87c07c57c7db443b46bc..d498d57c85f3b7a74f98c2fc33b2001d08dc3af3 100644
--- a/mindspore/dataset/engine/iterators.py
+++ b/mindspore/dataset/engine/iterators.py
@@ -28,10 +28,10 @@ ITERATORS_LIST = list()
 
 
 def _cleanup():
-    for itr in ITERATORS_LIST:
-        iter_ref = itr()
+    for itr_ref in ITERATORS_LIST:
+        itr = itr_ref()
         if itr is not None:
-            iter_ref.release()
+            itr.release()
 
 
 def alter_tree(node):
diff --git a/tests/ut/python/dataset/test_iterator.py b/tests/ut/python/dataset/test_iterator.py
index d2518e11198a6f6ba1bc531053885d9936393d58..102fd0eea1cbae8fa72702156745d4f402b7a72c 100644
--- a/tests/ut/python/dataset/test_iterator.py
+++ b/tests/ut/python/dataset/test_iterator.py
@@ -13,8 +13,10 @@
 # limitations under the License.
 # ==============================================================================
 import numpy as np
+import pytest
 
 import mindspore.dataset as ds
+from mindspore.dataset.engine.iterators import ITERATORS_LIST, _cleanup
 
 DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
 SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
@@ -41,3 +43,41 @@ def test_case_iterator():
     check(COLUMNS[0:7])
     check(COLUMNS[7:8])
     check(COLUMNS[0:2:8])
+
+
+def test_iterator_weak_ref():
+    ITERATORS_LIST.clear()
+    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
+    itr1 = data.create_tuple_iterator()
+    itr2 = data.create_tuple_iterator()
+    itr3 = data.create_tuple_iterator()
+
+    assert len(ITERATORS_LIST) == 3
+    assert sum(itr() is not None for itr in ITERATORS_LIST) == 3
+
+    del itr1
+    assert len(ITERATORS_LIST) == 3
+    assert sum(itr() is not None for itr in ITERATORS_LIST) == 2
+
+    del itr2
+    assert len(ITERATORS_LIST) == 3
+    assert sum(itr() is not None for itr in ITERATORS_LIST) == 1
+
+    del itr3
+    assert len(ITERATORS_LIST) == 3
+    assert sum(itr() is not None for itr in ITERATORS_LIST) == 0
+
+    itr1 = data.create_tuple_iterator()
+    itr2 = data.create_tuple_iterator()
+    itr3 = data.create_tuple_iterator()
+
+    _cleanup()
+    with pytest.raises(AttributeError) as info:
+        itr2.get_next()
+    assert "object has no attribute 'depipeline'" in str(info.value)
+
+    del itr1
+    assert len(ITERATORS_LIST) == 6
+    assert sum(itr() is not None for itr in ITERATORS_LIST) == 2
+
+    _cleanup()