diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index adfe54a02e02d8ce474d342aeba42e00cef47624..b4d22a4a013adcb23da23b58c713540179cc9b0f 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -243,6 +243,8 @@ def check_param_type(param_list, param_dict, param_type): if param_dict.get(param_name) is not None: if param_name == 'num_parallel_workers': check_num_parallel_workers(param_dict.get(param_name)) + if param_name == 'num_samples': + check_num_samples(param_dict.get(param_name)) else: check_type(param_dict.get(param_name), param_name, param_type) @@ -262,6 +264,12 @@ def check_num_parallel_workers(value): raise ValueError("num_parallel_workers exceeds the boundary between 0 and {}!".format(cpu_count())) +def check_num_samples(value): + check_type(value, 'num_samples', int) + if value <= 0: + raise ValueError("num_samples must be greater than 0!") + + def check_dataset_dir(dataset_dir): if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK): raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir)) diff --git a/tests/ut/python/dataset/test_datasets_sharding.py b/tests/ut/python/dataset/test_datasets_sharding.py index b178298e33e1fd413e0f2dc61f6e9d054b8461b2..b398391fb70008cc913dedaeb72e3c5129b6eab3 100644 --- a/tests/ut/python/dataset/test_datasets_sharding.py +++ b/tests/ut/python/dataset/test_datasets_sharding.py @@ -33,14 +33,14 @@ def test_imagefolder_shardings(print_res=False): # total 44 rows in dataset assert (sharding_config(4, 0, 5, False, dict()) == [0, 0, 0, 1, 1]) # 5 rows assert (sharding_config(4, 0, 12, False, dict()) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) # 11 rows - assert (sharding_config(4, 3, 0, False, dict()) == [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) # 11 rows + assert (sharding_config(4, 3, None, False, dict()) == [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) # 11 rows # total 22 in dataset rows because of class indexing which takes only 2 folders - assert (len(sharding_config(4, 0, 0, True, {"class1": 111, "class2": 999})) == 6) + assert (len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6) assert (len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3) # test with repeat assert (sharding_config(4, 0, 12, False, dict(), 3) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] * 3) assert (sharding_config(4, 0, 5, False, dict(), 5) == [0, 0, 0, 1, 1] * 5) - assert (len(sharding_config(5, 1, 0, True, {"class1": 111, "class2": 999}, 4)) == 20) + assert (len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20) def test_manifest_shardings(print_res=False): diff --git a/tests/ut/python/dataset/test_exceptions.py b/tests/ut/python/dataset/test_exceptions.py index 7668eeb2a8d2d5b69d2304a594ed8aba8ef70264..631f2ddcbcb671bdd0e1b764fe8e3858d2823cbe 100644 --- a/tests/ut/python/dataset/test_exceptions.py +++ b/tests/ut/python/dataset/test_exceptions.py @@ -18,6 +18,7 @@ import pytest import mindspore.dataset as ds DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" def skip_test_exception(): @@ -29,5 +30,23 @@ def skip_test_exception(): assert "The shape size 1 of input tensor is invalid" in str(info.value) +def test_sample_exception(): + num_samples = 0 + with pytest.raises(ValueError) as info: + data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) + assert "num_samples must be greater than 0" in str(info.value) + num_samples = -1 + with pytest.raises(ValueError) as info: + data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) + assert "num_samples must be greater than 0" in str(info.value) + num_samples = 1 + data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) + data = data.map(input_columns=["image"], operations=vision.Decode()) + data = data.map(input_columns=["image"], operations=vision.Resize((100, 100))) + num_iters = 0 + for item in data.create_dict_iterator(): + num_iters += 1 + assert num_iters == 1 + if __name__ == '__main__': test_exception()