Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7fa0d9e7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7fa0d9e7
编写于
6月 29, 2020
作者:
M
ms_yan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add paramter check for numpyslices and num_shards
上级
2f1b0dc5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
10 addition
and
3 deletion
+10
-3
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+7
-2
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+2
-0
tests/ut/python/dataset/test_minddataset_exception.py
tests/ut/python/dataset/test_minddataset_exception.py
+1
-1
未找到文件。
mindspore/dataset/engine/datasets.py
浏览文件 @
7fa0d9e7
...
@@ -3069,7 +3069,7 @@ class GeneratorDataset(MappableDataset):
...
@@ -3069,7 +3069,7 @@ class GeneratorDataset(MappableDataset):
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
required (default=None, expected order behavior shown in the table).
required (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
This argument should be specified only when 'num_samples' is "None"
. Random accessible input is required.
When this argument is specified, 'num_samples' will not effect
. Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required.
when num_shards is also specified. Random accessible input is required.
...
@@ -4878,6 +4878,11 @@ class _NumpySlicesDataset:
...
@@ -4878,6 +4878,11 @@ class _NumpySlicesDataset:
else
:
else
:
self
.
data
=
(
np
.
array
(
data
),)
self
.
data
=
(
np
.
array
(
data
),)
# check whether the data length in each column is equal
data_len
=
[
len
(
data_item
)
for
data_item
in
self
.
data
]
if
data_len
[
1
:]
!=
data_len
[:
-
1
]:
raise
ValueError
(
"Data length in each column is not equal."
)
# Init column_name
# Init column_name
if
column_list
is
not
None
:
if
column_list
is
not
None
:
self
.
column_list
=
column_list
self
.
column_list
=
column_list
...
@@ -4966,7 +4971,7 @@ class NumpySlicesDataset(GeneratorDataset):
...
@@ -4966,7 +4971,7 @@ class NumpySlicesDataset(GeneratorDataset):
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
required (default=None, expected order behavior shown in the table).
required (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
This argument should be specified only when 'num_samples' is "None"
. Random accessible input is required.
When this argument is specified, 'num_samples' will not effect
. Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required.
when num_shards is also specified. Random accessible input is required.
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
7fa0d9e7
...
@@ -153,6 +153,7 @@ def check_sampler_shuffle_shard_options(param_dict):
...
@@ -153,6 +153,7 @@ def check_sampler_shuffle_shard_options(param_dict):
raise
RuntimeError
(
"sampler and sharding cannot be specified at the same time."
)
raise
RuntimeError
(
"sampler and sharding cannot be specified at the same time."
)
if
num_shards
is
not
None
:
if
num_shards
is
not
None
:
check_positive_int32
(
num_shards
,
"num_shards"
)
if
shard_id
is
None
:
if
shard_id
is
None
:
raise
RuntimeError
(
"num_shards is specified and currently requires shard_id as well."
)
raise
RuntimeError
(
"num_shards is specified and currently requires shard_id as well."
)
if
shard_id
<
0
or
shard_id
>=
num_shards
:
if
shard_id
<
0
or
shard_id
>=
num_shards
:
...
@@ -529,6 +530,7 @@ def check_generatordataset(method):
...
@@ -529,6 +530,7 @@ def check_generatordataset(method):
# These two parameters appear together.
# These two parameters appear together.
raise
ValueError
(
"num_shards and shard_id need to be passed in together"
)
raise
ValueError
(
"num_shards and shard_id need to be passed in together"
)
if
num_shards
is
not
None
:
if
num_shards
is
not
None
:
check_positive_int32
(
num_shards
,
"num_shards"
)
if
shard_id
>=
num_shards
:
if
shard_id
>=
num_shards
:
raise
ValueError
(
"shard_id should be less than num_shards"
)
raise
ValueError
(
"shard_id should be less than num_shards"
)
...
...
tests/ut/python/dataset/test_minddataset_exception.py
浏览文件 @
7fa0d9e7
...
@@ -185,7 +185,7 @@ def test_minddataset_invalidate_num_shards():
...
@@ -185,7 +185,7 @@ def test_minddataset_invalidate_num_shards():
columns_list
=
[
"data"
,
"label"
]
columns_list
=
[
"data"
,
"label"
]
num_readers
=
4
num_readers
=
4
with
pytest
.
raises
(
Exception
,
match
=
"shard_id is invalid, "
):
with
pytest
.
raises
(
Exception
,
match
=
"shard_id is invalid, "
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
True
,
0
,
1
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
True
,
1
,
2
)
num_iter
=
0
num_iter
=
0
for
_
in
data_set
.
create_dict_iterator
():
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
num_iter
+=
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录