diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 98d66e9764061a834d62586fb9d567a9d92d167c..29904f1a9ef25d3418feec13ff9f56e9dbd44ce4 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -25,7 +25,7 @@ from mindspore._c_expression import typing from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ - check_columns, check_positive, check_pos_int32 + check_columns, check_pos_int32 from . import datasets from . import samplers @@ -319,10 +319,9 @@ def check_generatordataset(method): # These two parameters appear together. raise ValueError("num_shards and shard_id need to be passed in together") if num_shards is not None: - type_check(num_shards, (int,), "num_shards") - check_positive(num_shards, "num_shards") + check_pos_int32(num_shards, "num_shards") if shard_id >= num_shards: - raise ValueError("shard_id should be less than num_shards") + raise ValueError("shard_id should be less than num_shards.") sampler = param_dict.get("sampler") if sampler is not None: @@ -417,7 +416,7 @@ def check_bucket_batch_by_length(method): all_non_negative = all(item > 0 for item in bucket_boundaries) if not all_non_negative: - raise ValueError("bucket_boundaries cannot contain any negative numbers.") + raise ValueError("bucket_boundaries must only contain positive numbers.") for i in range(len(bucket_boundaries) - 1): if not bucket_boundaries[i + 1] > bucket_boundaries[i]: @@ -1044,7 +1043,8 @@ def check_numpyslicesdataset(method): data = param_dict.get("data") column_names = param_dict.get("column_names") - + if not data: + raise ValueError("Argument data cannot be empty") type_check(data, (list, tuple, dict, np.ndarray), "data") if isinstance(data, tuple): type_check(data[0], (list, np.ndarray), "data[0]") diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index a93d569810fab309b0db7c20dbdf3d6032d4989e..14c0ffe7c178ce8db62547083103c84f4713f160 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -62,7 +62,8 @@ def check_from_file(method): def new_method(self, *args, **kwargs): [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs) - check_unique_list_of_words(special_tokens, "special_tokens") + if special_tokens is not None: + check_unique_list_of_words(special_tokens, "special_tokens") type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"]) if vocab_size is not None: check_value(vocab_size, (-1, INT32_MAX), "vocab_size") diff --git a/tests/ut/python/dataset/test_bucket_batch_by_length.py b/tests/ut/python/dataset/test_bucket_batch_by_length.py index 5da7b1636da3090a79957e890e74d364b7d86573..405b8741103c622703bd4e7201620c31b519201a 100644 --- a/tests/ut/python/dataset/test_bucket_batch_by_length.py +++ b/tests/ut/python/dataset/test_bucket_batch_by_length.py @@ -45,6 +45,7 @@ def test_bucket_batch_invalid_input(): bucket_boundaries = [1, 2, 3] empty_bucket_boundaries = [] invalid_bucket_boundaries = ["1", "2", "3"] + zero_start_bucket_boundaries = [0, 2, 3] negative_bucket_boundaries = [1, 2, -3] decreasing_bucket_boundaries = [3, 2, 1] non_increasing_bucket_boundaries = [1, 2, 2] @@ -69,9 +70,13 @@ def test_bucket_batch_invalid_input(): _ = dataset.bucket_batch_by_length(column_names, invalid_bucket_boundaries, bucket_batch_sizes) assert "bucket_boundaries should be a list of int" in str(info.value) + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, zero_start_bucket_boundaries, bucket_batch_sizes) + assert "bucket_boundaries must only contain positive numbers." in str(info.value) + with pytest.raises(ValueError) as info: _ = dataset.bucket_batch_by_length(column_names, negative_bucket_boundaries, bucket_batch_sizes) - assert "bucket_boundaries cannot contain any negative numbers" in str(info.value) + assert "bucket_boundaries must only contain positive numbers." in str(info.value) with pytest.raises(ValueError) as info: _ = dataset.bucket_batch_by_length(column_names, decreasing_bucket_boundaries, bucket_batch_sizes) diff --git a/tests/ut/python/dataset/test_concatenate_op.py b/tests/ut/python/dataset/test_concatenate_op.py index fa293c3b34f6022a70fb66abb52965cf9b56b788..f7a432e4716253256909c376c9c602853bf652e3 100644 --- a/tests/ut/python/dataset/test_concatenate_op.py +++ b/tests/ut/python/dataset/test_concatenate_op.py @@ -108,7 +108,7 @@ def test_concatenate_op_type_mismatch(): with pytest.raises(RuntimeError) as error_info: for _ in data: pass - assert "Tensor types do not match" in repr(error_info.value) + assert "Tensor types do not match" in str(error_info.value) def test_concatenate_op_type_mismatch2(): @@ -123,7 +123,7 @@ def test_concatenate_op_type_mismatch2(): with pytest.raises(RuntimeError) as error_info: for _ in data: pass - assert "Tensor types do not match" in repr(error_info.value) + assert "Tensor types do not match" in str(error_info.value) def test_concatenate_op_incorrect_dim(): @@ -138,13 +138,13 @@ def test_concatenate_op_incorrect_dim(): with pytest.raises(RuntimeError) as error_info: for _ in data: pass - assert "Only 1D tensors supported" in repr(error_info.value) + assert "Only 1D tensors supported" in str(error_info.value) def test_concatenate_op_wrong_axis(): with pytest.raises(ValueError) as error_info: data_trans.Concatenate(2) - assert "only 1D concatenation supported." in repr(error_info.value) + assert "only 1D concatenation supported." in str(error_info.value) def test_concatenate_op_negative_axis(): @@ -167,7 +167,7 @@ def test_concatenate_op_incorrect_input_dim(): with pytest.raises(ValueError) as error_info: data_trans.Concatenate(0, prepend_tensor) - assert "can only prepend 1D arrays." in repr(error_info.value) + assert "can only prepend 1D arrays." in str(error_info.value) if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_dataset_numpy_slices.py b/tests/ut/python/dataset/test_dataset_numpy_slices.py index fe773b0328f5a910375f425f53cbfe9e7225e8f1..791a5674088472be7ae395f528d8b910208b7b21 100644 --- a/tests/ut/python/dataset/test_dataset_numpy_slices.py +++ b/tests/ut/python/dataset/test_dataset_numpy_slices.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -import numpy as np +import sys import pytest +import numpy as np +import pandas as pd import mindspore.dataset as de from mindspore import log as logger import mindspore.dataset.transforms.vision.c_transforms as vision -import pandas as pd def test_numpy_slices_list_1(): @@ -173,6 +174,25 @@ def test_numpy_slices_distributed_sampler(): assert sum([1 for _ in ds]) == 2 +def test_numpy_slices_distributed_shard_limit(): + logger.info("Test Slicing a 1D list.") + + np_data = [1, 2, 3] + num = sys.maxsize + with pytest.raises(ValueError) as err: + de.NumpySlicesDataset(np_data, num_shards=num, shard_id=0, shuffle=False) + assert "Input num_shards is not within the required interval of (1 to 2147483647)." in str(err.value) + + +def test_numpy_slices_distributed_zero_shard(): + logger.info("Test Slicing a 1D list.") + + np_data = [1, 2, 3] + with pytest.raises(ValueError) as err: + de.NumpySlicesDataset(np_data, num_shards=0, shard_id=0, shuffle=False) + assert "Input num_shards is not within the required interval of (1 to 2147483647)." in str(err.value) + + def test_numpy_slices_sequential_sampler(): logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.") @@ -210,6 +230,15 @@ def test_numpy_slices_invalid_empty_column_names(): assert "column_names should not be empty" in str(err.value) +def test_numpy_slices_invalid_empty_data_column(): + logger.info("Test incorrect column_names input") + np_data = [] + + with pytest.raises(ValueError) as err: + de.NumpySlicesDataset(np_data, shuffle=False) + assert "Argument data cannot be empty" in str(err.value) + + if __name__ == "__main__": test_numpy_slices_list_1() test_numpy_slices_list_2() @@ -223,7 +252,10 @@ if __name__ == "__main__": test_numpy_slices_csv_dict() test_numpy_slices_num_samplers() test_numpy_slices_distributed_sampler() + test_numpy_slices_distributed_shard_limit() + test_numpy_slices_distributed_zero_shard() test_numpy_slices_sequential_sampler() test_numpy_slices_invalid_column_names_type() test_numpy_slices_invalid_column_names_string() test_numpy_slices_invalid_empty_column_names() + test_numpy_slices_invalid_empty_data_column() diff --git a/tests/ut/python/dataset/test_fill_op.py b/tests/ut/python/dataset/test_fill_op.py index f138dd15ec9373c6319c66d36cf4bd09b3e9d2a8..657a5297235b993081794a91f530807254cfbf2a 100644 --- a/tests/ut/python/dataset/test_fill_op.py +++ b/tests/ut/python/dataset/test_fill_op.py @@ -82,9 +82,9 @@ def test_fillop_error_handling(): data = data.map(input_columns=["col"], operations=fill_op) with pytest.raises(RuntimeError) as error_info: - for data_row in data: - print(data_row) - assert "Types do not match" in repr(error_info.value) + for _ in data: + pass + assert "Types do not match" in str(error_info.value) if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index 5ecaeff13ac16d07349270115cf29d0f0dad926b..0b4d0dfc8fe9a2438682bdb51d00c81b78f306e3 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -189,7 +189,7 @@ def test_minddataset_invalidate_num_shards(): num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info) + assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) @@ -203,7 +203,7 @@ def test_minddataset_invalidate_shard_id(): num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info) + assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) @@ -217,14 +217,14 @@ def test_minddataset_shard_id_bigger_than_num_shard(): num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info) + assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) with pytest.raises(Exception) as error_info: data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info) + assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) diff --git a/tests/ut/python/dataset/test_nlp.py b/tests/ut/python/dataset/test_nlp.py index 0678316f7bf12df940e16cd4e8733046886f031e..cb517160a19f1d0b536d1768ef553c85943f3fcb 100644 --- a/tests/ut/python/dataset/test_nlp.py +++ b/tests/ut/python/dataset/test_nlp.py @@ -39,8 +39,27 @@ def test_on_tokenized_line(): res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14], [11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32) for i, d in enumerate(data.create_dict_iterator()): - _ = (np.testing.assert_array_equal(d["text"], res[i]), i) + np.testing.assert_array_equal(d["text"], res[i]) + + +def test_on_tokenized_line_with_no_special_tokens(): + data = ds.TextFileDataset("../data/dataset/testVocab/lines.txt", shuffle=False) + jieba_op = text.JiebaTokenizer(HMM_FILE, MP_FILE, mode=text.JiebaMode.MP) + with open(VOCAB_FILE, 'r') as f: + for line in f: + word = line.split(',')[0] + jieba_op.add_word(word) + + data = data.map(input_columns=["text"], operations=jieba_op) + vocab = text.Vocab.from_file(VOCAB_FILE, ",") + lookup = text.Lookup(vocab, "not") + data = data.map(input_columns=["text"], operations=lookup) + res = np.array([[8, 0, 9, 0, 10, 0, 13, 0, 11, 0, 12], + [9, 0, 10, 0, 8, 0, 12, 0, 11, 0, 13]], dtype=np.int32) + for i, d in enumerate(data.create_dict_iterator()): + np.testing.assert_array_equal(d["text"], res[i]) if __name__ == '__main__': test_on_tokenized_line() + test_on_tokenized_line_with_no_special_tokens() diff --git a/tests/ut/python/dataset/test_sync_wait.py b/tests/ut/python/dataset/test_sync_wait.py index a5727a299117940c455914b6b3fb4e4ada1f4b8c..eb2261a5d345a7cb8e6c74535dafc7122658dca3 100644 --- a/tests/ut/python/dataset/test_sync_wait.py +++ b/tests/ut/python/dataset/test_sync_wait.py @@ -14,7 +14,7 @@ # ============================================================================== import numpy as np - +import pytest import mindspore.dataset as ds from mindspore import log as logger @@ -163,7 +163,6 @@ def test_sync_exception_01(): """ logger.info("test_sync_exception_01") shuffle_size = 4 - batch_size = 10 dataset = ds.GeneratorDataset(gen, column_names=["input"]) @@ -171,11 +170,9 @@ def test_sync_exception_01(): dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) - try: - dataset = dataset.shuffle(shuffle_size) - except Exception as e: - assert "shuffle" in str(e) - dataset = dataset.batch(batch_size) + with pytest.raises(RuntimeError) as e: + dataset.shuffle(shuffle_size) + assert "No shuffle after sync operators" in str(e.value) def test_sync_exception_02(): @@ -183,7 +180,6 @@ def test_sync_exception_02(): Test sync: with duplicated condition name """ logger.info("test_sync_exception_02") - batch_size = 6 dataset = ds.GeneratorDataset(gen, column_names=["input"]) @@ -192,11 +188,9 @@ def test_sync_exception_02(): dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) - try: - dataset = dataset.sync_wait(num_batch=2, condition_name="every batch") - except Exception as e: - assert "name" in str(e) - dataset = dataset.batch(batch_size) + with pytest.raises(RuntimeError) as e: + dataset.sync_wait(num_batch=2, condition_name="every batch") + assert "Condition name is already in use" in str(e.value) def test_sync_exception_03(): @@ -209,12 +203,9 @@ def test_sync_exception_03(): aug = Augment(0) # try to create dataset with batch_size < 0 - try: - dataset = dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update) - except Exception as e: - assert "num_batch" in str(e) - - dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + with pytest.raises(ValueError) as e: + dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update) + assert "num_batch need to be greater than 0." in str(e.value) def test_sync_exception_04(): @@ -230,14 +221,13 @@ def test_sync_exception_04(): dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) count = 0 - try: + with pytest.raises(RuntimeError) as e: for _ in dataset.create_dict_iterator(): count += 1 data = {"loss": count} - # dataset.disable_sync() dataset.sync_update(condition_name="every batch", num_batch=-1, data=data) - except Exception as e: - assert "batch" in str(e) + assert "Sync_update batch size can only be positive" in str(e.value) + def test_sync_exception_05(): """ @@ -251,15 +241,15 @@ def test_sync_exception_05(): # try to create dataset with batch_size < 0 dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) - try: + with pytest.raises(RuntimeError) as e: for _ in dataset.create_dict_iterator(): dataset.disable_sync() count += 1 data = {"loss": count} dataset.disable_sync() dataset.sync_update(condition_name="every", data=data) - except Exception as e: - assert "name" in str(e) + assert "Condition name not found" in str(e.value) + if __name__ == "__main__": test_simple_sync_wait() diff --git a/tests/ut/python/dataset/test_uniform_augment.py b/tests/ut/python/dataset/test_uniform_augment.py index 2edd832d79aacdb7dc6df4092ce20e4102ac06bf..e5b66696eaf333fac64889d3032c7a20a6963775 100644 --- a/tests/ut/python/dataset/test_uniform_augment.py +++ b/tests/ut/python/dataset/test_uniform_augment.py @@ -16,6 +16,7 @@ Testing UniformAugment in DE """ import numpy as np +import pytest import mindspore.dataset.engine as de import mindspore.dataset.transforms.vision.c_transforms as C @@ -164,14 +165,13 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2): C.RandomRotation(degrees=45), F.Invert()] - try: + with pytest.raises(TypeError) as e: _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) - except Exception as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "Argument tensor_op_5 with value" \ - " ,)" in str(e) + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Argument tensor_op_5 with value" \ + " ,)" in str(e.value) def test_cpp_uniform_augment_exception_large_numops(num_ops=6):