From 7fa0d9e7e443aa342e42d4c8d44692d3df025e93 Mon Sep 17 00:00:00 2001 From: ms_yan <6576637+ms_yan@user.noreply.gitee.com> Date: Mon, 29 Jun 2020 17:58:51 +0800 Subject: [PATCH] add paramter check for numpyslices and num_shards --- mindspore/dataset/engine/datasets.py | 9 +++++++-- mindspore/dataset/engine/validators.py | 2 ++ tests/ut/python/dataset/test_minddataset_exception.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 49cab9002..b66fb70c1 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3069,7 +3069,7 @@ class GeneratorDataset(MappableDataset): sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is required (default=None, expected order behavior shown in the table). num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). - This argument should be specified only when 'num_samples' is "None". Random accessible input is required. + When this argument is specified, 'num_samples' will not effect. Random accessible input is required. shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only when num_shards is also specified. Random accessible input is required. @@ -4878,6 +4878,11 @@ class _NumpySlicesDataset: else: self.data = (np.array(data),) + # check whether the data length in each column is equal + data_len = [len(data_item) for data_item in self.data] + if data_len[1:] != data_len[:-1]: + raise ValueError("Data length in each column is not equal.") + # Init column_name if column_list is not None: self.column_list = column_list @@ -4966,7 +4971,7 @@ class NumpySlicesDataset(GeneratorDataset): sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is required (default=None, expected order behavior shown in the table). num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). - This argument should be specified only when 'num_samples' is "None". Random accessible input is required. + When this argument is specified, 'num_samples' will not effect. Random accessible input is required. shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only when num_shards is also specified. Random accessible input is required. diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 76a5f1b81..9857608c1 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -153,6 +153,7 @@ def check_sampler_shuffle_shard_options(param_dict): raise RuntimeError("sampler and sharding cannot be specified at the same time.") if num_shards is not None: + check_positive_int32(num_shards, "num_shards") if shard_id is None: raise RuntimeError("num_shards is specified and currently requires shard_id as well.") if shard_id < 0 or shard_id >= num_shards: @@ -529,6 +530,7 @@ def check_generatordataset(method): # These two parameters appear together. raise ValueError("num_shards and shard_id need to be passed in together") if num_shards is not None: + check_positive_int32(num_shards, "num_shards") if shard_id >= num_shards: raise ValueError("shard_id should be less than num_shards") diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index 14cffc8fb..b15944d76 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -185,7 +185,7 @@ def test_minddataset_invalidate_num_shards(): columns_list = ["data", "label"] num_readers = 4 with pytest.raises(Exception, match="shard_id is invalid, "): - data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 0, 1) + data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2) num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 -- GitLab