提交 219a716e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3066 fix some batch's get_dataset_size and some text validator inconsistency

Merge pull request !3066 from ZiruiWu/fix_validator
...@@ -1563,7 +1563,7 @@ class BatchDataset(DatasetOp): ...@@ -1563,7 +1563,7 @@ class BatchDataset(DatasetOp):
Number, number of batches. Number, number of batches.
""" """
child_size = self.children[0].get_dataset_size() child_size = self.children[0].get_dataset_size()
if child_size is not None: if child_size is not None and isinstance(self.batch_size, int):
if self.drop_remainder: if self.drop_remainder:
return math.floor(child_size / self.batch_size) return math.floor(child_size / self.batch_size)
return math.ceil(child_size / self.batch_size) return math.ceil(child_size / self.batch_size)
...@@ -3915,7 +3915,6 @@ class RandomDataset(SourceDataset): ...@@ -3915,7 +3915,6 @@ class RandomDataset(SourceDataset):
return self.sampler.is_sharded() return self.sampler.is_sharded()
class Schema: class Schema:
""" """
Class to represent a schema of dataset. Class to represent a schema of dataset.
......
...@@ -23,7 +23,8 @@ import mindspore._c_dataengine as cde ...@@ -23,7 +23,8 @@ import mindspore._c_dataengine as cde
from mindspore._c_expression import typing from mindspore._c_expression import typing
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \ from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \
INT32_MAX, check_value INT32_MAX, check_value, check_positive
def check_unique_list_of_words(words, arg_name): def check_unique_list_of_words(words, arg_name):
"""Check that words is a list and each element is a str without any duplication""" """Check that words is a list and each element is a str without any duplication"""
...@@ -109,7 +110,7 @@ def check_from_dict(method): ...@@ -109,7 +110,7 @@ def check_from_dict(method):
for word, word_id in word_dict.items(): for word, word_id in word_dict.items():
type_check(word, (str,), "word") type_check(word, (str,), "word")
type_check(word_id, (int,), "word_id") type_check(word_id, (int,), "word_id")
check_value(word_id, (-1, INT32_MAX), "word_id") check_value(word_id, (0, INT32_MAX), "word_id")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
...@@ -196,7 +197,7 @@ def check_wordpiece_tokenizer(method): ...@@ -196,7 +197,7 @@ def check_wordpiece_tokenizer(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
[vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ =\ [vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ = \
parse_user_args(method, *args, **kwargs) parse_user_args(method, *args, **kwargs)
if vocab is None: if vocab is None:
raise ValueError("vocab is not provided.") raise ValueError("vocab is not provided.")
...@@ -238,7 +239,7 @@ def check_basic_tokenizer(method): ...@@ -238,7 +239,7 @@ def check_basic_tokenizer(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
[lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ =\ [lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ = \
parse_user_args(method, *args, **kwargs) parse_user_args(method, *args, **kwargs)
if not isinstance(lower_case, bool): if not isinstance(lower_case, bool):
raise TypeError("Wrong input type for lower_case, should be boolean.") raise TypeError("Wrong input type for lower_case, should be boolean.")
...@@ -317,7 +318,7 @@ def check_from_dataset(method): ...@@ -317,7 +318,7 @@ def check_from_dataset(method):
type_check(top_k, (int, type(None)), "top_k") type_check(top_k, (int, type(None)), "top_k")
if isinstance(top_k, int): if isinstance(top_k, int):
check_value(top_k, (0, INT32_MAX), "top_k") check_positive(top_k, "top_k")
type_check(special_first, (bool,), "special_first") type_check(special_first, (bool,), "special_first")
if special_tokens is not None: if special_tokens is not None:
...@@ -343,7 +344,7 @@ def check_ngram(method): ...@@ -343,7 +344,7 @@ def check_ngram(method):
for i, gram in enumerate(n): for i, gram in enumerate(n):
type_check(gram, (int,), "gram[{0}]".format(i)) type_check(gram, (int,), "gram[{0}]".format(i))
check_value(gram, (0, INT32_MAX), "gram_{}".format(i)) check_positive(gram, "gram_{}".format(i))
if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance( if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
left_pad[1], int)): left_pad[1], int)):
......
...@@ -128,7 +128,7 @@ def test_from_dataset_exceptions(): ...@@ -128,7 +128,7 @@ def test_from_dataset_exceptions():
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k) vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k)
assert isinstance(vocab.text.Vocab) assert isinstance(vocab.text.Vocab)
except (TypeError, ValueError, RuntimeError) as e: except (TypeError, ValueError) as e:
assert s in str(e), str(e) assert s in str(e), str(e)
test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.") test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.")
...@@ -136,8 +136,8 @@ def test_from_dataset_exceptions(): ...@@ -136,8 +136,8 @@ def test_from_dataset_exceptions():
"Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)") "Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)")
test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (<class 'str'>,)") test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (<class 'str'>,)")
test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)") test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)")
test_config("text", (2, 3), 0, "top_k needs to be positive number") test_config("text", (2, 3), 0, "top_k must be greater than 0")
test_config([123], (2, 3), 0, "top_k needs to be positive number") test_config([123], (2, 3), -1, "top_k must be greater than 0")
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -72,43 +72,36 @@ def test_simple_ngram(): ...@@ -72,43 +72,36 @@ def test_simple_ngram():
def test_corner_cases(): def test_corner_cases():
""" testing various corner cases and exceptions""" """ testing various corner cases and exceptions"""
def test_config(input_line, output_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "): def test_config(input_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "):
def gen(texts): def gen(texts):
yield (np.array(texts.split(" "), dtype='S'),) yield (np.array(texts.split(" "), dtype='S'),)
dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"]) try:
dataset = dataset.map(input_columns=["text"], operations=text.Ngram(n, l_pad, r_pad, separator=sep)) dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"])
for data in dataset.create_dict_iterator(): dataset = dataset.map(input_columns=["text"], operations=text.Ngram(n, l_pad, r_pad, separator=sep))
assert [d.decode("utf8") for d in data["text"]] == output_line, output_line for data in dataset.create_dict_iterator():
return [d.decode("utf8") for d in data["text"]]
except (ValueError, TypeError) as e:
return str(e)
# test tensor length smaller than n # test tensor length smaller than n
test_config("Lone Star", ["Lone Star", "", "", ""], [2, 3, 4, 5]) assert test_config("Lone Star", [2, 3, 4, 5]) == ["Lone Star", "", "", ""]
# test empty separator # test empty separator
test_config("Beautiful British Columbia", ['BeautifulBritish', 'BritishColumbia'], 2, sep="") assert test_config("Beautiful British Columbia", 2, sep="") == ['BeautifulBritish', 'BritishColumbia']
# test separator with longer length # test separator with longer length
test_config("Beautiful British Columbia", ['Beautiful^-^British^-^Columbia'], 3, sep="^-^") assert test_config("Beautiful British Columbia", 3, sep="^-^") == ['Beautiful^-^British^-^Columbia']
# test left pad != right pad # test left pad != right pad
test_config("Lone Star", ['The Lone Star State'], 4, ("The", 1), ("State", 1)) assert test_config("Lone Star", 4, ("The", 1), ("State", 1)) == ['The Lone Star State']
# test invalid n # test invalid n
try: assert "gram[1] with value [1] is not of type (<class 'int'>,)" in test_config("Yours to Discover", [1, [1]])
test_config("Yours to Discover", "", [0, [1]]) assert "n needs to be a non-empty list" in test_config("Yours to Discover", [])
except Exception as e:
assert "Argument gram[1] with value [1] is not of type (<class 'int'>,)" in str(e)
# test empty n
try:
test_config("Yours to Discover", "", [])
except Exception as e:
assert "n needs to be a non-empty list" in str(e)
# test invalid pad # test invalid pad
try: assert "padding width need to be positive numbers" in test_config("Yours to Discover", [1], ("str", -1))
test_config("Yours to Discover", "", [1], ("str", -1)) assert "pad needs to be a tuple of (str, int)" in test_config("Yours to Discover", [1], ("str", "rts"))
except Exception as e: # test 0 as in valid input
assert "padding width need to be positive numbers" in str(e) assert "gram_0 must be greater than 0" in test_config("Yours to Discover", 0)
# test invalid pad assert "gram_0 must be greater than 0" in test_config("Yours to Discover", [0])
try: assert "gram_1 must be greater than 0" in test_config("Yours to Discover", [1, 0])
test_config("Yours to Discover", "", [1], ("str", "rts"))
except Exception as e:
assert "pad needs to be a tuple of (str, int)" in str(e)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -60,6 +60,15 @@ def test_from_dict_tutorial(): ...@@ -60,6 +60,15 @@ def test_from_dict_tutorial():
ind += 1 ind += 1
def test_from_dict_exception():
try:
vocab = text.Vocab.from_dict({"home": -1, "behind": 0})
if not vocab:
raise ValueError("Vocab is None")
except ValueError as e:
assert "is not within the required interval" in str(e)
def test_from_list(): def test_from_list():
def gen(texts): def gen(texts):
for word in texts.split(" "): for word in texts.split(" "):
...@@ -74,13 +83,11 @@ def test_from_list(): ...@@ -74,13 +83,11 @@ def test_from_list():
for d in data.create_dict_iterator(): for d in data.create_dict_iterator():
res.append(d["text"].item()) res.append(d["text"].item())
return res return res
except ValueError as e: except (ValueError, RuntimeError, TypeError) as e:
return str(e)
except RuntimeError as e:
return str(e)
except TypeError as e:
return str(e) return str(e)
# test basic default config, special_token=None, unknown_token=None
assert test_config("w1 w2 w3", ["w1", "w2", "w3"], None, True, None) == [0, 1, 2]
# test normal operations # test normal operations
assert test_config("w1 w2 w3 s1 s2 ephemeral", ["w1", "w2", "w3"], ["s1", "s2"], True, "s2") == [2, 3, 4, 0, 1, 1] assert test_config("w1 w2 w3 s1 s2 ephemeral", ["w1", "w2", "w3"], ["s1", "s2"], True, "s2") == [2, 3, 4, 0, 1, 1]
assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False, "s2") == [0, 1, 2, 3, 4] assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False, "s2") == [0, 1, 2, 3, 4]
...@@ -129,6 +136,7 @@ def test_from_file(): ...@@ -129,6 +136,7 @@ def test_from_file():
if __name__ == '__main__': if __name__ == '__main__':
test_from_dict_exception()
test_from_list_tutorial() test_from_list_tutorial()
test_from_file_tutorial() test_from_file_tutorial()
test_from_dict_tutorial() test_from_dict_tutorial()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册