diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 9b8c130a829cf1bf9b3854348d4cb961ca7073a3..2e69e2f0ec9c77f3a228f1ad3c569c4b55dea690 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -534,6 +534,7 @@ def check_minddataset(method): check_dataset_file(f) else: check_dataset_file(dataset_file) + check_param_type(nreq_param_int, param_dict, int) check_param_type(nreq_param_list, param_dict, list) @@ -544,6 +545,8 @@ def check_minddataset(method): if (num_shards is not None and shard_id is None) or (num_shards is None and shard_id is not None): raise ValueError("num_shards and shard_id need to be set or not set at the same time") + check_sampler_shuffle_shard_options(param_dict) + return method(*args, **kwargs) return new_method diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index 53a719a985b284742340bcf4437a5706d707fd03..c9f56eb40604852b14fc5814d7c209a7e2f027dd 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -128,7 +128,7 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle(): columns_list = ["data", "file_name", "label"] num_readers = 4 sampler = ds.PKSampler(2) - with pytest.raises(Exception, match="shuffle not allowed when use sampler"): + with pytest.raises(Exception, match="sampler and shuffle cannot be specified at the same time."): data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler, shuffle=False) num_iter = 0 @@ -168,3 +168,46 @@ def test_cv_minddataset_reader_different_page_size(): os.remove("{}.db".format(CV_FILE_NAME)) os.remove(CV1_FILE_NAME) os.remove("{}.db".format(CV1_FILE_NAME)) + +def test_minddataset_invalidate_num_shards(): + create_cv_mindrecord(1) + columns_list = ["data", "label"] + num_readers = 4 + with pytest.raises(Exception, match="shard_id is invalid, "): + data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 0, 1) + num_iter = 0 + for item in data_set.create_dict_iterator(): + num_iter += 1 + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) + +def test_minddataset_invalidate_shard_id(): + create_cv_mindrecord(1) + columns_list = ["data", "label"] + num_readers = 4 + with pytest.raises(Exception, match="shard_id is invalid, "): + data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1) + num_iter = 0 + for item in data_set.create_dict_iterator(): + num_iter += 1 + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) + +def test_minddataset_shard_id_bigger_than_num_shard(): + create_cv_mindrecord(1) + columns_list = ["data", "label"] + num_readers = 4 + with pytest.raises(Exception, match="shard_id is invalid, "): + data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2) + num_iter = 0 + for item in data_set.create_dict_iterator(): + num_iter += 1 + + with pytest.raises(Exception, match="shard_id is invalid, "): + data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) + num_iter = 0 + for item in data_set.create_dict_iterator(): + num_iter += 1 + + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME))