Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindspore
提交
ac39c20f
M
mindspore
项目概览
MindSpore
/
mindspore
通知
35
Star
15
Fork
15
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ac39c20f
编写于
8月 31, 2020
作者:
L
liyong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
del finish in FileReader
上级
a9f4a24e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
43 addition
and
78 deletion
+43
-78
mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc
mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc
+0
-1
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
+0
-4
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
+13
-17
mindspore/mindrecord/filereader.py
mindspore/mindrecord/filereader.py
+0
-9
mindspore/mindrecord/shardreader.py
mindspore/mindrecord/shardreader.py
+1
-18
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
+17
-17
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
+6
-6
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
+5
-5
tests/ut/python/mindrecord/test_mindrecord_base.py
tests/ut/python/mindrecord/test_mindrecord_base.py
+1
-1
未找到文件。
mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc
浏览文件 @
ac39c20f
...
...
@@ -93,7 +93,6 @@ void BindShardReader(const py::module *m) {
.
def
(
"get_blob_fields"
,
&
ShardReader
::
GetBlobFields
)
.
def
(
"get_next"
,
(
std
::
vector
<
std
::
tuple
<
std
::
vector
<
std
::
vector
<
uint8_t
>>
,
pybind11
::
object
>>
(
ShardReader
::*
)())
&
ShardReader
::
GetNextPy
)
.
def
(
"finish"
,
&
ShardReader
::
Finish
)
.
def
(
"close"
,
&
ShardReader
::
Close
);
}
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
浏览文件 @
ac39c20f
...
...
@@ -174,10 +174,6 @@ class ShardReader {
ROW_GROUP_BRIEF
ReadRowGroupCriteria
(
int
group_id
,
int
shard_id
,
const
std
::
pair
<
std
::
string
,
std
::
string
>
&
criteria
,
const
std
::
vector
<
std
::
string
>
&
columns
=
std
::
vector
<
std
::
string
>
());
/// \brief join all created threads
/// \return MSRStatus the status of MSRStatus
MSRStatus
Finish
();
/// \brief return a batch, given that one is ready
/// \return a batch of images and image data
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
GetNext
();
...
...
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
浏览文件 @
ac39c20f
...
...
@@ -239,7 +239,19 @@ void ShardReader::FileStreamsOperator() {
ShardReader
::~
ShardReader
()
{
Close
();
}
void
ShardReader
::
Close
()
{
(
void
)
Finish
();
// interrupt reading and stop threads
{
std
::
lock_guard
<
std
::
mutex
>
lck
(
mtx_delivery_
);
interrupt_
=
true
;
// interrupt reading and stop threads
}
cv_delivery_
.
notify_all
();
// Wait for all threads to finish
for
(
auto
&
i_thread
:
thread_set_
)
{
if
(
i_thread
.
joinable
())
{
i_thread
.
join
();
}
}
FileStreamsOperator
();
}
...
...
@@ -759,22 +771,6 @@ bool ResortRowGroups(std::tuple<int, int, int, int> a, std::tuple<int, int, int,
return
std
::
get
<
1
>
(
a
)
<
std
::
get
<
1
>
(
b
)
||
(
std
::
get
<
1
>
(
a
)
==
std
::
get
<
1
>
(
b
)
&&
std
::
get
<
0
>
(
a
)
<
std
::
get
<
0
>
(
b
));
}
MSRStatus
ShardReader
::
Finish
()
{
{
std
::
lock_guard
<
std
::
mutex
>
lck
(
mtx_delivery_
);
interrupt_
=
true
;
}
cv_delivery_
.
notify_all
();
// Wait for all threads to finish
for
(
auto
&
i_thread
:
thread_set_
)
{
if
(
i_thread
.
joinable
())
{
i_thread
.
join
();
}
}
return
SUCCESS
;
}
int64_t
ShardReader
::
GetNumClasses
(
const
std
::
string
&
category_field
)
{
auto
shard_count
=
file_paths_
.
size
();
auto
index_fields
=
shard_header_
->
GetFields
();
...
...
mindspore/mindrecord/filereader.py
浏览文件 @
ac39c20f
...
...
@@ -83,15 +83,6 @@ class FileReader:
yield
populate_data
(
raw
,
blob
,
self
.
_columns
,
self
.
_header
.
blob_fields
,
self
.
_header
.
schema
)
iterator
=
self
.
_reader
.
get_next
()
def
finish
(
self
):
"""
Stop reader worker.
Raises:
MRMFinishError: If failed to finish worker threads.
"""
return
self
.
_reader
.
finish
()
def
close
(
self
):
"""Stop reader worker and close File."""
return
self
.
_reader
.
close
()
mindspore/mindrecord/shardreader.py
浏览文件 @
ac39c20f
...
...
@@ -17,8 +17,7 @@ This module is to read data from mindrecord.
"""
import
mindspore._c_mindrecord
as
ms
from
mindspore
import
log
as
logger
from
.common.exceptions
import
MRMOpenError
,
MRMLaunchError
,
MRMFinishError
from
.common.exceptions
import
MRMOpenError
,
MRMLaunchError
__all__
=
[
'ShardReader'
]
class
ShardReader
:
...
...
@@ -102,22 +101,6 @@ class ShardReader:
"""
return
self
.
_reader
.
get_header
()
def
finish
(
self
):
"""
stop the worker threads.
Returns:
MSRStatus, SUCCESS or FAILED.
Raises:
MRMFinishError: If failed to finish worker threads.
"""
ret
=
self
.
_reader
.
finish
()
if
ret
!=
ms
.
MSRStatus
.
SUCCESS
:
logger
.
error
(
"Failed to finish worker threads."
)
raise
MRMFinishError
return
ret
def
close
(
self
):
"""close MindRecord File."""
self
.
_reader
.
close
()
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
浏览文件 @
ac39c20f
...
...
@@ -73,7 +73,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
]);
i
++
;
}
dataset
.
Finish
();
dataset
.
Close
();
ASSERT_TRUE
(
i
<=
kSampleCount
);
}
...
...
@@ -99,7 +99,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
]);
i
++
;
}
dataset
.
Finish
();
dataset
.
Close
();
ASSERT_TRUE
(
i
<=
5
);
}
...
...
@@ -125,7 +125,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
]);
i
++
;
}
dataset
.
Finish
();
dataset
.
Close
();
ASSERT_TRUE
(
i
<=
10
);
}
...
...
@@ -151,7 +151,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
MS_LOG
(
INFO
)
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
]);
i
++
;
}
dataset
.
Finish
();
dataset
.
Close
();
ASSERT_TRUE
(
i
<=
10
);
}
...
...
@@ -176,7 +176,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
())
<<
std
::
endl
;
i
++
;
}
dataset
.
Finish
();
dataset
.
Close
();
ASSERT_TRUE
(
i
==
20
);
}
// namespace mindrecord
...
...
@@ -202,7 +202,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
())
<<
std
::
endl
;
i
++
;
}
dataset
.
Finish
();
dataset
.
Close
();
ASSERT_TRUE
(
i
==
6
);
}
...
...
@@ -238,7 +238,7 @@ TEST_F(TestShardOperator, TestShardCategory) {
category_no
++
;
category_no
%=
static_cast
<
int
>
(
categories
.
size
());
}
dataset
.
Finish
();
dataset
.
Close
();
}
TEST_F
(
TestShardOperator
,
TestShardShuffle
)
{
...
...
@@ -262,7 +262,7 @@ TEST_F(TestShardOperator, TestShardShuffle) {
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
}
dataset
.
Finish
();
dataset
.
Close
();
}
TEST_F
(
TestShardOperator
,
TestShardSampleShuffle
)
{
...
...
@@ -287,7 +287,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
}
dataset
.
Finish
();
dataset
.
Close
();
ASSERT_LE
(
i
,
35
);
}
...
...
@@ -314,7 +314,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
}
dataset
.
Finish
();
dataset
.
Close
();
ASSERT_TRUE
(
i
<=
kSampleSize
);
}
...
...
@@ -341,7 +341,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
());
i
++
;
}
dataset
.
Finish
();
dataset
.
Close
();
ASSERT_LE
(
i
,
35
);
}
...
...
@@ -373,8 +373,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
auto
y
=
compare_dataset
.
GetNext
();
if
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
]
!=
(
std
::
get
<
1
>
(
y
[
0
]))[
"file_name"
])
different
=
true
;
}
dataset
.
Finish
();
compare_dataset
.
Finish
();
dataset
.
Close
();
compare_dataset
.
Close
();
ASSERT_TRUE
(
different
);
}
...
...
@@ -409,7 +409,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
category_no
++
;
category_no
%=
static_cast
<
int
>
(
categories
.
size
());
}
dataset
.
Finish
();
dataset
.
Close
();
}
TEST_F
(
TestShardOperator
,
TestShardCategoryShuffle2
)
{
...
...
@@ -442,7 +442,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
category_no
++
;
category_no
%=
static_cast
<
int
>
(
categories
.
size
());
}
dataset
.
Finish
();
dataset
.
Close
();
}
TEST_F
(
TestShardOperator
,
TestShardCategorySample
)
{
...
...
@@ -477,7 +477,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
category_no
++
;
category_no
%=
static_cast
<
int
>
(
categories
.
size
());
}
dataset
.
Finish
();
dataset
.
Close
();
ASSERT_EQ
(
category_no
,
0
);
ASSERT_TRUE
(
i
<=
kSampleSize
);
}
...
...
@@ -515,7 +515,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
category_no
++
;
category_no
%=
static_cast
<
int
>
(
categories
.
size
());
}
dataset
.
Finish
();
dataset
.
Close
();
ASSERT_EQ
(
category_no
,
0
);
ASSERT_TRUE
(
i
<=
kSampleSize
);
}
...
...
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
浏览文件 @
ac39c20f
...
...
@@ -67,7 +67,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
}
}
}
dataset
.
Finish
();
dataset
.
Close
();
}
TEST_F
(
TestShardReader
,
TestShardReaderSample
)
{
...
...
@@ -90,7 +90,7 @@ TEST_F(TestShardReader, TestShardReaderSample) {
}
}
}
dataset
.
Finish
();
dataset
.
Close
();
dataset
.
Close
();
}
...
...
@@ -110,7 +110,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
}
}
}
dataset
.
Finish
();
dataset
.
Close
();
}
TEST_F
(
TestShardReader
,
TestShardReaderColumnNotInIndex
)
{
...
...
@@ -131,7 +131,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
}
}
}
dataset
.
Finish
();
dataset
.
Close
();
}
TEST_F
(
TestShardReader
,
TestShardReaderColumnNotInSchema
)
{
...
...
@@ -161,7 +161,7 @@ TEST_F(TestShardReader, TestShardVersion) {
}
}
}
dataset
.
Finish
();
dataset
.
Close
();
}
TEST_F
(
TestShardReader
,
TestShardReaderDir
)
{
...
...
@@ -192,7 +192,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
}
}
}
dataset
.
Finish
();
dataset
.
Close
();
}
}
// namespace mindrecord
}
// namespace mindspore
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
浏览文件 @
ac39c20f
...
...
@@ -74,7 +74,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
}
}
}
dataset
.
Finish
();
dataset
.
Close
();
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
string
filename
=
std
::
string
(
"./OneSample.shard0"
)
+
std
::
to_string
(
i
);
string
db_name
=
std
::
string
(
"./OneSample.shard0"
)
+
std
::
to_string
(
i
)
+
".db"
;
...
...
@@ -775,7 +775,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
}
}
ASSERT_TRUE
(
count
==
10
);
dataset
.
Finish
();
dataset
.
Close
();
for
(
const
auto
&
filename
:
file_names
)
{
auto
filename_db
=
filename
+
".db"
;
...
...
@@ -858,7 +858,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
}
}
ASSERT_TRUE
(
count
==
10
);
dataset
.
Finish
();
dataset
.
Close
();
for
(
const
auto
&
filename
:
file_names
)
{
auto
filename_db
=
filename
+
".db"
;
remove
(
common
::
SafeCStr
(
filename_db
));
...
...
@@ -952,7 +952,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
}
}
ASSERT_TRUE
(
count
==
10
);
dataset
.
Finish
();
dataset
.
Close
();
for
(
const
auto
&
filename
:
file_names
)
{
auto
filename_db
=
filename
+
".db"
;
remove
(
common
::
SafeCStr
(
filename_db
));
...
...
@@ -1060,7 +1060,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {
count
++
;
}
ASSERT_TRUE
(
count
==
10
);
dataset
.
Finish
();
dataset
.
Close
();
for
(
const
auto
&
filename
:
file_names
)
{
auto
filename_db
=
filename
+
".db"
;
remove
(
common
::
SafeCStr
(
filename_db
));
...
...
tests/ut/python/mindrecord/test_mindrecord_base.py
浏览文件 @
ac39c20f
...
...
@@ -260,7 +260,7 @@ def test_cv_file_reader_partial_tutorial():
count
=
count
+
1
logger
.
info
(
"#item{}: {}"
.
format
(
index
,
x
))
if
count
==
5
:
reader
.
finish
()
reader
.
close
()
assert
count
==
5
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录