提交 7fa0d9e7 编写于 作者: M ms_yan

add paramter check for numpyslices and num_shards

上级 2f1b0dc5
...@@ -3069,7 +3069,7 @@ class GeneratorDataset(MappableDataset): ...@@ -3069,7 +3069,7 @@ class GeneratorDataset(MappableDataset):
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
required (default=None, expected order behavior shown in the table). required (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
This argument should be specified only when 'num_samples' is "None". Random accessible input is required. When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required. when num_shards is also specified. Random accessible input is required.
...@@ -4878,6 +4878,11 @@ class _NumpySlicesDataset: ...@@ -4878,6 +4878,11 @@ class _NumpySlicesDataset:
else: else:
self.data = (np.array(data),) self.data = (np.array(data),)
# check whether the data length in each column is equal
data_len = [len(data_item) for data_item in self.data]
if data_len[1:] != data_len[:-1]:
raise ValueError("Data length in each column is not equal.")
# Init column_name # Init column_name
if column_list is not None: if column_list is not None:
self.column_list = column_list self.column_list = column_list
...@@ -4966,7 +4971,7 @@ class NumpySlicesDataset(GeneratorDataset): ...@@ -4966,7 +4971,7 @@ class NumpySlicesDataset(GeneratorDataset):
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
required (default=None, expected order behavior shown in the table). required (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
This argument should be specified only when 'num_samples' is "None". Random accessible input is required. When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required. when num_shards is also specified. Random accessible input is required.
......
...@@ -153,6 +153,7 @@ def check_sampler_shuffle_shard_options(param_dict): ...@@ -153,6 +153,7 @@ def check_sampler_shuffle_shard_options(param_dict):
raise RuntimeError("sampler and sharding cannot be specified at the same time.") raise RuntimeError("sampler and sharding cannot be specified at the same time.")
if num_shards is not None: if num_shards is not None:
check_positive_int32(num_shards, "num_shards")
if shard_id is None: if shard_id is None:
raise RuntimeError("num_shards is specified and currently requires shard_id as well.") raise RuntimeError("num_shards is specified and currently requires shard_id as well.")
if shard_id < 0 or shard_id >= num_shards: if shard_id < 0 or shard_id >= num_shards:
...@@ -529,6 +530,7 @@ def check_generatordataset(method): ...@@ -529,6 +530,7 @@ def check_generatordataset(method):
# These two parameters appear together. # These two parameters appear together.
raise ValueError("num_shards and shard_id need to be passed in together") raise ValueError("num_shards and shard_id need to be passed in together")
if num_shards is not None: if num_shards is not None:
check_positive_int32(num_shards, "num_shards")
if shard_id >= num_shards: if shard_id >= num_shards:
raise ValueError("shard_id should be less than num_shards") raise ValueError("shard_id should be less than num_shards")
......
...@@ -185,7 +185,7 @@ def test_minddataset_invalidate_num_shards(): ...@@ -185,7 +185,7 @@ def test_minddataset_invalidate_num_shards():
columns_list = ["data", "label"] columns_list = ["data", "label"]
num_readers = 4 num_readers = 4
with pytest.raises(Exception, match="shard_id is invalid, "): with pytest.raises(Exception, match="shard_id is invalid, "):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 0, 1) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2)
num_iter = 0 num_iter = 0
for _ in data_set.create_dict_iterator(): for _ in data_set.create_dict_iterator():
num_iter += 1 num_iter += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册