Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
eda63a55
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eda63a55
编写于
4月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!477 Fix VOC dataset test cases
Merge pull request !477 from xiefangqi/xfq_fix_voc
上级
53b3d187
108eeb8e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
19 addition
and
27 deletion
+19
-27
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+15
-23
mindspore/dataset/engine/serializer_deserializer.py
mindspore/dataset/engine/serializer_deserializer.py
+2
-1
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+2
-3
未找到文件。
mindspore/dataset/engine/datasets.py
浏览文件 @
eda63a55
...
...
@@ -3335,14 +3335,17 @@ class VOCDataset(SourceDataset):
decode (bool, optional): Decode the images after reading (default=False).
sampler (Sampler, optional): Object used to choose samples from the dataset
(default=None, expected order behavior shown in the table).
distribution (str, optional): Path to the json distribution file to configure
dataset sharding (default=None). This argument should be specified
only when no 'sampler' is used.
num_shards (int, optional): Number of shards that the dataset should be divided
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
Raises:
RuntimeError: If distribution and sampler are specified at the same time.
RuntimeError: If distribution is failed to read.
RuntimeError: If shuffle and sampler are specified at the same time.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Examples:
>>> import mindspore.dataset as ds
...
...
@@ -3356,27 +3359,15 @@ class VOCDataset(SourceDataset):
@
check_vocdataset
def
__init__
(
self
,
dataset_dir
,
num_samples
=
None
,
num_parallel_workers
=
None
,
shuffle
=
None
,
decode
=
False
,
sampler
=
None
,
distribution
=
None
):
shuffle
=
None
,
decode
=
False
,
sampler
=
None
,
num_shards
=
None
,
shard_id
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
dataset_dir
=
dataset_dir
self
.
sampler
=
sampler
if
distribution
is
not
None
:
if
sampler
is
not
None
:
raise
RuntimeError
(
"Cannot specify distribution and sampler at the same time."
)
try
:
with
open
(
distribution
,
'r'
)
as
load_d
:
json
.
load
(
load_d
)
except
json
.
decoder
.
JSONDecodeError
:
raise
RuntimeError
(
"Json decode error when load distribution file"
)
except
Exception
:
raise
RuntimeError
(
"Distribution file has failed to load."
)
elif
shuffle
is
not
None
:
if
sampler
is
not
None
:
raise
RuntimeError
(
"Cannot specify shuffle and sampler at the same time."
)
self
.
sampler
=
_select_sampler
(
num_samples
,
sampler
,
shuffle
,
num_shards
,
shard_id
)
self
.
num_samples
=
num_samples
self
.
decode
=
decode
self
.
distribution
=
distribution
self
.
shuffle_level
=
shuffle
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
def
get_args
(
self
):
args
=
super
().
get_args
()
...
...
@@ -3385,7 +3376,8 @@ class VOCDataset(SourceDataset):
args
[
"sampler"
]
=
self
.
sampler
args
[
"decode"
]
=
self
.
decode
args
[
"shuffle"
]
=
self
.
shuffle_level
args
[
"distribution"
]
=
self
.
distribution
args
[
"num_shards"
]
=
self
.
num_shards
args
[
"shard_id"
]
=
self
.
shard_id
return
args
def
get_dataset_size
(
self
):
...
...
mindspore/dataset/engine/serializer_deserializer.py
浏览文件 @
eda63a55
...
...
@@ -286,7 +286,8 @@ def create_node(node):
elif
dataset_op
==
'VOCDataset'
:
sampler
=
construct_sampler
(
node
.
get
(
'sampler'
))
pyobj
=
pyclass
(
node
[
'dataset_dir'
],
node
.
get
(
'num_samples'
),
node
.
get
(
'num_parallel_workers'
),
node
.
get
(
'shuffle'
),
node
.
get
(
'decode'
),
sampler
,
node
.
get
(
'distribution'
))
node
.
get
(
'shuffle'
),
node
.
get
(
'decode'
),
sampler
,
node
.
get
(
'num_shards'
),
node
.
get
(
'shard_id'
))
elif
dataset_op
==
'CelebADataset'
:
sampler
=
construct_sampler
(
node
.
get
(
'sampler'
))
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
eda63a55
...
...
@@ -443,9 +443,8 @@ def check_vocdataset(method):
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
nreq_param_int
=
[
'num_samples'
,
'num_parallel_workers'
]
nreq_param_int
=
[
'num_samples'
,
'num_parallel_workers'
,
'num_shards'
,
'shard_id'
]
nreq_param_bool
=
[
'shuffle'
,
'decode'
]
nreq_param_str
=
[
'distribution'
]
# check dataset_dir; required argument
dataset_dir
=
param_dict
.
get
(
'dataset_dir'
)
...
...
@@ -457,7 +456,7 @@ def check_vocdataset(method):
check_param_type
(
nreq_param_bool
,
param_dict
,
bool
)
check_
param_type
(
nreq_param_str
,
param_dict
,
str
)
check_
sampler_shuffle_shard_options
(
param_dict
)
return
method
(
*
args
,
**
kwargs
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录