提交 4f575500 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!29 Add some prompt information for ease of use

Merge pull request !29 from jonyguo/add_more_log_info_and_testcase
......@@ -66,11 +66,10 @@ def _alter_node(node):
class Iterator:
"""
General Iterator over a dataset.
Attributes:
dataset: Dataset to be iterated over
General Iterator over a dataset.
Attributes:
dataset: Dataset to be iterated over
"""
def __init__(self, dataset):
......@@ -86,6 +85,7 @@ class Iterator:
root = self.__convert_node_postorder(self.dataset)
self.depipeline.AssignRootNode(root)
self.depipeline.LaunchTreeExec()
self._index = 0
def __is_tree_node(self, node):
"""Check if a node is tree node."""
......@@ -185,10 +185,7 @@ class Iterator:
Iterator.__print_local(input_op, level + 1)
def print(self):
"""
Print the dataset tree
"""
"""Print the dataset tree"""
self.__print_local(self.dataset, 0)
def release(self):
......@@ -202,7 +199,10 @@ class Iterator:
def __next__(self):
data = self.get_next()
if not data:
if self._index == 0:
logger.warning("No records available.")
raise StopIteration
self._index += 1
return data
def get_output_shapes(self):
......@@ -234,7 +234,7 @@ class DictIterator(Iterator):
def get_next(self):
"""
Returns the next record in the dataset as dictionary
Returns the next record in the dataset as dictionary
Returns:
Dict, the next record in the dataset.
......@@ -260,7 +260,7 @@ class TupleIterator(Iterator):
def get_next(self):
"""
Returns the next record in the dataset as a list
Returns the next record in the dataset as a list
Returns:
List, the next record in the dataset.
......
......@@ -328,13 +328,20 @@ class FileWriter:
self._generator.build()
self._generator.write_to_db()
mindrecord_files = []
index_files = []
# change the file mode to 600
for item in self._paths:
if os.path.exists(item):
os.chmod(item, stat.S_IRUSR | stat.S_IWUSR)
mindrecord_files.append(item)
index_file = item + ".db"
if os.path.exists(index_file):
os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR)
index_files.append(index_file)
logger.info("The list of mindrecord files created are: {}, and the list of index files are: {}".format(
mindrecord_files, index_files))
return ret
......
......@@ -25,6 +25,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
import numpy as np
import pytest
from mindspore._c_dataengine import InterpolationMode
from mindspore.dataset.transforms.vision import Inter
from mindspore import log as logger
import mindspore.dataset as ds
......@@ -151,6 +152,51 @@ def test_cv_minddataset_dataset_size(add_and_remove_cv_file):
assert data_set.get_dataset_size() == 3
def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "label"]
num_readers = 4
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
decode_op = vision.Decode()
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2)
resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2)
data_set = data_set.batch(2)
data_set = data_set.repeat(2)
num_iter = 0
labels = []
for item in data_set.create_dict_iterator():
logger.info("-------------- get dataset size {} -----------------".format(num_iter))
logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
num_iter += 1
labels.append(item["label"])
assert num_iter == 10
logger.info("repeat shuffle: {}".format(labels))
assert len(labels) == 10
assert labels[0:5] == labels[0:5]
assert labels[0:5] != labels[5:5]
def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "label"]
num_readers = 4
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
decode_op = vision.Decode()
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2)
resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2)
data_set = data_set.batch(32, drop_remainder=True)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- get dataset size {} -----------------".format(num_iter))
logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
num_iter += 1
assert num_iter == 0
def test_cv_minddataset_issue_888(add_and_remove_cv_file):
"""issue 888 test."""
columns_list = ["data", "label"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册