Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
d949c17a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d949c17a
编写于
4月 12, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 12, 2020
浏览文件
操作
浏览文件
下载
差异文件
!228 [MD] add subset random sampler in minddataset
Merge pull request !228 from liyong126/mindrecord_subsetrandom_sampler
上级
40d41167
11403492
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
336 addition
and
33 deletion
+336
-33
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+31
-0
mindspore/ccsrc/dataset/api/de_pipeline.h
mindspore/ccsrc/dataset/api/de_pipeline.h
+3
-0
mindspore/ccsrc/mindrecord/include/common/shard_utils.h
mindspore/ccsrc/mindrecord/include/common/shard_utils.h
+2
-0
mindspore/ccsrc/mindrecord/include/shard_sample.h
mindspore/ccsrc/mindrecord/include/shard_sample.h
+6
-0
mindspore/ccsrc/mindrecord/meta/shard_sample.cc
mindspore/ccsrc/mindrecord/meta/shard_sample.cc
+45
-29
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+27
-4
tests/ut/python/dataset/test_minddataset_sampler.py
tests/ut/python/dataset/test_minddataset_sampler.py
+222
-0
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
d949c17a
...
...
@@ -391,6 +391,30 @@ Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vecto
return
Status
::
OK
();
}
Status
DEPipeline
::
GetMindrecordSampler
(
const
std
::
string
&
sampler_name
,
const
py
::
dict
&
args
,
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>
*
ptr
)
{
std
::
vector
<
int
>
indices
;
for
(
auto
&
arg
:
args
)
{
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"indices"
)
{
indices
=
ToIntVector
(
value
);
}
else
{
std
::
string
err_msg
=
"ERROR: parameter "
+
key
+
" is invalid."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
}
}
if
(
sampler_name
==
"SubsetRandomSampler"
)
{
*
ptr
=
std
::
make_shared
<
mindrecord
::
ShardSample
>
(
indices
);
}
else
{
std
::
string
err_msg
=
"ERROR: parameter sampler_name is invalid."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseMindRecordOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
if
(
args
[
"dataset_file"
].
is_none
())
{
std
::
string
err_msg
=
"Error: at least one of dataset_files is missing"
;
...
...
@@ -422,6 +446,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
}
else
if
(
key
==
"global_shuffle"
&&
ToBool
(
value
)
==
true
)
{
uint32_t
seed
=
args
[
"partitions"
].
is_none
()
?
GetSeed
()
:
0
;
operators
.
push_back
(
std
::
make_shared
<
mindrecord
::
ShardShuffle
>
(
seed
));
}
else
if
(
key
==
"sampler_name"
)
{
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>
sample_op
;
auto
ret
=
GetMindrecordSampler
(
ToString
(
value
),
args
[
"sampler_params"
],
&
sample_op
);
if
(
Status
::
OK
()
!=
ret
)
{
return
ret
;
}
operators
.
push_back
(
sample_op
);
}
}
}
...
...
mindspore/ccsrc/dataset/api/de_pipeline.h
浏览文件 @
d949c17a
...
...
@@ -145,6 +145,9 @@ class DEPipeline {
Status
ParseCelebAOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
GetMindrecordSampler
(
const
std
::
string
&
sampler_name
,
const
py
::
dict
&
args
,
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>
*
ptr
);
private:
// Execution tree that links the dataset operators.
std
::
shared_ptr
<
ExecutionTree
>
tree_
;
...
...
mindspore/ccsrc/mindrecord/include/common/shard_utils.h
浏览文件 @
d949c17a
...
...
@@ -68,6 +68,8 @@ enum ShardType {
kCV
=
1
,
};
enum
SamplerType
{
kCustomTopNSampler
,
kCustomTopPercentSampler
,
kSubsetRandomSampler
,
kPKSampler
};
const
double
kEpsilon
=
1e-7
;
const
int
kThreadNumber
=
14
;
...
...
mindspore/ccsrc/mindrecord/include/shard_sample.h
浏览文件 @
d949c17a
...
...
@@ -17,7 +17,9 @@
#ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
#define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
#include <string>
#include <utility>
#include <vector>
#include "mindrecord/include/shard_operator.h"
namespace
mindspore
{
...
...
@@ -30,6 +32,8 @@ class ShardSample : public ShardOperator {
ShardSample
(
int
num
,
int
den
,
int
par
);
explicit
ShardSample
(
const
std
::
vector
<
int
>
&
indices
);
~
ShardSample
()
override
{};
const
std
::
pair
<
int
,
int
>
get_partitions
()
const
;
...
...
@@ -41,6 +45,8 @@ class ShardSample : public ShardOperator {
int
denominator_
;
int
no_of_samples_
;
int
partition_id_
;
std
::
vector
<
int
>
indices_
;
SamplerType
sampler_type_
;
};
}
// namespace mindrecord
}
// namespace mindspore
...
...
mindspore/ccsrc/mindrecord/meta/shard_sample.cc
浏览文件 @
d949c17a
...
...
@@ -22,33 +22,37 @@ using mindspore::MsLogLevel::ERROR;
namespace
mindspore
{
namespace
mindrecord
{
ShardSample
::
ShardSample
(
int
n
)
{
numerator_
=
0
;
denominator_
=
0
;
no_of_samples_
=
n
;
partition_id_
=
0
;
}
ShardSample
::
ShardSample
(
int
n
)
:
numerator_
(
0
),
denominator_
(
0
),
no_of_samples_
(
n
),
partition_id_
(
0
),
indices_
({}),
sampler_type_
(
kCustomTopNSampler
)
{}
ShardSample
::
ShardSample
(
int
num
,
int
den
)
{
if
(
num
<
0
||
den
<=
0
||
num
>
den
)
{
no_of_samples_
=
5
;
numerator_
=
0
;
denominator_
=
0
;
partition_id_
=
0
;
return
;
}
numerator_
=
num
;
denominator_
=
den
;
no_of_samples_
=
0
;
partition_id_
=
0
;
}
ShardSample
::
ShardSample
(
int
num
,
int
den
)
:
numerator_
(
num
),
denominator_
(
den
),
no_of_samples_
(
0
),
partition_id_
(
0
),
indices_
({}),
sampler_type_
(
kCustomTopPercentSampler
)
{}
ShardSample
::
ShardSample
(
int
num
,
int
den
,
int
par
)
{
numerator_
=
num
;
denominator_
=
den
;
no_of_samples_
=
0
;
partition_id_
=
par
;
}
ShardSample
::
ShardSample
(
int
num
,
int
den
,
int
par
)
:
numerator_
(
num
),
denominator_
(
den
),
no_of_samples_
(
0
),
partition_id_
(
par
),
indices_
({}),
sampler_type_
(
kCustomTopPercentSampler
)
{}
ShardSample
::
ShardSample
(
const
std
::
vector
<
int
>
&
indices
)
:
numerator_
(
0
),
denominator_
(
0
),
no_of_samples_
(
0
),
partition_id_
(
0
),
indices_
(
indices
),
sampler_type_
(
kSubsetRandomSampler
)
{}
const
std
::
pair
<
int
,
int
>
ShardSample
::
get_partitions
()
const
{
if
(
numerator_
==
1
&&
denominator_
>
1
)
{
...
...
@@ -62,10 +66,15 @@ MSRStatus ShardSample::operator()(ShardTask &tasks) {
int
total_no
=
static_cast
<
int
>
(
tasks
.
Size
());
int
taking
=
0
;
if
(
no_of_samples_
>
0
)
{
// non sharding case constructor #1
if
(
sampler_type_
==
kCustomTopNSampler
)
{
// non sharding case constructor #1
no_of_samples_
=
std
::
min
(
no_of_samples_
,
total_no
);
taking
=
no_of_samples_
-
no_of_samples_
%
no_of_categories
;
}
else
{
// constructor #2 & #3
}
else
if
(
sampler_type_
==
kSubsetRandomSampler
)
{
if
(
indices_
.
size
()
>
total_no
)
{
MS_LOG
(
ERROR
)
<<
"parameter indices's size is greater than dataset size."
;
return
FAILED
;
}
}
else
{
// constructor TopPercent
if
(
numerator_
>
0
&&
denominator_
>
0
&&
numerator_
<=
denominator_
)
{
if
(
numerator_
==
1
&&
denominator_
>
1
)
{
// sharding
taking
=
(
total_no
/
denominator_
)
+
(
total_no
%
denominator_
==
0
?
0
:
1
);
...
...
@@ -82,8 +91,15 @@ MSRStatus ShardSample::operator()(ShardTask &tasks) {
if
(
tasks
.
permutation_
.
empty
())
{
ShardTask
new_tasks
;
total_no
=
static_cast
<
int
>
(
tasks
.
Size
());
for
(
int
i
=
partition_id_
*
taking
;
i
<
(
partition_id_
+
1
)
*
taking
;
i
++
)
{
new_tasks
.
InsertTask
(
tasks
.
get_task_by_id
(
i
%
total_no
));
// rounding up. if overflow, go back to start
if
(
sampler_type_
==
kSubsetRandomSampler
)
{
for
(
int
i
=
0
;
i
<
indices_
.
size
();
++
i
)
{
int
index
=
((
indices_
[
i
]
%
total_no
)
+
total_no
)
%
total_no
;
new_tasks
.
InsertTask
(
tasks
.
get_task_by_id
(
index
));
// different mod result between c and python
}
}
else
{
for
(
int
i
=
partition_id_
*
taking
;
i
<
(
partition_id_
+
1
)
*
taking
;
i
++
)
{
new_tasks
.
InsertTask
(
tasks
.
get_task_by_id
(
i
%
total_no
));
// rounding up. if overflow, go back to start
}
}
std
::
swap
(
tasks
,
new_tasks
);
}
else
{
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
d949c17a
...
...
@@ -1363,7 +1363,6 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
return
samplers
.
SequentialSampler
()
class
ImageFolderDatasetV2
(
SourceDataset
):
"""
A source dataset that reads images from a tree of directories.
...
...
@@ -1621,6 +1620,9 @@ class MindDataset(SourceDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
block_reader (bool, optional): Whether read data by block mode (default=False).
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, sampler is exclusive
with shuffle and block_reader). Support list: SubsetRandomSampler.
Raises:
ValueError: If num_shards is specified but shard_id is None.
...
...
@@ -1630,14 +1632,16 @@ class MindDataset(SourceDataset):
@
check_minddataset
def
__init__
(
self
,
dataset_file
,
columns_list
=
None
,
num_parallel_workers
=
None
,
shuffle
=
None
,
num_shards
=
None
,
shard_id
=
None
,
block_reader
=
False
):
shuffle
=
None
,
num_shards
=
None
,
shard_id
=
None
,
block_reader
=
False
,
sampler
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
dataset_file
=
dataset_file
self
.
columns_list
=
columns_list
self
.
global_shuffle
=
not
bool
(
shuffle
is
False
)
self
.
global_shuffle
=
shuffle
self
.
distribution
=
""
self
.
sampler
=
sampler
if
num_shards
is
None
:
if
num_shards
is
None
or
shard_id
is
None
:
self
.
partitions
=
None
else
:
self
.
partitions
=
[
num_shards
,
shard_id
]
...
...
@@ -1645,9 +1649,25 @@ class MindDataset(SourceDataset):
if
block_reader
is
True
and
self
.
partitions
is
not
None
:
raise
ValueError
(
"block reader not allowed true when use partitions"
)
if
block_reader
is
True
and
shuffle
is
True
:
raise
ValueError
(
"block reader not allowed true when use shuffle"
)
if
block_reader
is
True
:
logger
.
warning
(
"WARN: global shuffle is not used."
)
if
sampler
is
not
None
and
isinstance
(
sampler
,
samplers
.
SubsetRandomSampler
)
is
False
:
raise
ValueError
(
"the sampler is not supported yet."
)
# sampler exclusive
if
block_reader
is
True
and
sampler
is
not
None
:
raise
ValueError
(
"block reader not allowed true when use sampler"
)
if
shuffle
is
True
and
sampler
is
not
None
:
raise
ValueError
(
"shuffle not allowed true when use sampler"
)
if
block_reader
is
False
and
sampler
is
None
:
self
.
global_shuffle
=
not
bool
(
shuffle
is
False
)
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
self
.
block_reader
=
block_reader
...
...
@@ -1661,6 +1681,9 @@ class MindDataset(SourceDataset):
args
[
"block_reader"
]
=
self
.
block_reader
args
[
"num_shards"
]
=
self
.
num_shards
args
[
"shard_id"
]
=
self
.
shard_id
if
self
.
sampler
:
args
[
"sampler_name"
]
=
self
.
sampler
.
__class__
.
__name__
args
[
"sampler_params"
]
=
self
.
sampler
.
__dict__
return
args
def
get_dataset_size
(
self
):
...
...
tests/ut/python/dataset/test_minddataset_sampler.py
0 → 100644
浏览文件 @
d949c17a
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This is the test module for mindrecord
"""
import
collections
import
json
import
os
import
re
import
string
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
numpy
as
np
import
pytest
from
mindspore.dataset.transforms.vision
import
Inter
from
mindspore
import
log
as
logger
import
mindspore.dataset
as
ds
from
mindspore.mindrecord
import
FileWriter
FILES_NUM
=
4
CV_FILE_NAME
=
"../data/mindrecord/imagenet.mindrecord"
CV_DIR_NAME
=
"../data/mindrecord/testImageNetData"
@
pytest
.
fixture
def
add_and_remove_cv_file
():
"""add/remove cv file"""
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
for
x
in
range
(
FILES_NUM
)]
for
x
in
paths
:
if
os
.
path
.
exists
(
"{}"
.
format
(
x
)):
os
.
remove
(
"{}"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
)):
os
.
remove
(
"{}.db"
.
format
(
x
))
writer
=
FileWriter
(
CV_FILE_NAME
,
FILES_NUM
)
data
=
get_data
(
CV_DIR_NAME
)
cv_schema_json
=
{
"id"
:
{
"type"
:
"int32"
},
"file_name"
:
{
"type"
:
"string"
},
"label"
:
{
"type"
:
"int32"
},
"data"
:
{
"type"
:
"bytes"
}}
writer
.
add_schema
(
cv_schema_json
,
"img_schema"
)
writer
.
add_index
([
"file_name"
,
"label"
])
writer
.
write_raw_data
(
data
)
writer
.
commit
()
yield
"yield_cv_data"
for
x
in
paths
:
os
.
remove
(
"{}"
.
format
(
x
))
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_cv_minddataset_subset_random_sample_basic
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
indices
=
[
1
,
2
,
3
,
5
,
7
]
sampler
=
ds
.
SubsetRandomSampler
(
indices
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
data
=
get_data
(
CV_DIR_NAME
)
assert
data_set
.
get_dataset_size
()
==
10
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[data]: {} -----------------------------"
.
format
(
item
[
"data"
]))
logger
.
info
(
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
assert
data
[
indices
[
num_iter
]][
'file_name'
]
==
""
.
join
(
[
chr
(
x
)
for
x
in
item
[
'file_name'
]])
num_iter
+=
1
assert
num_iter
==
5
def
test_cv_minddataset_subset_random_sample_replica
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
indices
=
[
1
,
2
,
2
,
5
,
7
,
9
]
sampler
=
ds
.
SubsetRandomSampler
(
indices
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
data
=
get_data
(
CV_DIR_NAME
)
assert
data_set
.
get_dataset_size
()
==
10
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[data]: {} -----------------------------"
.
format
(
item
[
"data"
]))
logger
.
info
(
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
assert
data
[
indices
[
num_iter
]][
'file_name'
]
==
""
.
join
(
[
chr
(
x
)
for
x
in
item
[
'file_name'
]])
num_iter
+=
1
assert
num_iter
==
6
def
test_cv_minddataset_subset_random_sample_empty
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
indices
=
[]
sampler
=
ds
.
SubsetRandomSampler
(
indices
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
data
=
get_data
(
CV_DIR_NAME
)
assert
data_set
.
get_dataset_size
()
==
10
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[data]: {} -----------------------------"
.
format
(
item
[
"data"
]))
logger
.
info
(
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
assert
data
[
indices
[
num_iter
]][
'file_name'
]
==
""
.
join
(
[
chr
(
x
)
for
x
in
item
[
'file_name'
]])
num_iter
+=
1
assert
num_iter
==
0
def
test_cv_minddataset_subset_random_sample_out_range
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
indices
=
[
1
,
2
,
4
,
11
,
13
]
sampler
=
ds
.
SubsetRandomSampler
(
indices
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
data
=
get_data
(
CV_DIR_NAME
)
assert
data_set
.
get_dataset_size
()
==
10
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[data]: {} -----------------------------"
.
format
(
item
[
"data"
]))
logger
.
info
(
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
assert
data
[
indices
[
num_iter
]
%
len
(
data
)][
'file_name'
]
==
""
.
join
([
chr
(
x
)
for
x
in
item
[
'file_name'
]])
num_iter
+=
1
assert
num_iter
==
5
def
test_cv_minddataset_subset_random_sample_negative
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
indices
=
[
1
,
2
,
4
,
-
1
,
-
2
]
sampler
=
ds
.
SubsetRandomSampler
(
indices
)
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
sampler
=
sampler
)
data
=
get_data
(
CV_DIR_NAME
)
assert
data_set
.
get_dataset_size
()
==
10
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- item[data]: {} -----------------------------"
.
format
(
item
[
"data"
]))
logger
.
info
(
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
assert
data
[
indices
[
num_iter
]
%
len
(
data
)][
'file_name'
]
==
""
.
join
([
chr
(
x
)
for
x
in
item
[
'file_name'
]])
num_iter
+=
1
assert
num_iter
==
5
def
get_data
(
dir_name
):
"""
usage: get data from imagenet dataset
params:
dir_name: directory containing folder images and annotation information
"""
if
not
os
.
path
.
isdir
(
dir_name
):
raise
IOError
(
"Directory {} not exists"
.
format
(
dir_name
))
img_dir
=
os
.
path
.
join
(
dir_name
,
"images"
)
ann_file
=
os
.
path
.
join
(
dir_name
,
"annotation.txt"
)
with
open
(
ann_file
,
"r"
)
as
file_reader
:
lines
=
file_reader
.
readlines
()
data_list
=
[]
for
i
,
line
in
enumerate
(
lines
):
try
:
filename
,
label
=
line
.
split
(
","
)
label
=
label
.
strip
(
"
\n
"
)
with
open
(
os
.
path
.
join
(
img_dir
,
filename
),
"rb"
)
as
file_reader
:
img
=
file_reader
.
read
()
data_json
=
{
"id"
:
i
,
"file_name"
:
filename
,
"data"
:
img
,
"label"
:
int
(
label
)}
data_list
.
append
(
data_json
)
except
FileNotFoundError
:
continue
return
data_list
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录