Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b298c515
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b298c515
编写于
5月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
xiefangqi
5月 28, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
!1559 Voc dataset support split ops
Merge pull request !1559 from xiefangqi/xfq_voc_support_split
上级
0f221403
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
78 addition
and
4 deletion
+78
-4
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
+26
-0
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
+9
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+18
-4
tests/ut/python/dataset/test_datasets_voc.py
tests/ut/python/dataset/test_datasets_voc.py
+18
-0
未找到文件。
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
b298c515
...
...
@@ -202,6 +202,13 @@ void bindDatasetOps(py::module *m) {
return
count
;
});
(
void
)
py
::
class_
<
VOCOp
,
DatasetOp
,
std
::
shared_ptr
<
VOCOp
>>
(
*
m
,
"VOCOp"
)
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
,
int64_t
numSamples
)
{
int64_t
count
=
0
;
THROW_IF_ERROR
(
VOCOp
::
CountTotalRows
(
dir
,
task_type
,
task_mode
,
dict
,
numSamples
,
&
count
));
return
count
;
})
.
def_static
(
"get_class_indexing"
,
[](
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
,
int64_t
numSamples
)
{
std
::
map
<
std
::
string
,
int32_t
>
output_class_indexing
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
浏览文件 @
b298c515
...
...
@@ -442,6 +442,32 @@ Status VOCOp::GetNumRowsInDataset(int64_t *num) const {
return
Status
::
OK
();
}
Status
VOCOp
::
CountTotalRows
(
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
,
int64_t
numSamples
,
int64_t
*
count
)
{
if
(
task_type
==
"Detection"
)
{
std
::
map
<
std
::
string
,
int32_t
>
input_class_indexing
;
for
(
auto
p
:
dict
)
{
(
void
)
input_class_indexing
.
insert
(
std
::
pair
<
std
::
string
,
int32_t
>
(
py
::
reinterpret_borrow
<
py
::
str
>
(
p
.
first
),
py
::
reinterpret_borrow
<
py
::
int_
>
(
p
.
second
)));
}
std
::
shared_ptr
<
VOCOp
>
op
;
RETURN_IF_NOT_OK
(
Builder
().
SetDir
(
dir
).
SetTask
(
task_type
).
SetMode
(
task_mode
).
SetClassIndex
(
input_class_indexing
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
op
->
ParseImageIds
());
RETURN_IF_NOT_OK
(
op
->
ParseAnnotationIds
());
*
count
=
static_cast
<
int64_t
>
(
op
->
image_ids_
.
size
());
}
else
if
(
task_type
==
"Segmentation"
)
{
std
::
shared_ptr
<
VOCOp
>
op
;
RETURN_IF_NOT_OK
(
Builder
().
SetDir
(
dir
).
SetTask
(
task_type
).
SetMode
(
task_mode
).
Build
(
&
op
));
RETURN_IF_NOT_OK
(
op
->
ParseImageIds
());
*
count
=
static_cast
<
int64_t
>
(
op
->
image_ids_
.
size
());
}
*
count
=
(
numSamples
==
0
||
*
count
<
numSamples
)
?
*
count
:
numSamples
;
return
Status
::
OK
();
}
Status
VOCOp
::
GetClassIndexing
(
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
,
int64_t
numSamples
,
std
::
map
<
std
::
string
,
int32_t
>
*
output_class_indexing
)
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
浏览文件 @
b298c515
...
...
@@ -208,6 +208,15 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param show_all
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
// @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job
// @param const py::dict &dict - input dict of class index
// @param int64_t numSamples - samples number of VOCDataset
// @param int64_t *count - output rows number of VOCDataset
static
Status
CountTotalRows
(
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
,
int64_t
numSamples
,
int64_t
*
count
);
// @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
b298c515
...
...
@@ -1210,8 +1210,10 @@ class MappableDataset(SourceDataset):
>>> new_sampler = ds.DistributedSampler(10, 2)
>>> data.use_sampler(new_sampler)
"""
if
new_sampler
is
not
None
and
not
isinstance
(
new_sampler
,
(
samplers
.
BuiltinSampler
,
samplers
.
Sampler
)):
raise
TypeError
(
"new_sampler is not an instance of a sampler."
)
if
new_sampler
is
None
:
raise
TypeError
(
"Input sampler could not be None."
)
if
not
isinstance
(
new_sampler
,
(
samplers
.
BuiltinSampler
,
samplers
.
Sampler
)):
raise
TypeError
(
"Input sampler is not an instance of a sampler."
)
self
.
sampler
=
self
.
sampler
.
child_sampler
self
.
add_sampler
(
new_sampler
)
...
...
@@ -3914,12 +3916,24 @@ class VOCDataset(MappableDataset):
Return:
Number, number of batches.
"""
if
self
.
num_samples
is
None
:
num_samples
=
0
else
:
num_samples
=
self
.
num_samples
if
self
.
class_indexing
is
None
:
class_indexing
=
dict
()
else
:
class_indexing
=
self
.
class_indexing
num_rows
=
VOCOp
.
get_num_rows
(
self
.
dataset_dir
,
self
.
task
,
self
.
mode
,
class_indexing
,
num_samples
)
rows_per_shard
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
if
rows_from_sampler
is
None
:
return
self
.
num_samples
return
rows_per_shard
return
min
(
rows_from_sampler
,
self
.
num_samples
)
return
min
(
rows_from_sampler
,
rows_per_shard
)
def
get_class_indexing
(
self
):
"""
...
...
tests/ut/python/dataset/test_datasets_voc.py
浏览文件 @
b298c515
...
...
@@ -115,6 +115,23 @@ def test_case_1():
assert
(
num
==
18
)
def
test_case_2
():
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"Segmentation"
,
mode
=
"train"
,
decode
=
True
)
sizes
=
[
0.5
,
0.5
]
randomize
=
False
dataset1
,
dataset2
=
data1
.
split
(
sizes
=
sizes
,
randomize
=
randomize
)
num_iter
=
0
for
_
in
dataset1
.
create_dict_iterator
():
num_iter
+=
1
assert
(
num_iter
==
5
)
num_iter
=
0
for
_
in
dataset2
.
create_dict_iterator
():
num_iter
+=
1
assert
(
num_iter
==
5
)
def
test_voc_exception
():
try
:
data1
=
ds
.
VOCDataset
(
DATA_DIR
,
task
=
"InvalidTask"
,
mode
=
"train"
,
decode
=
True
)
...
...
@@ -172,4 +189,5 @@ if __name__ == '__main__':
test_voc_get_class_indexing
()
test_case_0
()
test_case_1
()
test_case_2
()
test_voc_exception
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录