diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 806df64fdc4efb90bfc7c06e0631b00548a71a5e..268a66c0cf0b90ad87b539cb1e27c7307c68ef86 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -66,11 +66,10 @@ def _alter_node(node): class Iterator: """ - General Iterator over a dataset. - - Attributes: - dataset: Dataset to be iterated over + General Iterator over a dataset. + Attributes: + dataset: Dataset to be iterated over """ def __init__(self, dataset): @@ -86,6 +85,7 @@ class Iterator: root = self.__convert_node_postorder(self.dataset) self.depipeline.AssignRootNode(root) self.depipeline.LaunchTreeExec() + self._index = 0 def __is_tree_node(self, node): """Check if a node is tree node.""" @@ -185,10 +185,7 @@ class Iterator: Iterator.__print_local(input_op, level + 1) def print(self): - """ - Print the dataset tree - - """ + """Print the dataset tree""" self.__print_local(self.dataset, 0) def release(self): @@ -202,7 +199,10 @@ class Iterator: def __next__(self): data = self.get_next() if not data: + if self._index == 0: + logger.warning("No records available.") raise StopIteration + self._index += 1 return data def get_output_shapes(self): @@ -234,7 +234,7 @@ class DictIterator(Iterator): def get_next(self): """ - Returns the next record in the dataset as dictionary + Returns the next record in the dataset as dictionary Returns: Dict, the next record in the dataset. @@ -260,7 +260,7 @@ class TupleIterator(Iterator): def get_next(self): """ - Returns the next record in the dataset as a list + Returns the next record in the dataset as a list Returns: List, the next record in the dataset. diff --git a/mindspore/mindrecord/filewriter.py b/mindspore/mindrecord/filewriter.py index 0e44b618579b53070135b51ca37e45373c77f995..d1471f47cb72b2338af270171b44ccdba06ee374 100644 --- a/mindspore/mindrecord/filewriter.py +++ b/mindspore/mindrecord/filewriter.py @@ -328,13 +328,20 @@ class FileWriter: self._generator.build() self._generator.write_to_db() + mindrecord_files = [] + index_files = [] # change the file mode to 600 for item in self._paths: if os.path.exists(item): os.chmod(item, stat.S_IRUSR | stat.S_IWUSR) + mindrecord_files.append(item) index_file = item + ".db" if os.path.exists(index_file): os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR) + index_files.append(index_file) + + logger.info("The list of mindrecord files created are: {}, and the list of index files are: {}".format( + mindrecord_files, index_files)) return ret diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index 99b31b64ec2af1aff71e1e2aff4007d6f59e51d7..8b8cbc807a9ed5ebe63a0796f1d2fe1cae3a35ae 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -25,6 +25,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision import numpy as np import pytest from mindspore._c_dataengine import InterpolationMode +from mindspore.dataset.transforms.vision import Inter from mindspore import log as logger import mindspore.dataset as ds @@ -151,6 +152,51 @@ def test_cv_minddataset_dataset_size(add_and_remove_cv_file): assert data_set.get_dataset_size() == 3 +def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "label"] + num_readers = 4 + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) + decode_op = vision.Decode() + data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2) + resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) + data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2) + data_set = data_set.batch(2) + data_set = data_set.repeat(2) + num_iter = 0 + labels = [] + for item in data_set.create_dict_iterator(): + logger.info("-------------- get dataset size {} -----------------".format(num_iter)) + logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) + logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) + num_iter += 1 + labels.append(item["label"]) + assert num_iter == 10 + logger.info("repeat shuffle: {}".format(labels)) + assert len(labels) == 10 + assert labels[0:5] == labels[0:5] + assert labels[0:5] != labels[5:5] + + +def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "label"] + num_readers = 4 + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) + decode_op = vision.Decode() + data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2) + resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) + data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2) + data_set = data_set.batch(32, drop_remainder=True) + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- get dataset size {} -----------------".format(num_iter)) + logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) + logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) + num_iter += 1 + assert num_iter == 0 + + def test_cv_minddataset_issue_888(add_and_remove_cv_file): """issue 888 test.""" columns_list = ["data", "label"]