Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
be2e7531
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
be2e7531
编写于
5月 14, 2020
作者:
J
jonyguo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: MindDataset parameter shard_id & num_shards check
上级
94883f9b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
47 addition
and
1 deletion
+47
-1
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+3
-0
tests/ut/python/dataset/test_minddataset_exception.py
tests/ut/python/dataset/test_minddataset_exception.py
+44
-1
未找到文件。
mindspore/dataset/engine/validators.py
浏览文件 @
be2e7531
...
@@ -534,6 +534,7 @@ def check_minddataset(method):
...
@@ -534,6 +534,7 @@ def check_minddataset(method):
check_dataset_file
(
f
)
check_dataset_file
(
f
)
else
:
else
:
check_dataset_file
(
dataset_file
)
check_dataset_file
(
dataset_file
)
check_param_type
(
nreq_param_int
,
param_dict
,
int
)
check_param_type
(
nreq_param_int
,
param_dict
,
int
)
check_param_type
(
nreq_param_list
,
param_dict
,
list
)
check_param_type
(
nreq_param_list
,
param_dict
,
list
)
...
@@ -544,6 +545,8 @@ def check_minddataset(method):
...
@@ -544,6 +545,8 @@ def check_minddataset(method):
if
(
num_shards
is
not
None
and
shard_id
is
None
)
or
(
num_shards
is
None
and
shard_id
is
not
None
):
if
(
num_shards
is
not
None
and
shard_id
is
None
)
or
(
num_shards
is
None
and
shard_id
is
not
None
):
raise
ValueError
(
"num_shards and shard_id need to be set or not set at the same time"
)
raise
ValueError
(
"num_shards and shard_id need to be set or not set at the same time"
)
check_sampler_shuffle_shard_options
(
param_dict
)
return
method
(
*
args
,
**
kwargs
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
return
new_method
...
...
tests/ut/python/dataset/test_minddataset_exception.py
浏览文件 @
be2e7531
...
@@ -128,7 +128,7 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle():
...
@@ -128,7 +128,7 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle():
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
num_readers
=
4
sampler
=
ds
.
PKSampler
(
2
)
sampler
=
ds
.
PKSampler
(
2
)
with
pytest
.
raises
(
Exception
,
match
=
"s
huffle not allowed when use sampler
"
):
with
pytest
.
raises
(
Exception
,
match
=
"s
ampler and shuffle cannot be specified at the same time.
"
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
sampler
=
sampler
,
shuffle
=
False
)
sampler
=
sampler
,
shuffle
=
False
)
num_iter
=
0
num_iter
=
0
...
@@ -168,3 +168,46 @@ def test_cv_minddataset_reader_different_page_size():
...
@@ -168,3 +168,46 @@ def test_cv_minddataset_reader_different_page_size():
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
os
.
remove
(
CV1_FILE_NAME
)
os
.
remove
(
CV1_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV1_FILE_NAME
))
os
.
remove
(
"{}.db"
.
format
(
CV1_FILE_NAME
))
def
test_minddataset_invalidate_num_shards
():
create_cv_mindrecord
(
1
)
columns_list
=
[
"data"
,
"label"
]
num_readers
=
4
with
pytest
.
raises
(
Exception
,
match
=
"shard_id is invalid, "
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
True
,
0
,
1
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
def
test_minddataset_invalidate_shard_id
():
create_cv_mindrecord
(
1
)
columns_list
=
[
"data"
,
"label"
]
num_readers
=
4
with
pytest
.
raises
(
Exception
,
match
=
"shard_id is invalid, "
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
True
,
1
,
-
1
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
def
test_minddataset_shard_id_bigger_than_num_shard
():
create_cv_mindrecord
(
1
)
columns_list
=
[
"data"
,
"label"
]
num_readers
=
4
with
pytest
.
raises
(
Exception
,
match
=
"shard_id is invalid, "
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
True
,
2
,
2
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
with
pytest
.
raises
(
Exception
,
match
=
"shard_id is invalid, "
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
True
,
2
,
5
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录