提交 219a716e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3066 fix some batch's get_dataset_size and some text validator inconsistency

Merge pull request !3066 from ZiruiWu/fix_validator
......@@ -1563,7 +1563,7 @@ class BatchDataset(DatasetOp):
Number, number of batches.
"""
child_size = self.children[0].get_dataset_size()
if child_size is not None:
if child_size is not None and isinstance(self.batch_size, int):
if self.drop_remainder:
return math.floor(child_size / self.batch_size)
return math.ceil(child_size / self.batch_size)
......@@ -3915,7 +3915,6 @@ class RandomDataset(SourceDataset):
return self.sampler.is_sharded()
class Schema:
"""
Class to represent a schema of dataset.
......
......@@ -23,7 +23,8 @@ import mindspore._c_dataengine as cde
from mindspore._c_expression import typing
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \
INT32_MAX, check_value
INT32_MAX, check_value, check_positive
def check_unique_list_of_words(words, arg_name):
"""Check that words is a list and each element is a str without any duplication"""
......@@ -109,7 +110,7 @@ def check_from_dict(method):
for word, word_id in word_dict.items():
type_check(word, (str,), "word")
type_check(word_id, (int,), "word_id")
check_value(word_id, (-1, INT32_MAX), "word_id")
check_value(word_id, (0, INT32_MAX), "word_id")
return method(self, *args, **kwargs)
return new_method
......@@ -196,7 +197,7 @@ def check_wordpiece_tokenizer(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ =\
[vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ = \
parse_user_args(method, *args, **kwargs)
if vocab is None:
raise ValueError("vocab is not provided.")
......@@ -238,7 +239,7 @@ def check_basic_tokenizer(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ =\
[lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ = \
parse_user_args(method, *args, **kwargs)
if not isinstance(lower_case, bool):
raise TypeError("Wrong input type for lower_case, should be boolean.")
......@@ -317,7 +318,7 @@ def check_from_dataset(method):
type_check(top_k, (int, type(None)), "top_k")
if isinstance(top_k, int):
check_value(top_k, (0, INT32_MAX), "top_k")
check_positive(top_k, "top_k")
type_check(special_first, (bool,), "special_first")
if special_tokens is not None:
......@@ -343,7 +344,7 @@ def check_ngram(method):
for i, gram in enumerate(n):
type_check(gram, (int,), "gram[{0}]".format(i))
check_value(gram, (0, INT32_MAX), "gram_{}".format(i))
check_positive(gram, "gram_{}".format(i))
if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
left_pad[1], int)):
......
......@@ -128,7 +128,7 @@ def test_from_dataset_exceptions():
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k)
assert isinstance(vocab.text.Vocab)
except (TypeError, ValueError, RuntimeError) as e:
except (TypeError, ValueError) as e:
assert s in str(e), str(e)
test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.")
......@@ -136,8 +136,8 @@ def test_from_dataset_exceptions():
"Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)")
test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (<class 'str'>,)")
test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)")
test_config("text", (2, 3), 0, "top_k needs to be positive number")
test_config([123], (2, 3), 0, "top_k needs to be positive number")
test_config("text", (2, 3), 0, "top_k must be greater than 0")
test_config([123], (2, 3), -1, "top_k must be greater than 0")
if __name__ == '__main__':
......
......@@ -72,43 +72,36 @@ def test_simple_ngram():
def test_corner_cases():
""" testing various corner cases and exceptions"""
def test_config(input_line, output_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "):
def test_config(input_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "):
def gen(texts):
yield (np.array(texts.split(" "), dtype='S'),)
dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"])
dataset = dataset.map(input_columns=["text"], operations=text.Ngram(n, l_pad, r_pad, separator=sep))
for data in dataset.create_dict_iterator():
assert [d.decode("utf8") for d in data["text"]] == output_line, output_line
try:
dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"])
dataset = dataset.map(input_columns=["text"], operations=text.Ngram(n, l_pad, r_pad, separator=sep))
for data in dataset.create_dict_iterator():
return [d.decode("utf8") for d in data["text"]]
except (ValueError, TypeError) as e:
return str(e)
# test tensor length smaller than n
test_config("Lone Star", ["Lone Star", "", "", ""], [2, 3, 4, 5])
assert test_config("Lone Star", [2, 3, 4, 5]) == ["Lone Star", "", "", ""]
# test empty separator
test_config("Beautiful British Columbia", ['BeautifulBritish', 'BritishColumbia'], 2, sep="")
assert test_config("Beautiful British Columbia", 2, sep="") == ['BeautifulBritish', 'BritishColumbia']
# test separator with longer length
test_config("Beautiful British Columbia", ['Beautiful^-^British^-^Columbia'], 3, sep="^-^")
assert test_config("Beautiful British Columbia", 3, sep="^-^") == ['Beautiful^-^British^-^Columbia']
# test left pad != right pad
test_config("Lone Star", ['The Lone Star State'], 4, ("The", 1), ("State", 1))
assert test_config("Lone Star", 4, ("The", 1), ("State", 1)) == ['The Lone Star State']
# test invalid n
try:
test_config("Yours to Discover", "", [0, [1]])
except Exception as e:
assert "Argument gram[1] with value [1] is not of type (<class 'int'>,)" in str(e)
# test empty n
try:
test_config("Yours to Discover", "", [])
except Exception as e:
assert "n needs to be a non-empty list" in str(e)
assert "gram[1] with value [1] is not of type (<class 'int'>,)" in test_config("Yours to Discover", [1, [1]])
assert "n needs to be a non-empty list" in test_config("Yours to Discover", [])
# test invalid pad
try:
test_config("Yours to Discover", "", [1], ("str", -1))
except Exception as e:
assert "padding width need to be positive numbers" in str(e)
# test invalid pad
try:
test_config("Yours to Discover", "", [1], ("str", "rts"))
except Exception as e:
assert "pad needs to be a tuple of (str, int)" in str(e)
assert "padding width need to be positive numbers" in test_config("Yours to Discover", [1], ("str", -1))
assert "pad needs to be a tuple of (str, int)" in test_config("Yours to Discover", [1], ("str", "rts"))
# test 0 as in valid input
assert "gram_0 must be greater than 0" in test_config("Yours to Discover", 0)
assert "gram_0 must be greater than 0" in test_config("Yours to Discover", [0])
assert "gram_1 must be greater than 0" in test_config("Yours to Discover", [1, 0])
if __name__ == '__main__':
......
......@@ -60,6 +60,15 @@ def test_from_dict_tutorial():
ind += 1
def test_from_dict_exception():
try:
vocab = text.Vocab.from_dict({"home": -1, "behind": 0})
if not vocab:
raise ValueError("Vocab is None")
except ValueError as e:
assert "is not within the required interval" in str(e)
def test_from_list():
def gen(texts):
for word in texts.split(" "):
......@@ -74,13 +83,11 @@ def test_from_list():
for d in data.create_dict_iterator():
res.append(d["text"].item())
return res
except ValueError as e:
return str(e)
except RuntimeError as e:
return str(e)
except TypeError as e:
except (ValueError, RuntimeError, TypeError) as e:
return str(e)
# test basic default config, special_token=None, unknown_token=None
assert test_config("w1 w2 w3", ["w1", "w2", "w3"], None, True, None) == [0, 1, 2]
# test normal operations
assert test_config("w1 w2 w3 s1 s2 ephemeral", ["w1", "w2", "w3"], ["s1", "s2"], True, "s2") == [2, 3, 4, 0, 1, 1]
assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False, "s2") == [0, 1, 2, 3, 4]
......@@ -129,6 +136,7 @@ def test_from_file():
if __name__ == '__main__':
test_from_dict_exception()
test_from_list_tutorial()
test_from_file_tutorial()
test_from_dict_tutorial()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册