Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
219a716e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
219a716e
编写于
7月 15, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 15, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3066 fix some batch's get_dataset_size and some text validator inconsistency
Merge pull request !3066 from ZiruiWu/fix_validator
上级
bed93a9e
63185cb2
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
44 addition
and
43 deletion
+44
-43
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+1
-2
mindspore/dataset/text/validators.py
mindspore/dataset/text/validators.py
+7
-6
tests/ut/python/dataset/test_from_dataset.py
tests/ut/python/dataset/test_from_dataset.py
+3
-3
tests/ut/python/dataset/test_ngram_op.py
tests/ut/python/dataset/test_ngram_op.py
+20
-27
tests/ut/python/dataset/test_vocab.py
tests/ut/python/dataset/test_vocab.py
+13
-5
未找到文件。
mindspore/dataset/engine/datasets.py
浏览文件 @
219a716e
...
...
@@ -1563,7 +1563,7 @@ class BatchDataset(DatasetOp):
Number, number of batches.
"""
child_size
=
self
.
children
[
0
].
get_dataset_size
()
if
child_size
is
not
None
:
if
child_size
is
not
None
and
isinstance
(
self
.
batch_size
,
int
)
:
if
self
.
drop_remainder
:
return
math
.
floor
(
child_size
/
self
.
batch_size
)
return
math
.
ceil
(
child_size
/
self
.
batch_size
)
...
...
@@ -3915,7 +3915,6 @@ class RandomDataset(SourceDataset):
return
self
.
sampler
.
is_sharded
()
class
Schema
:
"""
Class to represent a schema of dataset.
...
...
mindspore/dataset/text/validators.py
浏览文件 @
219a716e
...
...
@@ -23,7 +23,8 @@ import mindspore._c_dataengine as cde
from
mindspore._c_expression
import
typing
from
..core.validator_helpers
import
parse_user_args
,
type_check
,
type_check_list
,
check_uint32
,
\
INT32_MAX
,
check_value
INT32_MAX
,
check_value
,
check_positive
def
check_unique_list_of_words
(
words
,
arg_name
):
"""Check that words is a list and each element is a str without any duplication"""
...
...
@@ -109,7 +110,7 @@ def check_from_dict(method):
for
word
,
word_id
in
word_dict
.
items
():
type_check
(
word
,
(
str
,),
"word"
)
type_check
(
word_id
,
(
int
,),
"word_id"
)
check_value
(
word_id
,
(
-
1
,
INT32_MAX
),
"word_id"
)
check_value
(
word_id
,
(
0
,
INT32_MAX
),
"word_id"
)
return
method
(
self
,
*
args
,
**
kwargs
)
return
new_method
...
...
@@ -196,7 +197,7 @@ def check_wordpiece_tokenizer(method):
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
[
vocab
,
suffix_indicator
,
max_bytes_per_token
,
unknown_token
,
with_offsets
],
_
=
\
[
vocab
,
suffix_indicator
,
max_bytes_per_token
,
unknown_token
,
with_offsets
],
_
=
\
parse_user_args
(
method
,
*
args
,
**
kwargs
)
if
vocab
is
None
:
raise
ValueError
(
"vocab is not provided."
)
...
...
@@ -238,7 +239,7 @@ def check_basic_tokenizer(method):
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
[
lower_case
,
keep_whitespace
,
_
,
preserve_unused
,
with_offsets
],
_
=
\
[
lower_case
,
keep_whitespace
,
_
,
preserve_unused
,
with_offsets
],
_
=
\
parse_user_args
(
method
,
*
args
,
**
kwargs
)
if
not
isinstance
(
lower_case
,
bool
):
raise
TypeError
(
"Wrong input type for lower_case, should be boolean."
)
...
...
@@ -317,7 +318,7 @@ def check_from_dataset(method):
type_check
(
top_k
,
(
int
,
type
(
None
)),
"top_k"
)
if
isinstance
(
top_k
,
int
):
check_
value
(
top_k
,
(
0
,
INT32_MAX
)
,
"top_k"
)
check_
positive
(
top_k
,
"top_k"
)
type_check
(
special_first
,
(
bool
,),
"special_first"
)
if
special_tokens
is
not
None
:
...
...
@@ -343,7 +344,7 @@ def check_ngram(method):
for
i
,
gram
in
enumerate
(
n
):
type_check
(
gram
,
(
int
,),
"gram[{0}]"
.
format
(
i
))
check_
value
(
gram
,
(
0
,
INT32_MAX
)
,
"gram_{}"
.
format
(
i
))
check_
positive
(
gram
,
"gram_{}"
.
format
(
i
))
if
not
(
isinstance
(
left_pad
,
tuple
)
and
len
(
left_pad
)
==
2
and
isinstance
(
left_pad
[
0
],
str
)
and
isinstance
(
left_pad
[
1
],
int
)):
...
...
tests/ut/python/dataset/test_from_dataset.py
浏览文件 @
219a716e
...
...
@@ -128,7 +128,7 @@ def test_from_dataset_exceptions():
data
=
ds
.
TextFileDataset
(
"../data/dataset/testVocab/words.txt"
,
shuffle
=
False
)
vocab
=
text
.
Vocab
.
from_dataset
(
data
,
columns
,
freq_range
,
top_k
)
assert
isinstance
(
vocab
.
text
.
Vocab
)
except
(
TypeError
,
ValueError
,
RuntimeError
)
as
e
:
except
(
TypeError
,
ValueError
)
as
e
:
assert
s
in
str
(
e
),
str
(
e
)
test_config
(
"text"
,
(),
1
,
"freq_range needs to be a tuple of 2 integers or an int and a None."
)
...
...
@@ -136,8 +136,8 @@ def test_from_dataset_exceptions():
"Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)"
)
test_config
(
23
,
(
2
,
3
),
1.2345
,
"Argument col_0 with value 23 is not of type (<class 'str'>,)"
)
test_config
(
"text"
,
(
100
,
1
),
12
,
"frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)"
)
test_config
(
"text"
,
(
2
,
3
),
0
,
"top_k
needs to be positive number
"
)
test_config
([
123
],
(
2
,
3
),
0
,
"top_k needs to be positive number
"
)
test_config
(
"text"
,
(
2
,
3
),
0
,
"top_k
must be greater than 0
"
)
test_config
([
123
],
(
2
,
3
),
-
1
,
"top_k must be greater than 0
"
)
if
__name__
==
'__main__'
:
...
...
tests/ut/python/dataset/test_ngram_op.py
浏览文件 @
219a716e
...
...
@@ -72,43 +72,36 @@ def test_simple_ngram():
def
test_corner_cases
():
""" testing various corner cases and exceptions"""
def
test_config
(
input_line
,
output_line
,
n
,
l_pad
=
(
""
,
0
),
r_pad
=
(
""
,
0
),
sep
=
" "
):
def
test_config
(
input_line
,
n
,
l_pad
=
(
""
,
0
),
r_pad
=
(
""
,
0
),
sep
=
" "
):
def
gen
(
texts
):
yield
(
np
.
array
(
texts
.
split
(
" "
),
dtype
=
'S'
),)
dataset
=
ds
.
GeneratorDataset
(
gen
(
input_line
),
column_names
=
[
"text"
])
dataset
=
dataset
.
map
(
input_columns
=
[
"text"
],
operations
=
text
.
Ngram
(
n
,
l_pad
,
r_pad
,
separator
=
sep
))
for
data
in
dataset
.
create_dict_iterator
():
assert
[
d
.
decode
(
"utf8"
)
for
d
in
data
[
"text"
]]
==
output_line
,
output_line
try
:
dataset
=
ds
.
GeneratorDataset
(
gen
(
input_line
),
column_names
=
[
"text"
])
dataset
=
dataset
.
map
(
input_columns
=
[
"text"
],
operations
=
text
.
Ngram
(
n
,
l_pad
,
r_pad
,
separator
=
sep
))
for
data
in
dataset
.
create_dict_iterator
():
return
[
d
.
decode
(
"utf8"
)
for
d
in
data
[
"text"
]]
except
(
ValueError
,
TypeError
)
as
e
:
return
str
(
e
)
# test tensor length smaller than n
test_config
(
"Lone Star"
,
[
"Lone Star"
,
""
,
""
,
""
],
[
2
,
3
,
4
,
5
])
assert
test_config
(
"Lone Star"
,
[
2
,
3
,
4
,
5
])
==
[
"Lone Star"
,
""
,
""
,
""
]
# test empty separator
test_config
(
"Beautiful British Columbia"
,
[
'BeautifulBritish'
,
'BritishColumbia'
],
2
,
sep
=
""
)
assert
test_config
(
"Beautiful British Columbia"
,
2
,
sep
=
""
)
==
[
'BeautifulBritish'
,
'BritishColumbia'
]
# test separator with longer length
test_config
(
"Beautiful British Columbia"
,
[
'Beautiful^-^British^-^Columbia'
],
3
,
sep
=
"^-^"
)
assert
test_config
(
"Beautiful British Columbia"
,
3
,
sep
=
"^-^"
)
==
[
'Beautiful^-^British^-^Columbia'
]
# test left pad != right pad
test_config
(
"Lone Star"
,
[
'The Lone Star State'
],
4
,
(
"The"
,
1
),
(
"State"
,
1
))
assert
test_config
(
"Lone Star"
,
4
,
(
"The"
,
1
),
(
"State"
,
1
))
==
[
'The Lone Star State'
]
# test invalid n
try
:
test_config
(
"Yours to Discover"
,
""
,
[
0
,
[
1
]])
except
Exception
as
e
:
assert
"Argument gram[1] with value [1] is not of type (<class 'int'>,)"
in
str
(
e
)
# test empty n
try
:
test_config
(
"Yours to Discover"
,
""
,
[])
except
Exception
as
e
:
assert
"n needs to be a non-empty list"
in
str
(
e
)
assert
"gram[1] with value [1] is not of type (<class 'int'>,)"
in
test_config
(
"Yours to Discover"
,
[
1
,
[
1
]])
assert
"n needs to be a non-empty list"
in
test_config
(
"Yours to Discover"
,
[])
# test invalid pad
try
:
test_config
(
"Yours to Discover"
,
""
,
[
1
],
(
"str"
,
-
1
))
except
Exception
as
e
:
assert
"padding width need to be positive numbers"
in
str
(
e
)
# test invalid pad
try
:
test_config
(
"Yours to Discover"
,
""
,
[
1
],
(
"str"
,
"rts"
))
except
Exception
as
e
:
assert
"pad needs to be a tuple of (str, int)"
in
str
(
e
)
assert
"padding width need to be positive numbers"
in
test_config
(
"Yours to Discover"
,
[
1
],
(
"str"
,
-
1
))
assert
"pad needs to be a tuple of (str, int)"
in
test_config
(
"Yours to Discover"
,
[
1
],
(
"str"
,
"rts"
))
# test 0 as in valid input
assert
"gram_0 must be greater than 0"
in
test_config
(
"Yours to Discover"
,
0
)
assert
"gram_0 must be greater than 0"
in
test_config
(
"Yours to Discover"
,
[
0
])
assert
"gram_1 must be greater than 0"
in
test_config
(
"Yours to Discover"
,
[
1
,
0
])
if
__name__
==
'__main__'
:
...
...
tests/ut/python/dataset/test_vocab.py
浏览文件 @
219a716e
...
...
@@ -60,6 +60,15 @@ def test_from_dict_tutorial():
ind
+=
1
def
test_from_dict_exception
():
try
:
vocab
=
text
.
Vocab
.
from_dict
({
"home"
:
-
1
,
"behind"
:
0
})
if
not
vocab
:
raise
ValueError
(
"Vocab is None"
)
except
ValueError
as
e
:
assert
"is not within the required interval"
in
str
(
e
)
def
test_from_list
():
def
gen
(
texts
):
for
word
in
texts
.
split
(
" "
):
...
...
@@ -74,13 +83,11 @@ def test_from_list():
for
d
in
data
.
create_dict_iterator
():
res
.
append
(
d
[
"text"
].
item
())
return
res
except
ValueError
as
e
:
return
str
(
e
)
except
RuntimeError
as
e
:
return
str
(
e
)
except
TypeError
as
e
:
except
(
ValueError
,
RuntimeError
,
TypeError
)
as
e
:
return
str
(
e
)
# test basic default config, special_token=None, unknown_token=None
assert
test_config
(
"w1 w2 w3"
,
[
"w1"
,
"w2"
,
"w3"
],
None
,
True
,
None
)
==
[
0
,
1
,
2
]
# test normal operations
assert
test_config
(
"w1 w2 w3 s1 s2 ephemeral"
,
[
"w1"
,
"w2"
,
"w3"
],
[
"s1"
,
"s2"
],
True
,
"s2"
)
==
[
2
,
3
,
4
,
0
,
1
,
1
]
assert
test_config
(
"w1 w2 w3 s1 s2"
,
[
"w1"
,
"w2"
,
"w3"
],
[
"s1"
,
"s2"
],
False
,
"s2"
)
==
[
0
,
1
,
2
,
3
,
4
]
...
...
@@ -129,6 +136,7 @@ def test_from_file():
if
__name__
==
'__main__'
:
test_from_dict_exception
()
test_from_list_tutorial
()
test_from_file_tutorial
()
test_from_dict_tutorial
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录