Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
aa3f89e7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
aa3f89e7
编写于
5月 08, 2020
作者:
L
liyong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
mindrecord support read file list
上级
a2d5ad5a
变更
27
隐藏空白更改
内联
并排
Showing
27 changed file
with
496 addition
and
173 deletion
+496
-173
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+7
-2
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+11
-10
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
...e/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
+20
-13
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
...re/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
+19
-11
mindspore/ccsrc/mindrecord/common/shard_error.cc
mindspore/ccsrc/mindrecord/common/shard_error.cc
+3
-0
mindspore/ccsrc/mindrecord/common/shard_pybind.cc
mindspore/ccsrc/mindrecord/common/shard_pybind.cc
+4
-2
mindspore/ccsrc/mindrecord/include/shard_error.h
mindspore/ccsrc/mindrecord/include/shard_error.h
+2
-1
mindspore/ccsrc/mindrecord/include/shard_header.h
mindspore/ccsrc/mindrecord/include/shard_header.h
+7
-6
mindspore/ccsrc/mindrecord/include/shard_reader.h
mindspore/ccsrc/mindrecord/include/shard_reader.h
+13
-8
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
+16
-2
mindspore/ccsrc/mindrecord/io/shard_reader.cc
mindspore/ccsrc/mindrecord/io/shard_reader.cc
+64
-29
mindspore/ccsrc/mindrecord/io/shard_writer.cc
mindspore/ccsrc/mindrecord/io/shard_writer.cc
+18
-5
mindspore/ccsrc/mindrecord/meta/shard_header.cc
mindspore/ccsrc/mindrecord/meta/shard_header.cc
+31
-25
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+11
-3
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+5
-2
mindspore/mindrecord/filereader.py
mindspore/mindrecord/filereader.py
+6
-3
mindspore/mindrecord/mindpage.py
mindspore/mindrecord/mindpage.py
+6
-3
mindspore/mindrecord/shardreader.py
mindspore/mindrecord/shardreader.py
+7
-2
mindspore/mindrecord/shardsegment.py
mindspore/mindrecord/shardsegment.py
+7
-2
tests/ut/cpp/dataset/mind_record_op_test.cc
tests/ut/cpp/dataset/mind_record_op_test.cc
+14
-7
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
+17
-17
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
+9
-9
tests/ut/cpp/mindrecord/ut_shard_segment_test.cc
tests/ut/cpp/mindrecord/ut_shard_segment_test.cc
+5
-5
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
+5
-5
tests/ut/python/dataset/test_minddataset.py
tests/ut/python/dataset/test_minddataset.py
+122
-1
tests/ut/python/dataset/test_minddataset_exception.py
tests/ut/python/dataset/test_minddataset_exception.py
+57
-0
tests/ut/python/mindrecord/test_mindrecord_base.py
tests/ut/python/mindrecord/test_mindrecord_base.py
+10
-0
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
aa3f89e7
...
@@ -408,8 +408,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
...
@@ -408,8 +408,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
}
}
std
::
shared_ptr
<
MindRecordOp
::
Builder
>
builder
=
std
::
make_shared
<
MindRecordOp
::
Builder
>
();
std
::
shared_ptr
<
MindRecordOp
::
Builder
>
builder
=
std
::
make_shared
<
MindRecordOp
::
Builder
>
();
(
void
)
builder
->
SetDatasetFile
(
ToString
(
args
[
"dataset_file"
]));
bool
load_dataset
=
ToBool
(
args
[
"load_dataset"
]);
if
(
load_dataset
==
true
)
{
(
void
)
builder
->
SetDatasetFile
({
ToString
(
args
[
"dataset_file"
])});
}
else
{
(
void
)
builder
->
SetDatasetFile
(
ToStringVector
(
args
[
"dataset_file"
]));
}
(
void
)
builder
->
SetLoadDataset
(
load_dataset
);
std
::
vector
<
std
::
string
>
in_col_names
;
std
::
vector
<
std
::
string
>
in_col_names
;
if
(
!
args
[
"columns_list"
].
is_none
())
{
if
(
!
args
[
"columns_list"
].
is_none
())
{
in_col_names
=
ToStringVector
(
args
[
"columns_list"
]);
in_col_names
=
ToStringVector
(
args
[
"columns_list"
]);
...
...
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
aa3f89e7
...
@@ -151,16 +151,17 @@ void bindDatasetOps(py::module *m) {
...
@@ -151,16 +151,17 @@ void bindDatasetOps(py::module *m) {
});
});
(
void
)
py
::
class_
<
MindRecordOp
,
DatasetOp
,
std
::
shared_ptr
<
MindRecordOp
>>
(
*
m
,
"MindRecordOp"
)
(
void
)
py
::
class_
<
MindRecordOp
,
DatasetOp
,
std
::
shared_ptr
<
MindRecordOp
>>
(
*
m
,
"MindRecordOp"
)
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
path
,
const
py
::
object
&
sampler
)
{
.
def_static
(
"get_num_rows"
,
int64_t
count
=
0
;
[](
const
std
::
vector
<
std
::
string
>
&
paths
,
bool
load_dataset
,
const
py
::
object
&
sampler
)
{
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>
op
;
int64_t
count
=
0
;
if
(
py
::
hasattr
(
sampler
,
"_create_for_minddataset"
))
{
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>
op
;
auto
create
=
sampler
.
attr
(
"_create_for_minddataset"
);
if
(
py
::
hasattr
(
sampler
,
"_create_for_minddataset"
))
{
op
=
create
().
cast
<
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
();
auto
create
=
sampler
.
attr
(
"_create_for_minddataset"
);
}
op
=
create
().
cast
<
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
();
THROW_IF_ERROR
(
MindRecordOp
::
CountTotalRows
(
path
,
op
,
&
count
));
}
return
count
;
THROW_IF_ERROR
(
MindRecordOp
::
CountTotalRows
(
paths
,
load_dataset
,
op
,
&
count
));
});
return
count
;
});
(
void
)
py
::
class_
<
ManifestOp
,
DatasetOp
,
std
::
shared_ptr
<
ManifestOp
>>
(
*
m
,
"ManifestOp"
)
(
void
)
py
::
class_
<
ManifestOp
,
DatasetOp
,
std
::
shared_ptr
<
ManifestOp
>>
(
*
m
,
"ManifestOp"
)
.
def_static
(
"get_num_rows_and_classes"
,
.
def_static
(
"get_num_rows_and_classes"
,
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
浏览文件 @
aa3f89e7
...
@@ -40,7 +40,7 @@ using mindrecord::ShardOperator;
...
@@ -40,7 +40,7 @@ using mindrecord::ShardOperator;
using
mindrecord
::
ShardReader
;
using
mindrecord
::
ShardReader
;
// Builder constructor. Creates the builder object.
// Builder constructor. Creates the builder object.
MindRecordOp
::
Builder
::
Builder
()
:
build_dataset_file_
(
""
)
{
MindRecordOp
::
Builder
::
Builder
()
:
build_dataset_file_
(
{}
)
{
// Some arguments to the MindRecordOp constructor have a default argument that is taken
// Some arguments to the MindRecordOp constructor have a default argument that is taken
// from the client config.
// from the client config.
// The user may choose to change these values for the construction of the StorageOp by
// The user may choose to change these values for the construction of the StorageOp by
...
@@ -63,9 +63,9 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
...
@@ -63,9 +63,9 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
"Building a MindRecordOp that has not provided a file."
);
"Building a MindRecordOp that has not provided a file."
);
}
}
new_mind_record_op
=
std
::
make_shared
<
MindRecordOp
>
(
build_num_mind_record_workers_
,
build_rows_per_buffer_
,
new_mind_record_op
=
std
::
make_shared
<
MindRecordOp
>
(
build_dataset_file_
,
build_op_connector_queue_size
_
,
build_num_mind_record_workers_
,
build_rows_per_buffer_
,
build_dataset_file_
,
build_load_dataset
_
,
build_columns_to_load_
,
build_operators_
,
build_block_reader_
);
build_op_connector_queue_size_
,
build_columns_to_load_
,
build_operators_
,
build_block_reader_
);
RETURN_IF_NOT_OK
(
new_mind_record_op
->
Init
());
RETURN_IF_NOT_OK
(
new_mind_record_op
->
Init
());
...
@@ -76,12 +76,14 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
...
@@ -76,12 +76,14 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
Status
MindRecordOp
::
Builder
::
SanityCheck
()
const
{
return
Status
::
OK
();
}
Status
MindRecordOp
::
Builder
::
SanityCheck
()
const
{
return
Status
::
OK
();
}
// Constructor of the MindRecordOp.
// Constructor of the MindRecordOp.
MindRecordOp
::
MindRecordOp
(
int32_t
num_mind_record_workers
,
int32_t
rows_per_buffer
,
std
::
string
dataset_file
,
MindRecordOp
::
MindRecordOp
(
int32_t
num_mind_record_workers
,
int32_t
rows_per_buffer
,
int32_t
op_connector_queue_size
,
const
std
::
vector
<
std
::
string
>
&
columns_to_load
,
std
::
vector
<
std
::
string
>
dataset_file
,
bool
load_dataset
,
int32_t
op_connector_queue_size
,
const
std
::
vector
<
std
::
string
>
&
columns_to_load
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
,
const
bool
&
block_reader
)
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
,
const
bool
&
block_reader
)
:
ParallelOp
(
num_mind_record_workers
,
op_connector_queue_size
),
:
ParallelOp
(
num_mind_record_workers
,
op_connector_queue_size
),
rows_per_buffer_
(
rows_per_buffer
),
rows_per_buffer_
(
rows_per_buffer
),
dataset_file_
(
dataset_file
),
dataset_file_
(
dataset_file
),
load_dataset_
(
load_dataset
),
columns_to_load_
(
columns_to_load
),
columns_to_load_
(
columns_to_load
),
operators_
(
operators
),
operators_
(
operators
),
num_mind_record_workers_
(
num_mind_record_workers
),
num_mind_record_workers_
(
num_mind_record_workers
),
...
@@ -101,9 +103,10 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf
...
@@ -101,9 +103,10 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf
// Private helper method to encapsulate some common construction/reset tasks
// Private helper method to encapsulate some common construction/reset tasks
Status
MindRecordOp
::
Init
()
{
Status
MindRecordOp
::
Init
()
{
shard_reader_
=
std
::
make_unique
<
ShardReader
>
();
shard_reader_
=
std
::
make_unique
<
ShardReader
>
();
auto
rc
=
shard_reader_
->
Open
(
dataset_file_
,
num_mind_record_workers_
,
columns_to_load_
,
operators_
,
block_reader_
);
auto
rc
=
shard_reader_
->
Open
(
dataset_file_
,
load_dataset_
,
num_mind_record_workers_
,
columns_to_load_
,
operators_
,
block_reader_
);
CHECK_FAIL_RETURN_UNEXPECTED
(
rc
!=
MSRStatus
::
FAILED
,
CHECK_FAIL_RETURN_UNEXPECTED
(
rc
==
MSRStatus
::
SUCCESS
,
"MindRecordOp init failed. Error message: "
+
ErrnoToMessage
(
rc
));
"MindRecordOp init failed. Error message: "
+
ErrnoToMessage
(
rc
));
data_schema_
=
std
::
make_unique
<
DataSchema
>
();
data_schema_
=
std
::
make_unique
<
DataSchema
>
();
...
@@ -201,8 +204,12 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
...
@@ -201,8 +204,12 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
// Call the super class for displaying any common detailed info
ParallelOp
::
Print
(
out
,
show_all
);
ParallelOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal stuff
// Then show any custom derived-internal stuff
out
<<
"
\n
1 Dataset file : "
<<
dataset_file_
<<
"
\n
Number of rows : "
<<
num_rows_
out
<<
"
\n
Dataset file : "
;
<<
"
\n
Rows per buffer : "
<<
rows_per_buffer_
<<
"
\n
Number of buffers : "
<<
buffers_needed_
for
(
auto
&
file
:
dataset_file_
)
{
out
<<
file
<<
" "
;
}
out
<<
"
\n
Number of rows : "
<<
num_rows_
<<
"
\n
Rows per buffer : "
<<
rows_per_buffer_
<<
"
\n
Number of buffers : "
<<
buffers_needed_
<<
"
\n
Number of ShardReader workers : "
<<
num_mind_record_workers_
<<
"
\n\n
"
;
<<
"
\n
Number of ShardReader workers : "
<<
num_mind_record_workers_
<<
"
\n\n
"
;
}
}
}
}
...
@@ -668,10 +675,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() {
...
@@ -668,10 +675,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() {
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
MindRecordOp
::
CountTotalRows
(
const
std
::
string
dataset_path
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
,
Status
MindRecordOp
::
CountTotalRows
(
const
std
::
vector
<
std
::
string
>
dataset_path
,
bool
load_dataset
,
int64_t
*
count
)
{
const
std
::
shared_ptr
<
ShardOperator
>
&
op
,
int64_t
*
count
)
{
std
::
unique_ptr
<
ShardReader
>
shard_reader
=
std
::
make_unique
<
ShardReader
>
();
std
::
unique_ptr
<
ShardReader
>
shard_reader
=
std
::
make_unique
<
ShardReader
>
();
MSRStatus
rc
=
shard_reader
->
CountTotalRows
(
dataset_path
,
op
,
count
);
MSRStatus
rc
=
shard_reader
->
CountTotalRows
(
dataset_path
,
load_dataset
,
op
,
count
);
if
(
rc
==
MSRStatus
::
FAILED
)
{
if
(
rc
==
MSRStatus
::
FAILED
)
{
RETURN_STATUS_UNEXPECTED
(
"MindRecordOp count total rows failed."
);
RETURN_STATUS_UNEXPECTED
(
"MindRecordOp count total rows failed."
);
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
浏览文件 @
aa3f89e7
...
@@ -77,8 +77,8 @@ class MindRecordOp : public ParallelOp {
...
@@ -77,8 +77,8 @@ class MindRecordOp : public ParallelOp {
return
*
this
;
return
*
this
;
}
}
Builder
&
SetDatasetFile
(
const
std
::
string
&
file
)
{
Builder
&
SetDatasetFile
(
const
std
::
vector
<
std
::
string
>
&
files
)
{
build_dataset_file_
=
file
;
build_dataset_file_
=
file
s
;
return
*
this
;
return
*
this
;
}
}
...
@@ -97,6 +97,11 @@ class MindRecordOp : public ParallelOp {
...
@@ -97,6 +97,11 @@ class MindRecordOp : public ParallelOp {
return
*
this
;
return
*
this
;
}
}
Builder
&
SetLoadDataset
(
bool
load_dataset
)
{
build_load_dataset_
=
load_dataset
;
return
*
this
;
}
Status
SanityCheck
()
const
;
Status
SanityCheck
()
const
;
static
int32_t
num_mind_record_workers
()
{
return
kDefaultMindRecordWorkers
;
}
static
int32_t
num_mind_record_workers
()
{
return
kDefaultMindRecordWorkers
;
}
...
@@ -109,7 +114,8 @@ class MindRecordOp : public ParallelOp {
...
@@ -109,7 +114,8 @@ class MindRecordOp : public ParallelOp {
int32_t
builder_num_workers_
;
int32_t
builder_num_workers_
;
int32_t
build_rows_per_buffer_
;
int32_t
build_rows_per_buffer_
;
int32_t
build_op_connector_queue_size_
;
int32_t
build_op_connector_queue_size_
;
std
::
string
build_dataset_file_
;
std
::
vector
<
std
::
string
>
build_dataset_file_
;
bool
build_load_dataset_
;
std
::
vector
<
std
::
string
>
build_columns_to_load_
;
std
::
vector
<
std
::
string
>
build_columns_to_load_
;
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
build_operators_
;
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
build_operators_
;
bool
build_block_reader_
;
bool
build_block_reader_
;
...
@@ -119,12 +125,12 @@ class MindRecordOp : public ParallelOp {
...
@@ -119,12 +125,12 @@ class MindRecordOp : public ParallelOp {
// @note The builder class should be used to call it
// @note The builder class should be used to call it
// @param num_mind_record_workers - The number of workers for the op (run by ShardReader)
// @param num_mind_record_workers - The number of workers for the op (run by ShardReader)
// @param rows_per_buffer - The requested number of rows per buffer
// @param rows_per_buffer - The requested number of rows per buffer
// @param dataset_file -
A shard file
// @param dataset_file -
dataset files
// @param op_connector_queue_size - The output connector queue size
// @param op_connector_queue_size - The output connector queue size
// @param columns_to_load - The list of columns to use (column name)
// @param columns_to_load - The list of columns to use (column name)
// @param operators - ShardOperators for Shuffle, Category, Sample
// @param operators - ShardOperators for Shuffle, Category, Sample
MindRecordOp
(
int32_t
num_mind_record_workers
,
int32_t
rows_per_buffer
,
std
::
string
dataset_file
,
MindRecordOp
(
int32_t
num_mind_record_workers
,
int32_t
rows_per_buffer
,
std
::
vector
<
std
::
string
>
dataset_file
,
int32_t
op_connector_queue_size
,
const
std
::
vector
<
std
::
string
>
&
columns_to_load
,
bool
load_dataset
,
int32_t
op_connector_queue_size
,
const
std
::
vector
<
std
::
string
>
&
columns_to_load
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
,
const
bool
&
block_reader
);
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
,
const
bool
&
block_reader
);
// Destructor
// Destructor
...
@@ -169,21 +175,22 @@ class MindRecordOp : public ParallelOp {
...
@@ -169,21 +175,22 @@ class MindRecordOp : public ParallelOp {
// Getter method
// Getter method
int32_t
num_rows
()
const
{
return
num_rows_
;
}
int32_t
num_rows
()
const
{
return
num_rows_
;
}
// Getter method
static
Status
CountTotalRows
(
const
std
::
vector
<
std
::
string
>
dataset_path
,
bool
load_dataset
,
static
Status
CountTotalRows
(
const
std
::
string
dataset_path
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
,
int64_t
*
count
);
int64_t
*
count
);
// Getter method
// Getter method
int32_t
rows_per_buffer
()
const
{
return
rows_per_buffer_
;
}
int32_t
rows_per_buffer
()
const
{
return
rows_per_buffer_
;
}
// Getter method
// Getter method
std
::
string
dataset_file
()
const
{
return
dataset_file_
;
}
std
::
vector
<
std
::
string
>
dataset_file
()
const
{
return
dataset_file_
;
}
// Getter method
// Getter method
std
::
vector
<
std
::
string
>
columns_to_load
()
const
{
return
columns_to_load_
;
}
std
::
vector
<
std
::
string
>
columns_to_load
()
const
{
return
columns_to_load_
;
}
bool
block_reader
()
const
{
return
block_reader_
;
}
bool
block_reader
()
const
{
return
block_reader_
;
}
bool
load_dataset
()
const
{
return
load_dataset_
;
}
Status
Init
();
Status
Init
();
Status
SetColumnsBlob
();
Status
SetColumnsBlob
();
...
@@ -246,7 +253,8 @@ class MindRecordOp : public ParallelOp {
...
@@ -246,7 +253,8 @@ class MindRecordOp : public ParallelOp {
Status
FetchBlockBuffer
(
const
int32_t
&
buffer_id
);
Status
FetchBlockBuffer
(
const
int32_t
&
buffer_id
);
int32_t
rows_per_buffer_
;
// The number of requested rows per buffer.
int32_t
rows_per_buffer_
;
// The number of requested rows per buffer.
std
::
string
dataset_file_
;
// A dataset file
std
::
vector
<
std
::
string
>
dataset_file_
;
// dataset files
bool
load_dataset_
;
// load dataset from single file or not
std
::
vector
<
std
::
string
>
columns_to_load_
;
// Columns to load from dataset
std
::
vector
<
std
::
string
>
columns_to_load_
;
// Columns to load from dataset
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
operators_
;
// ShardOperators to use
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
operators_
;
// ShardOperators to use
int32_t
num_mind_record_workers_
;
// number of workers to be spawned by ShardReader
int32_t
num_mind_record_workers_
;
// number of workers to be spawned by ShardReader
...
...
mindspore/ccsrc/mindrecord/common/shard_error.cc
浏览文件 @
aa3f89e7
...
@@ -170,6 +170,9 @@ std::string ErrnoToMessage(MSRStatus status) {
...
@@ -170,6 +170,9 @@ std::string ErrnoToMessage(MSRStatus status) {
case
IO_FAILED
:
case
IO_FAILED
:
return
"io operate failed"
;
return
"io operate failed"
;
break
;
break
;
case
MATCH_HEADER_FAILED
:
return
"match header failed"
;
break
;
default:
default:
return
"invalid error no"
;
return
"invalid error no"
;
}
}
...
...
mindspore/ccsrc/mindrecord/common/shard_pybind.cc
浏览文件 @
aa3f89e7
...
@@ -84,7 +84,8 @@ void BindShardWriter(py::module *m) {
...
@@ -84,7 +84,8 @@ void BindShardWriter(py::module *m) {
void
BindShardReader
(
const
py
::
module
*
m
)
{
void
BindShardReader
(
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
ShardReader
,
std
::
shared_ptr
<
ShardReader
>>
(
*
m
,
"ShardReader"
,
py
::
module_local
())
(
void
)
py
::
class_
<
ShardReader
,
std
::
shared_ptr
<
ShardReader
>>
(
*
m
,
"ShardReader"
,
py
::
module_local
())
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
"open"
,
(
MSRStatus
(
ShardReader
::*
)(
const
std
::
string
&
,
const
int
&
,
const
std
::
vector
<
std
::
string
>
&
,
.
def
(
"open"
,
(
MSRStatus
(
ShardReader
::*
)(
const
std
::
vector
<
std
::
string
>
&
,
bool
,
const
int
&
,
const
std
::
vector
<
std
::
string
>
&
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
))
&
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
))
&
ShardReader
::
OpenPy
)
ShardReader
::
OpenPy
)
.
def
(
"launch"
,
&
ShardReader
::
Launch
)
.
def
(
"launch"
,
&
ShardReader
::
Launch
)
...
@@ -106,7 +107,8 @@ void BindShardIndexGenerator(const py::module *m) {
...
@@ -106,7 +107,8 @@ void BindShardIndexGenerator(const py::module *m) {
void
BindShardSegment
(
py
::
module
*
m
)
{
void
BindShardSegment
(
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
ShardSegment
>
(
*
m
,
"ShardSegment"
,
py
::
module_local
())
(
void
)
py
::
class_
<
ShardSegment
>
(
*
m
,
"ShardSegment"
,
py
::
module_local
())
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
"open"
,
(
MSRStatus
(
ShardSegment
::*
)(
const
std
::
string
&
,
const
int
&
,
const
std
::
vector
<
std
::
string
>
&
,
.
def
(
"open"
,
(
MSRStatus
(
ShardSegment
::*
)(
const
std
::
vector
<
std
::
string
>
&
,
bool
,
const
int
&
,
const
std
::
vector
<
std
::
string
>
&
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
))
&
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
))
&
ShardSegment
::
OpenPy
)
ShardSegment
::
OpenPy
)
.
def
(
"get_category_fields"
,
.
def
(
"get_category_fields"
,
...
...
mindspore/ccsrc/mindrecord/include/shard_error.h
浏览文件 @
aa3f89e7
...
@@ -72,7 +72,8 @@ enum MSRStatus {
...
@@ -72,7 +72,8 @@ enum MSRStatus {
ILLEGAL_PARAMETERS
,
ILLEGAL_PARAMETERS
,
GET_PAGE_BY_GROUP_ID_FAILED
,
GET_PAGE_BY_GROUP_ID_FAILED
,
GET_SYSTEM_STATE_FAILED
,
GET_SYSTEM_STATE_FAILED
,
IO_FAILED
IO_FAILED
,
MATCH_HEADER_FAILED
};
};
// convert error no to string message
// convert error no to string message
...
...
mindspore/ccsrc/mindrecord/include/shard_header.h
浏览文件 @
aa3f89e7
...
@@ -35,10 +35,11 @@ class ShardHeader {
...
@@ -35,10 +35,11 @@ class ShardHeader {
public:
public:
ShardHeader
();
ShardHeader
();
MSRStatus
Build
(
const
std
::
string
&
file_path
);
~
ShardHeader
()
=
default
;
~
ShardHeader
()
=
default
;
MSRStatus
BuildDataset
(
const
std
::
vector
<
std
::
string
>
&
file_paths
,
bool
load_dataset
=
true
);
static
std
::
pair
<
MSRStatus
,
json
>
BuildSingleHeader
(
const
std
::
string
&
file_path
);
/// \brief add the schema and save it
/// \brief add the schema and save it
/// \param[in] schema the schema needs to be added
/// \param[in] schema the schema needs to be added
/// \return the last schema's id
/// \return the last schema's id
...
@@ -126,7 +127,7 @@ class ShardHeader {
...
@@ -126,7 +127,7 @@ class ShardHeader {
MSRStatus
FileToPages
(
const
std
::
string
dump_file_name
);
MSRStatus
FileToPages
(
const
std
::
string
dump_file_name
);
private:
private:
MSRStatus
InitializeHeader
(
const
std
::
vector
<
json
>
&
headers
);
MSRStatus
InitializeHeader
(
const
std
::
vector
<
json
>
&
headers
,
bool
load_dataset
);
/// \brief get the headers from all the shard data
/// \brief get the headers from all the shard data
/// \param[in] the shard data real path
/// \param[in] the shard data real path
...
@@ -137,9 +138,9 @@ class ShardHeader {
...
@@ -137,9 +138,9 @@ class ShardHeader {
MSRStatus
ValidateField
(
const
std
::
vector
<
std
::
string
>
&
field_name
,
json
schema
,
const
uint64_t
&
schema_id
);
MSRStatus
ValidateField
(
const
std
::
vector
<
std
::
string
>
&
field_name
,
json
schema
,
const
uint64_t
&
schema_id
);
/// \brief check the binary file status
/// \brief check the binary file status
MSRStatus
CheckFileStatus
(
const
std
::
string
&
path
);
static
MSRStatus
CheckFileStatus
(
const
std
::
string
&
path
);
std
::
pair
<
MSRStatus
,
json
>
ValidateHeader
(
const
std
::
string
&
path
);
st
atic
st
d
::
pair
<
MSRStatus
,
json
>
ValidateHeader
(
const
std
::
string
&
path
);
void
ParseHeader
(
const
json
&
header
);
void
ParseHeader
(
const
json
&
header
);
...
@@ -149,7 +150,7 @@ class ShardHeader {
...
@@ -149,7 +150,7 @@ class ShardHeader {
MSRStatus
CheckIndexField
(
const
std
::
string
&
field
,
const
json
&
schema
);
MSRStatus
CheckIndexField
(
const
std
::
string
&
field
,
const
json
&
schema
);
void
ParsePage
(
const
json
&
page
);
void
ParsePage
(
const
json
&
page
,
int
shard_index
,
bool
load_dataset
);
MSRStatus
ParseStatistics
(
const
json
&
statistics
);
MSRStatus
ParseStatistics
(
const
json
&
statistics
);
...
...
mindspore/ccsrc/mindrecord/include/shard_reader.h
浏览文件 @
aa3f89e7
...
@@ -68,23 +68,25 @@ class ShardReader {
...
@@ -68,23 +68,25 @@ class ShardReader {
virtual
~
ShardReader
();
virtual
~
ShardReader
();
/// \brief open files and initialize reader, c++ API
/// \brief open files and initialize reader, c++ API
/// \param[in] file_path the path of ONE file, any file in dataset is fine
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
/// \param[in] load_dataset load dataset from single file or not
/// \param[in] n_consumer number of threads when reading
/// \param[in] n_consumer number of threads when reading
/// \param[in] selected_columns column list to be populated
/// \param[in] selected_columns column list to be populated
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
/// \param[in] block_reader block-reader mode if true, otherwise row-reader mode
/// \param[in] block_reader block-reader mode if true, otherwise row-reader mode
/// \return MSRStatus the status of MSRStatus
/// \return MSRStatus the status of MSRStatus
MSRStatus
Open
(
const
std
::
string
&
file_path
,
int
n_consumer
=
4
,
MSRStatus
Open
(
const
std
::
vector
<
std
::
string
>
&
file_paths
,
bool
load_dataset
,
int
n_consumer
=
4
,
const
std
::
vector
<
std
::
string
>
&
selected_columns
=
{},
const
std
::
vector
<
std
::
string
>
&
selected_columns
=
{},
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
=
{},
const
bool
&
block_reader
=
false
);
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
=
{},
const
bool
&
block_reader
=
false
);
/// \brief open files and initialize reader, python API
/// \brief open files and initialize reader, python API
/// \param[in] file_path the path of ONE file, any file in dataset is fine
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
/// \param[in] load_dataset load dataset from single file or not
/// \param[in] n_consumer number of threads when reading
/// \param[in] n_consumer number of threads when reading
/// \param[in] selected_columns column list to be populated
/// \param[in] selected_columns column list to be populated
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
/// \return MSRStatus the status of MSRStatus
/// \return MSRStatus the status of MSRStatus
MSRStatus
OpenPy
(
const
std
::
string
&
file_path
,
const
int
&
n_consumer
=
4
,
MSRStatus
OpenPy
(
const
std
::
vector
<
std
::
string
>
&
file_paths
,
bool
load_dataset
,
const
int
&
n_consumer
=
4
,
const
std
::
vector
<
std
::
string
>
&
selected_columns
=
{},
const
std
::
vector
<
std
::
string
>
&
selected_columns
=
{},
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
=
{});
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
=
{});
...
@@ -114,11 +116,13 @@ class ShardReader {
...
@@ -114,11 +116,13 @@ class ShardReader {
int
GetShardCount
()
const
;
int
GetShardCount
()
const
;
/// \brief get the number of rows in database
/// \brief get the number of rows in database
/// \param[in] file_path the path of ONE file, any file in dataset is fine
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
/// \param[in] load_dataset load dataset from single file or not
/// \param[in] op smart pointer refer to ShardCategory or ShardSample object
/// \param[in] op smart pointer refer to ShardCategory or ShardSample object
/// \param[out] count # of rows
/// \param[out] count # of rows
/// \return MSRStatus the status of MSRStatus
/// \return MSRStatus the status of MSRStatus
MSRStatus
CountTotalRows
(
const
std
::
string
&
file_path
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
,
int64_t
*
count
);
MSRStatus
CountTotalRows
(
const
std
::
vector
<
std
::
string
>
&
file_paths
,
bool
load_dataset
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
,
int64_t
*
count
);
/// \brief shuffle task with incremental seed
/// \brief shuffle task with incremental seed
/// \return void
/// \return void
...
@@ -220,7 +224,7 @@ class ShardReader {
...
@@ -220,7 +224,7 @@ class ShardReader {
std
::
vector
<
std
::
vector
<
json
>>
&
column_values
);
std
::
vector
<
std
::
vector
<
json
>>
&
column_values
);
/// \brief initialize reader
/// \brief initialize reader
MSRStatus
Init
(
const
std
::
string
&
file_path
);
MSRStatus
Init
(
const
std
::
vector
<
std
::
string
>
&
file_paths
,
bool
load_dataset
);
/// \brief validate column list
/// \brief validate column list
MSRStatus
CheckColumnList
(
const
std
::
vector
<
std
::
string
>
&
selected_columns
);
MSRStatus
CheckColumnList
(
const
std
::
vector
<
std
::
string
>
&
selected_columns
);
...
@@ -292,8 +296,9 @@ class ShardReader {
...
@@ -292,8 +296,9 @@ class ShardReader {
void
GetClassesInShard
(
sqlite3
*
db
,
int
shard_id
,
const
std
::
string
sql
,
std
::
set
<
std
::
string
>
&
categories
);
void
GetClassesInShard
(
sqlite3
*
db
,
int
shard_id
,
const
std
::
string
sql
,
std
::
set
<
std
::
string
>
&
categories
);
/// \brief get number of classes
/// \brief get number of classes
int64_t
GetNumClasses
(
const
std
::
string
&
file_path
,
const
std
::
string
&
category_field
);
int64_t
GetNumClasses
(
const
std
::
string
&
category_field
);
std
::
pair
<
MSRStatus
,
std
::
vector
<
std
::
string
>>
GetMeta
(
const
std
::
string
&
file_path
,
json
&
meta_data
);
/// \brief get exactly blob fields data by indices
/// \brief get exactly blob fields data by indices
std
::
vector
<
uint8_t
>
ExtractBlobFieldBySelectColumns
(
std
::
vector
<
uint8_t
>
&
blob_fields_bytes
,
std
::
vector
<
uint8_t
>
ExtractBlobFieldBySelectColumns
(
std
::
vector
<
uint8_t
>
&
blob_fields_bytes
,
std
::
vector
<
uint32_t
>
&
ordered_selected_columns_index
);
std
::
vector
<
uint32_t
>
&
ordered_selected_columns_index
);
...
...
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
浏览文件 @
aa3f89e7
...
@@ -36,9 +36,23 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe
...
@@ -36,9 +36,23 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe
write_success_
(
true
)
{}
write_success_
(
true
)
{}
MSRStatus
ShardIndexGenerator
::
Build
()
{
MSRStatus
ShardIndexGenerator
::
Build
()
{
auto
ret
=
ShardHeader
::
BuildSingleHeader
(
file_path_
);
if
(
ret
.
first
!=
SUCCESS
)
{
return
FAILED
;
}
auto
json_header
=
ret
.
second
;
auto
ret2
=
GetParentDir
(
file_path_
);
if
(
SUCCESS
!=
ret2
.
first
)
{
return
FAILED
;
}
std
::
vector
<
std
::
string
>
real_addresses
;
for
(
const
auto
&
path
:
json_header
[
"shard_addresses"
])
{
std
::
string
abs_path
=
ret2
.
second
+
string
(
path
);
real_addresses
.
emplace_back
(
abs_path
);
}
ShardHeader
header
=
ShardHeader
();
ShardHeader
header
=
ShardHeader
();
if
(
header
.
Build
(
file_path_
)
!=
SUCCESS
)
{
if
(
header
.
BuildDataset
(
real_addresses
)
==
FAILED
)
{
MS_LOG
(
ERROR
)
<<
"Build shard schema failed."
;
return
FAILED
;
return
FAILED
;
}
}
shard_header_
=
header
;
shard_header_
=
header
;
...
...
mindspore/ccsrc/mindrecord/io/shard_reader.cc
浏览文件 @
aa3f89e7
...
@@ -47,20 +47,55 @@ ShardReader::ShardReader() {
...
@@ -47,20 +47,55 @@ ShardReader::ShardReader() {
block_reader_
=
false
;
block_reader_
=
false
;
}
}
MSRStatus
ShardReader
::
Init
(
const
std
::
string
&
file_path
)
{
std
::
pair
<
MSRStatus
,
std
::
vector
<
std
::
string
>>
ShardReader
::
GetMeta
(
const
std
::
string
&
file_path
,
json
&
meta_data
)
{
if
(
!
IsLegalFile
(
file_path
))
{
if
(
!
IsLegalFile
(
file_path
))
{
return
{
FAILED
,
{}};
}
auto
ret
=
ShardHeader
::
BuildSingleHeader
(
file_path
);
if
(
ret
.
first
!=
SUCCESS
)
{
return
{
FAILED
,
{}};
}
auto
header
=
ret
.
second
;
meta_data
=
{{
"header_size"
,
header
[
"header_size"
]},
{
"page_size"
,
header
[
"page_size"
]},
{
"version"
,
header
[
"version"
]},
{
"index_fields"
,
header
[
"index_fields"
]},
{
"schema"
,
header
[
"schema"
]},
{
"blob_fields"
,
header
[
"blob_fields"
]}};
return
{
SUCCESS
,
header
[
"shard_addresses"
]};
}
MSRStatus
ShardReader
::
Init
(
const
std
::
vector
<
std
::
string
>
&
file_paths
,
bool
load_dataset
)
{
std
::
string
file_path
=
file_paths
[
0
];
json
first_meta_data
=
json
();
auto
ret
=
GetMeta
(
file_path
,
first_meta_data
);
if
(
ret
.
first
!=
SUCCESS
)
{
return
FAILED
;
return
FAILED
;
}
}
ShardHeader
sh
=
ShardHeader
();
if
(
file_paths
.
size
()
==
1
&&
load_dataset
==
true
)
{
if
(
sh
.
Build
(
file_path
)
==
FAILED
)
{
auto
ret2
=
GetParentDir
(
file_path
);
if
(
SUCCESS
!=
ret2
.
first
)
{
return
FAILED
;
}
std
::
vector
<
std
::
string
>
real_addresses
;
for
(
const
auto
&
path
:
ret
.
second
)
{
std
::
string
abs_path
=
ret2
.
second
+
string
(
path
);
real_addresses
.
emplace_back
(
abs_path
);
}
file_paths_
=
real_addresses
;
}
else
if
(
file_paths
.
size
()
>=
1
&&
load_dataset
==
false
)
{
file_paths_
=
file_paths
;
}
else
{
MS_LOG
(
ERROR
)
<<
"Error in parameter file_path or load_dataset."
;
return
FAILED
;
return
FAILED
;
}
}
shard_header_
=
std
::
make_shared
<
ShardHeader
>
(
sh
);
header_size_
=
shard_header_
->
GetHeaderSize
();
page_size_
=
shard_header_
->
GetPageSize
();
file_paths_
=
shard_header_
->
GetShardAddresses
();
for
(
const
auto
&
file
:
file_paths_
)
{
for
(
const
auto
&
file
:
file_paths_
)
{
json
meta_data
=
json
();
auto
ret1
=
GetMeta
(
file
,
meta_data
);
if
(
ret1
.
first
!=
SUCCESS
)
{
return
FAILED
;
}
if
(
meta_data
!=
first_meta_data
)
{
MS_LOG
(
ERROR
)
<<
"Mindrecord files meta information is different."
;
return
FAILED
;
}
sqlite3
*
db
=
nullptr
;
sqlite3
*
db
=
nullptr
;
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
int
rc
=
sqlite3_open_v2
(
common
::
SafeCStr
(
file
+
".db"
),
&
db
,
SQLITE_OPEN_READONLY
,
nullptr
);
int
rc
=
sqlite3_open_v2
(
common
::
SafeCStr
(
file
+
".db"
),
&
db
,
SQLITE_OPEN_READONLY
,
nullptr
);
...
@@ -91,7 +126,13 @@ MSRStatus ShardReader::Init(const std::string &file_path) {
...
@@ -91,7 +126,13 @@ MSRStatus ShardReader::Init(const std::string &file_path) {
}
}
database_paths_
.
push_back
(
db
);
database_paths_
.
push_back
(
db
);
}
}
ShardHeader
sh
=
ShardHeader
();
if
(
sh
.
BuildDataset
(
file_paths_
,
load_dataset
)
==
FAILED
)
{
return
FAILED
;
}
shard_header_
=
std
::
make_shared
<
ShardHeader
>
(
sh
);
header_size_
=
shard_header_
->
GetHeaderSize
();
page_size_
=
shard_header_
->
GetPageSize
();
num_rows_
=
0
;
num_rows_
=
0
;
auto
row_group_summary
=
ReadRowGroupSummary
();
auto
row_group_summary
=
ReadRowGroupSummary
();
for
(
const
auto
&
rg
:
row_group_summary
)
{
for
(
const
auto
&
rg
:
row_group_summary
)
{
...
@@ -248,7 +289,6 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
...
@@ -248,7 +289,6 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
fs
->
close
();
fs
->
close
();
return
FAILED
;
return
FAILED
;
}
}
json
label_json
=
json
::
from_msgpack
(
label_raw
);
json
label_json
=
json
::
from_msgpack
(
label_raw
);
json
tmp
;
json
tmp
;
if
(
!
columns
.
empty
())
{
if
(
!
columns
.
empty
())
{
...
@@ -713,15 +753,9 @@ MSRStatus ShardReader::Finish() {
...
@@ -713,15 +753,9 @@ MSRStatus ShardReader::Finish() {
return
SUCCESS
;
return
SUCCESS
;
}
}
int64_t
ShardReader
::
GetNumClasses
(
const
std
::
string
&
file_path
,
const
std
::
string
&
category_field
)
{
int64_t
ShardReader
::
GetNumClasses
(
const
std
::
string
&
category_field
)
{
ShardHeader
sh
=
ShardHeader
();
auto
shard_count
=
file_paths_
.
size
();
if
(
sh
.
Build
(
file_path
)
==
FAILED
)
{
auto
index_fields
=
shard_header_
->
GetFields
();
return
-
1
;
}
auto
header
=
std
::
make_shared
<
ShardHeader
>
(
sh
);
auto
file_paths
=
header
->
GetShardAddresses
();
auto
shard_count
=
file_paths
.
size
();
auto
index_fields
=
header
->
GetFields
();
std
::
map
<
std
::
string
,
int64_t
>
map_schema_id_fields
;
std
::
map
<
std
::
string
,
int64_t
>
map_schema_id_fields
;
for
(
auto
&
field
:
index_fields
)
{
for
(
auto
&
field
:
index_fields
)
{
...
@@ -742,7 +776,7 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
...
@@ -742,7 +776,7 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
std
::
set
<
std
::
string
>
categories
;
std
::
set
<
std
::
string
>
categories
;
for
(
int
x
=
0
;
x
<
shard_count
;
x
++
)
{
for
(
int
x
=
0
;
x
<
shard_count
;
x
++
)
{
sqlite3
*
db
=
nullptr
;
sqlite3
*
db
=
nullptr
;
int
rc
=
sqlite3_open_v2
(
common
::
SafeCStr
(
file_paths
[
x
]
+
".db"
),
&
db
,
SQLITE_OPEN_READONLY
,
nullptr
);
int
rc
=
sqlite3_open_v2
(
common
::
SafeCStr
(
file_paths
_
[
x
]
+
".db"
),
&
db
,
SQLITE_OPEN_READONLY
,
nullptr
);
if
(
SQLITE_OK
!=
rc
)
{
if
(
SQLITE_OK
!=
rc
)
{
MS_LOG
(
ERROR
)
<<
"Can't open database, error: "
<<
sqlite3_errmsg
(
db
);
MS_LOG
(
ERROR
)
<<
"Can't open database, error: "
<<
sqlite3_errmsg
(
db
);
return
-
1
;
return
-
1
;
...
@@ -756,16 +790,16 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
...
@@ -756,16 +790,16 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
return
categories
.
size
();
return
categories
.
size
();
}
}
MSRStatus
ShardReader
::
CountTotalRows
(
const
std
::
string
&
file_path
,
const
std
::
shared_ptr
<
ShardOperator
>
&
op
,
MSRStatus
ShardReader
::
CountTotalRows
(
const
std
::
vector
<
std
::
string
>
&
file_paths
,
bool
load_dataset
,
int64_t
*
count
)
{
const
std
::
shared_ptr
<
ShardOperator
>
&
op
,
int64_t
*
count
)
{
if
(
Init
(
file_path
)
==
FAILED
)
{
if
(
SUCCESS
!=
Init
(
file_paths
,
load_dataset
)
)
{
return
FAILED
;
return
FAILED
;
}
}
int64_t
num_samples
=
num_rows_
;
int64_t
num_samples
=
num_rows_
;
if
(
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
))
{
if
(
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
))
{
auto
category_op
=
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
);
auto
category_op
=
std
::
dynamic_pointer_cast
<
ShardCategory
>
(
op
);
std
::
string
category_field
=
category_op
->
GetCategoryField
();
std
::
string
category_field
=
category_op
->
GetCategoryField
();
auto
num_classes
=
GetNumClasses
(
file_path
,
category_field
);
auto
num_classes
=
GetNumClasses
(
category_field
);
num_samples
=
category_op
->
GetNumSamples
(
num_rows_
,
num_classes
);
num_samples
=
category_op
->
GetNumSamples
(
num_rows_
,
num_classes
);
}
else
if
(
std
::
dynamic_pointer_cast
<
ShardSample
>
(
op
))
{
}
else
if
(
std
::
dynamic_pointer_cast
<
ShardSample
>
(
op
))
{
num_samples
=
op
->
GetNumSamples
(
num_rows_
,
0
);
num_samples
=
op
->
GetNumSamples
(
num_rows_
,
0
);
...
@@ -779,12 +813,13 @@ MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::s
...
@@ -779,12 +813,13 @@ MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::s
return
SUCCESS
;
return
SUCCESS
;
}
}
MSRStatus
ShardReader
::
Open
(
const
std
::
string
&
file_path
,
int
n_consumer
,
MSRStatus
ShardReader
::
Open
(
const
std
::
vector
<
std
::
string
>
&
file_paths
,
bool
load_dataset
,
int
n_consumer
,
const
std
::
vector
<
std
::
string
>
&
selected_columns
,
const
std
::
vector
<
std
::
string
>
&
selected_columns
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
,
const
bool
&
block_reader
)
{
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
,
const
bool
&
block_reader
)
{
// Open file and set header by ShardReader
// Open file and set header by ShardReader
if
(
Init
(
file_path
)
==
FAILED
)
{
auto
ret
=
Init
(
file_paths
,
load_dataset
);
return
FAILED
;
if
(
SUCCESS
!=
ret
)
{
return
ret
;
}
}
auto
thread_limit
=
GetMaxThreadNum
();
auto
thread_limit
=
GetMaxThreadNum
();
if
(
n_consumer
>
thread_limit
)
{
if
(
n_consumer
>
thread_limit
)
{
...
@@ -837,11 +872,11 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
...
@@ -837,11 +872,11 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
return
SUCCESS
;
return
SUCCESS
;
}
}
MSRStatus
ShardReader
::
OpenPy
(
const
std
::
string
&
file_path
,
const
int
&
n_consumer
,
MSRStatus
ShardReader
::
OpenPy
(
const
std
::
vector
<
std
::
string
>
&
file_paths
,
bool
load_dataset
,
const
int
&
n_consumer
,
const
std
::
vector
<
std
::
string
>
&
selected_columns
,
const
std
::
vector
<
std
::
string
>
&
selected_columns
,
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
)
{
const
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
&
operators
)
{
// Open file and set header by ShardReader
// Open file and set header by ShardReader
if
(
Init
(
file_path
)
==
FAILED
)
{
if
(
SUCCESS
!=
Init
(
file_paths
,
load_dataset
)
)
{
return
FAILED
;
return
FAILED
;
}
}
// should remove blob field from selected_columns when call from python
// should remove blob field from selected_columns when call from python
...
...
mindspore/ccsrc/mindrecord/io/shard_writer.cc
浏览文件 @
aa3f89e7
...
@@ -174,12 +174,25 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
...
@@ -174,12 +174,25 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
if
(
!
IsLegalFile
(
path
))
{
if
(
!
IsLegalFile
(
path
))
{
return
FAILED
;
return
FAILED
;
}
}
ShardHeader
sh
=
ShardHeader
(
);
auto
ret1
=
ShardHeader
::
BuildSingleHeader
(
path
);
if
(
sh
.
Build
(
path
)
==
FAILED
)
{
if
(
ret1
.
first
!=
SUCCESS
)
{
return
FAILED
;
return
FAILED
;
}
}
shard_header_
=
std
::
make_shared
<
ShardHeader
>
(
sh
);
auto
json_header
=
ret1
.
second
;
auto
paths
=
shard_header_
->
GetShardAddresses
();
auto
ret2
=
GetParentDir
(
path
);
if
(
SUCCESS
!=
ret2
.
first
)
{
return
FAILED
;
}
std
::
vector
<
std
::
string
>
real_addresses
;
for
(
const
auto
&
path
:
json_header
[
"shard_addresses"
])
{
std
::
string
abs_path
=
ret2
.
second
+
string
(
path
);
real_addresses
.
emplace_back
(
abs_path
);
}
ShardHeader
header
=
ShardHeader
();
if
(
header
.
BuildDataset
(
real_addresses
)
==
FAILED
)
{
return
FAILED
;
}
shard_header_
=
std
::
make_shared
<
ShardHeader
>
(
header
);
MSRStatus
ret
=
SetHeaderSize
(
shard_header_
->
GetHeaderSize
());
MSRStatus
ret
=
SetHeaderSize
(
shard_header_
->
GetHeaderSize
());
if
(
ret
==
FAILED
)
{
if
(
ret
==
FAILED
)
{
return
FAILED
;
return
FAILED
;
...
@@ -188,7 +201,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
...
@@ -188,7 +201,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
if
(
ret
==
FAILED
)
{
if
(
ret
==
FAILED
)
{
return
FAILED
;
return
FAILED
;
}
}
ret
=
Open
(
paths
,
true
);
ret
=
Open
(
json_header
[
"shard_addresses"
]
,
true
);
if
(
ret
==
FAILED
)
{
if
(
ret
==
FAILED
)
{
MS_LOG
(
ERROR
)
<<
"Open file failed"
;
MS_LOG
(
ERROR
)
<<
"Open file failed"
;
return
FAILED
;
return
FAILED
;
...
...
mindspore/ccsrc/mindrecord/meta/shard_header.cc
浏览文件 @
aa3f89e7
...
@@ -35,8 +35,9 @@ namespace mindrecord {
...
@@ -35,8 +35,9 @@ namespace mindrecord {
std
::
atomic
<
bool
>
thread_status
(
false
);
std
::
atomic
<
bool
>
thread_status
(
false
);
ShardHeader
::
ShardHeader
()
:
shard_count_
(
0
),
header_size_
(
0
),
page_size_
(
0
)
{
index_
=
std
::
make_shared
<
Index
>
();
}
ShardHeader
::
ShardHeader
()
:
shard_count_
(
0
),
header_size_
(
0
),
page_size_
(
0
)
{
index_
=
std
::
make_shared
<
Index
>
();
}
MSRStatus
ShardHeader
::
InitializeHeader
(
const
std
::
vector
<
json
>
&
headers
)
{
MSRStatus
ShardHeader
::
InitializeHeader
(
const
std
::
vector
<
json
>
&
headers
,
bool
load_dataset
)
{
shard_count_
=
headers
.
size
();
shard_count_
=
headers
.
size
();
int
shard_index
=
0
;
bool
first
=
true
;
bool
first
=
true
;
for
(
const
auto
&
header
:
headers
)
{
for
(
const
auto
&
header
:
headers
)
{
if
(
first
)
{
if
(
first
)
{
...
@@ -54,7 +55,8 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) {
...
@@ -54,7 +55,8 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) {
header_size_
=
header
[
"header_size"
].
get
<
uint64_t
>
();
header_size_
=
header
[
"header_size"
].
get
<
uint64_t
>
();
page_size_
=
header
[
"page_size"
].
get
<
uint64_t
>
();
page_size_
=
header
[
"page_size"
].
get
<
uint64_t
>
();
}
}
ParsePage
(
header
[
"page"
]);
ParsePage
(
header
[
"page"
],
shard_index
,
load_dataset
);
shard_index
++
;
}
}
return
SUCCESS
;
return
SUCCESS
;
}
}
...
@@ -136,40 +138,39 @@ std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path)
...
@@ -136,40 +138,39 @@ std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path)
return
{
SUCCESS
,
json_header
};
return
{
SUCCESS
,
json_header
};
}
}
MSRStatus
ShardHeader
::
Build
(
const
std
::
string
&
file_path
)
{
std
::
pair
<
MSRStatus
,
json
>
ShardHeader
::
BuildSingleHeader
(
const
std
::
string
&
file_path
)
{
auto
ret
=
ValidateHeader
(
file_path
);
auto
ret
=
ValidateHeader
(
file_path
);
if
(
SUCCESS
!=
ret
.
first
)
{
if
(
SUCCESS
!=
ret
.
first
)
{
return
FAILED
;
return
{
FAILED
,
json
()};
}
json
main_header
=
ret
.
second
;
json
addresses
=
main_header
[
"shard_addresses"
];
vector
<
string
>
real_addresses
;
auto
ret1
=
GetParentDir
(
file_path
);
if
(
SUCCESS
!=
ret1
.
first
)
{
return
FAILED
;
}
}
std
::
string
parent_dir
=
ret1
.
second
;
json
raw_header
=
ret
.
second
;
json
header
=
{{
"shard_addresses"
,
raw_header
[
"shard_addresses"
]},
{
"header_size"
,
raw_header
[
"header_size"
]},
{
"page_size"
,
raw_header
[
"page_size"
]},
{
"index_fields"
,
raw_header
[
"index_fields"
]},
{
"blob_fields"
,
raw_header
[
"schema"
][
0
][
"blob_fields"
]},
{
"schema"
,
raw_header
[
"schema"
][
0
][
"schema"
]},
{
"version"
,
raw_header
[
"version"
]}};
return
{
SUCCESS
,
header
};
}
for
(
const
auto
&
addr
:
addresses
)
{
MSRStatus
ShardHeader
::
BuildDataset
(
const
std
::
vector
<
std
::
string
>
&
file_paths
,
bool
load_dataset
)
{
std
::
string
absolute_path
=
parent_dir
+
string
(
addr
);
real_addresses
.
emplace_back
(
absolute_path
);
}
uint32_t
thread_num
=
std
::
thread
::
hardware_concurrency
();
uint32_t
thread_num
=
std
::
thread
::
hardware_concurrency
();
if
(
thread_num
==
0
)
thread_num
=
kThreadNumber
;
if
(
thread_num
==
0
)
thread_num
=
kThreadNumber
;
uint32_t
work_thread_num
=
0
;
uint32_t
work_thread_num
=
0
;
uint32_t
addr_count
=
real_addresse
s
.
size
();
uint32_t
shard_count
=
file_path
s
.
size
();
int
group_num
=
ceil
(
addr
_count
*
1.0
/
thread_num
);
int
group_num
=
ceil
(
shard
_count
*
1.0
/
thread_num
);
std
::
vector
<
std
::
thread
>
thread_set
(
thread_num
);
std
::
vector
<
std
::
thread
>
thread_set
(
thread_num
);
std
::
vector
<
json
>
headers
(
addr
_count
);
std
::
vector
<
json
>
headers
(
shard
_count
);
for
(
uint32_t
x
=
0
;
x
<
thread_num
;
++
x
)
{
for
(
uint32_t
x
=
0
;
x
<
thread_num
;
++
x
)
{
int
start_num
=
x
*
group_num
;
int
start_num
=
x
*
group_num
;
int
end_num
=
((
x
+
1
)
*
group_num
>
addr_count
)
?
addr
_count
:
(
x
+
1
)
*
group_num
;
int
end_num
=
((
x
+
1
)
*
group_num
>
shard_count
)
?
shard
_count
:
(
x
+
1
)
*
group_num
;
if
(
start_num
>=
end_num
)
{
if
(
start_num
>=
end_num
)
{
continue
;
continue
;
}
}
thread_set
[
x
]
=
thread_set
[
x
]
=
std
::
thread
(
&
ShardHeader
::
GetHeadersOneTask
,
this
,
start_num
,
end_num
,
std
::
ref
(
headers
),
real_addresse
s
);
std
::
thread
(
&
ShardHeader
::
GetHeadersOneTask
,
this
,
start_num
,
end_num
,
std
::
ref
(
headers
),
file_path
s
);
work_thread_num
++
;
work_thread_num
++
;
}
}
...
@@ -180,7 +181,7 @@ MSRStatus ShardHeader::Build(const std::string &file_path) {
...
@@ -180,7 +181,7 @@ MSRStatus ShardHeader::Build(const std::string &file_path) {
thread_status
=
false
;
thread_status
=
false
;
return
FAILED
;
return
FAILED
;
}
}
if
(
SUCCESS
!=
InitializeHeader
(
headers
))
{
if
(
SUCCESS
!=
InitializeHeader
(
headers
,
load_dataset
))
{
return
FAILED
;
return
FAILED
;
}
}
return
SUCCESS
;
return
SUCCESS
;
...
@@ -247,7 +248,8 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) {
...
@@ -247,7 +248,8 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) {
return
SUCCESS
;
return
SUCCESS
;
}
}
void
ShardHeader
::
ParsePage
(
const
json
&
pages
)
{
void
ShardHeader
::
ParsePage
(
const
json
&
pages
,
int
shard_index
,
bool
load_dataset
)
{
// set shard_index when load_dataset is false
if
(
pages_
.
empty
()
&&
shard_count_
<=
kMaxShardCount
)
{
if
(
pages_
.
empty
()
&&
shard_count_
<=
kMaxShardCount
)
{
pages_
.
resize
(
shard_count_
);
pages_
.
resize
(
shard_count_
);
}
}
...
@@ -267,7 +269,11 @@ void ShardHeader::ParsePage(const json &pages) {
...
@@ -267,7 +269,11 @@ void ShardHeader::ParsePage(const json &pages) {
std
::
shared_ptr
<
Page
>
parsed_page
=
std
::
make_shared
<
Page
>
(
page_id
,
shard_id
,
page_type
,
page_type_id
,
start_row_id
,
std
::
shared_ptr
<
Page
>
parsed_page
=
std
::
make_shared
<
Page
>
(
page_id
,
shard_id
,
page_type
,
page_type_id
,
start_row_id
,
end_row_id
,
row_group_ids
,
page_size
);
end_row_id
,
row_group_ids
,
page_size
);
pages_
[
shard_id
].
push_back
(
std
::
move
(
parsed_page
));
if
(
load_dataset
==
true
)
{
pages_
[
shard_id
].
push_back
(
std
::
move
(
parsed_page
));
}
else
{
pages_
[
shard_index
].
push_back
(
std
::
move
(
parsed_page
));
}
}
}
}
}
...
@@ -709,7 +715,7 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
...
@@ -709,7 +715,7 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
std
::
string
line
;
std
::
string
line
;
while
(
std
::
getline
(
page_in_handle
,
line
))
{
while
(
std
::
getline
(
page_in_handle
,
line
))
{
ParsePage
(
json
::
parse
(
line
));
ParsePage
(
json
::
parse
(
line
)
,
-
1
,
true
);
}
}
page_in_handle
.
close
();
page_in_handle
.
close
();
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
aa3f89e7
...
@@ -2189,7 +2189,7 @@ class MindDataset(SourceDataset):
...
@@ -2189,7 +2189,7 @@ class MindDataset(SourceDataset):
A source dataset that reads from shard files and database.
A source dataset that reads from shard files and database.
Args:
Args:
dataset_file (str
): one of file names
in dataset.
dataset_file (str
, list[str]): One of file names or file list
in dataset.
columns_list (list[str], optional): List of columns to be read (default=None).
columns_list (list[str], optional): List of columns to be read (default=None).
num_parallel_workers (int, optional): The number of readers (default=None).
num_parallel_workers (int, optional): The number of readers (default=None).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
...
@@ -2214,6 +2214,10 @@ class MindDataset(SourceDataset):
...
@@ -2214,6 +2214,10 @@ class MindDataset(SourceDataset):
shuffle
=
None
,
num_shards
=
None
,
shard_id
=
None
,
shuffle
=
None
,
num_shards
=
None
,
shard_id
=
None
,
block_reader
=
False
,
sampler
=
None
):
block_reader
=
False
,
sampler
=
None
):
super
().
__init__
(
num_parallel_workers
)
super
().
__init__
(
num_parallel_workers
)
if
isinstance
(
dataset_file
,
list
):
self
.
load_dataset
=
False
else
:
self
.
load_dataset
=
True
self
.
dataset_file
=
dataset_file
self
.
dataset_file
=
dataset_file
self
.
columns_list
=
columns_list
self
.
columns_list
=
columns_list
self
.
global_shuffle
=
shuffle
self
.
global_shuffle
=
shuffle
...
@@ -2256,6 +2260,7 @@ class MindDataset(SourceDataset):
...
@@ -2256,6 +2260,7 @@ class MindDataset(SourceDataset):
def
get_args
(
self
):
def
get_args
(
self
):
args
=
super
().
get_args
()
args
=
super
().
get_args
()
args
[
"dataset_file"
]
=
self
.
dataset_file
args
[
"dataset_file"
]
=
self
.
dataset_file
args
[
"load_dataset"
]
=
self
.
load_dataset
args
[
"columns_list"
]
=
self
.
columns_list
args
[
"columns_list"
]
=
self
.
columns_list
args
[
"global_shuffle"
]
=
self
.
global_shuffle
args
[
"global_shuffle"
]
=
self
.
global_shuffle
args
[
"partitions"
]
=
self
.
partitions
args
[
"partitions"
]
=
self
.
partitions
...
@@ -2272,8 +2277,11 @@ class MindDataset(SourceDataset):
...
@@ -2272,8 +2277,11 @@ class MindDataset(SourceDataset):
Return:
Return:
Number, number of batches.
Number, number of batches.
"""
"""
if
self
.
load_dataset
:
num_rows
=
MindRecordOp
.
get_num_rows
(
self
.
dataset_file
,
self
.
sampler
)
dataset_file
=
[
self
.
dataset_file
]
else
:
dataset_file
=
self
.
dataset_file
num_rows
=
MindRecordOp
.
get_num_rows
(
dataset_file
,
self
.
load_dataset
,
self
.
sampler
)
if
self
.
partitions
is
not
None
and
self
.
partitions
[
0
]
>
0
:
if
self
.
partitions
is
not
None
and
self
.
partitions
[
0
]
>
0
:
if
num_rows
%
self
.
partitions
[
0
]
==
0
:
if
num_rows
%
self
.
partitions
[
0
]
==
0
:
num_rows
=
num_rows
//
self
.
partitions
[
0
]
num_rows
=
num_rows
//
self
.
partitions
[
0
]
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
aa3f89e7
...
@@ -529,8 +529,11 @@ def check_minddataset(method):
...
@@ -529,8 +529,11 @@ def check_minddataset(method):
dataset_file
=
param_dict
.
get
(
'dataset_file'
)
dataset_file
=
param_dict
.
get
(
'dataset_file'
)
if
dataset_file
is
None
:
if
dataset_file
is
None
:
raise
ValueError
(
"dataset_file is not provided."
)
raise
ValueError
(
"dataset_file is not provided."
)
check_dataset_file
(
dataset_file
)
if
isinstance
(
dataset_file
,
list
):
for
f
in
dataset_file
:
check_dataset_file
(
f
)
else
:
check_dataset_file
(
dataset_file
)
check_param_type
(
nreq_param_int
,
param_dict
,
int
)
check_param_type
(
nreq_param_int
,
param_dict
,
int
)
check_param_type
(
nreq_param_list
,
param_dict
,
list
)
check_param_type
(
nreq_param_list
,
param_dict
,
list
)
...
...
mindspore/mindrecord/filereader.py
浏览文件 @
aa3f89e7
...
@@ -28,7 +28,7 @@ class FileReader:
...
@@ -28,7 +28,7 @@ class FileReader:
Class to read MindRecord File series.
Class to read MindRecord File series.
Args:
Args:
file_name (str
): File name of MindRecord File
.
file_name (str
, list[str]): One of MindRecord File or file list
.
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
It should not be smaller than 1 or larger than the number of CPU.
It should not be smaller than 1 or larger than the number of CPU.
columns (list[str], optional): List of fields which correspond data would be read (default=None).
columns (list[str], optional): List of fields which correspond data would be read (default=None).
...
@@ -38,8 +38,11 @@ class FileReader:
...
@@ -38,8 +38,11 @@ class FileReader:
ParamValueError: If file_name, num_consumer or columns is invalid.
ParamValueError: If file_name, num_consumer or columns is invalid.
"""
"""
def
__init__
(
self
,
file_name
,
num_consumer
=
4
,
columns
=
None
,
operator
=
None
):
def
__init__
(
self
,
file_name
,
num_consumer
=
4
,
columns
=
None
,
operator
=
None
):
check_filename
(
file_name
)
if
isinstance
(
file_name
,
list
):
self
.
_file_name
=
file_name
for
f
in
file_name
:
check_filename
(
f
)
else
:
check_filename
(
file_name
)
if
num_consumer
is
not
None
:
if
num_consumer
is
not
None
:
if
isinstance
(
num_consumer
,
int
):
if
isinstance
(
num_consumer
,
int
):
...
...
mindspore/mindrecord/mindpage.py
浏览文件 @
aa3f89e7
...
@@ -28,7 +28,7 @@ class MindPage:
...
@@ -28,7 +28,7 @@ class MindPage:
Class to read MindRecord File series in pagination.
Class to read MindRecord File series in pagination.
Args:
Args:
file_name (str):
File name of MindRecord File
.
file_name (str):
One of MindRecord File or file list
.
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
It should not be smaller than 1 or larger than the number of CPU.
It should not be smaller than 1 or larger than the number of CPU.
...
@@ -37,8 +37,11 @@ class MindPage:
...
@@ -37,8 +37,11 @@ class MindPage:
MRMInitSegmentError: If failed to initialize ShardSegment.
MRMInitSegmentError: If failed to initialize ShardSegment.
"""
"""
def
__init__
(
self
,
file_name
,
num_consumer
=
4
):
def
__init__
(
self
,
file_name
,
num_consumer
=
4
):
check_filename
(
file_name
)
if
isinstance
(
file_name
,
list
):
self
.
_file_name
=
file_name
for
f
in
file_name
:
check_filename
(
f
)
else
:
check_filename
(
file_name
)
if
num_consumer
is
not
None
:
if
num_consumer
is
not
None
:
if
isinstance
(
num_consumer
,
int
):
if
isinstance
(
num_consumer
,
int
):
...
...
mindspore/mindrecord/shardreader.py
浏览文件 @
aa3f89e7
...
@@ -35,7 +35,7 @@ class ShardReader:
...
@@ -35,7 +35,7 @@ class ShardReader:
Open file and prepare to read MindRecord File.
Open file and prepare to read MindRecord File.
Args:
Args:
file_name (str
): File name
of MindRecord File.
file_name (str
, list[str]): File names
of MindRecord File.
num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
columns (list[str]): List of fields which correspond data would be read.
columns (list[str]): List of fields which correspond data would be read.
operator(int): Reserved parameter for operators. Default: None.
operator(int): Reserved parameter for operators. Default: None.
...
@@ -48,7 +48,12 @@ class ShardReader:
...
@@ -48,7 +48,12 @@ class ShardReader:
"""
"""
columns
=
columns
if
columns
else
[]
columns
=
columns
if
columns
else
[]
operator
=
operator
if
operator
else
[]
operator
=
operator
if
operator
else
[]
ret
=
self
.
_reader
.
open
(
file_name
,
num_consumer
,
columns
,
operator
)
if
isinstance
(
file_name
,
list
):
load_dataset
=
False
else
:
load_dataset
=
True
file_name
=
[
file_name
]
ret
=
self
.
_reader
.
open
(
file_name
,
load_dataset
,
num_consumer
,
columns
,
operator
)
if
ret
!=
ms
.
MSRStatus
.
SUCCESS
:
if
ret
!=
ms
.
MSRStatus
.
SUCCESS
:
logger
.
error
(
"Failed to open {}."
.
format
(
file_name
))
logger
.
error
(
"Failed to open {}."
.
format
(
file_name
))
raise
MRMOpenError
raise
MRMOpenError
...
...
mindspore/mindrecord/shardsegment.py
浏览文件 @
aa3f89e7
...
@@ -40,7 +40,7 @@ class ShardSegment:
...
@@ -40,7 +40,7 @@ class ShardSegment:
Initialize the ShardSegment.
Initialize the ShardSegment.
Args:
Args:
file_name (str
): File name
of MindRecord File.
file_name (str
, list[str]): File names
of MindRecord File.
num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
columns (list[str]): List of fields which correspond data would be read.
columns (list[str]): List of fields which correspond data would be read.
operator(int): Reserved parameter for operators. Default: None.
operator(int): Reserved parameter for operators. Default: None.
...
@@ -53,7 +53,12 @@ class ShardSegment:
...
@@ -53,7 +53,12 @@ class ShardSegment:
"""
"""
self
.
_columns
=
columns
if
columns
else
[]
self
.
_columns
=
columns
if
columns
else
[]
operator
=
operator
if
operator
else
[]
operator
=
operator
if
operator
else
[]
ret
=
self
.
_segment
.
open
(
file_name
,
num_consumer
,
self
.
_columns
,
operator
)
if
isinstance
(
file_name
,
list
):
load_dataset
=
False
else
:
load_dataset
=
True
file_name
=
[
file_name
]
ret
=
self
.
_segment
.
open
(
file_name
,
load_dataset
,
num_consumer
,
self
.
_columns
,
operator
)
if
ret
!=
SUCCESS
:
if
ret
!=
SUCCESS
:
logger
.
error
(
"Failed to open {}."
.
format
(
file_name
))
logger
.
error
(
"Failed to open {}."
.
format
(
file_name
))
raise
MRMOpenError
raise
MRMOpenError
...
...
tests/ut/cpp/dataset/mind_record_op_test.cc
浏览文件 @
aa3f89e7
...
@@ -62,7 +62,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBasic) {
...
@@ -62,7 +62,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBasic) {
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
MindRecordOp
::
Builder
builder
;
MindRecordOp
::
Builder
builder
;
builder
.
SetDatasetFile
(
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
)
builder
.
SetDatasetFile
({
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
})
.
SetLoadDataset
(
true
)
.
SetRowsPerBuffer
(
3
)
.
SetRowsPerBuffer
(
3
)
.
SetNumMindRecordWorkers
(
4
)
.
SetNumMindRecordWorkers
(
4
)
.
SetColumnsToLoad
(
column_list
);
.
SetColumnsToLoad
(
column_list
);
...
@@ -132,7 +133,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordSample) {
...
@@ -132,7 +133,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordSample) {
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
MindRecordOp
::
Builder
builder
;
MindRecordOp
::
Builder
builder
;
builder
.
SetDatasetFile
(
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
)
builder
.
SetDatasetFile
({
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
})
.
SetLoadDataset
(
true
)
.
SetRowsPerBuffer
(
3
)
.
SetRowsPerBuffer
(
3
)
.
SetNumMindRecordWorkers
(
4
)
.
SetNumMindRecordWorkers
(
4
)
.
SetColumnsToLoad
(
column_list
)
.
SetColumnsToLoad
(
column_list
)
...
@@ -203,7 +205,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordShuffle) {
...
@@ -203,7 +205,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordShuffle) {
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
MindRecordOp
::
Builder
builder
;
MindRecordOp
::
Builder
builder
;
builder
.
SetDatasetFile
(
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
)
builder
.
SetDatasetFile
({
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
})
.
SetLoadDataset
(
true
)
.
SetRowsPerBuffer
(
3
)
.
SetRowsPerBuffer
(
3
)
.
SetNumMindRecordWorkers
(
4
)
.
SetNumMindRecordWorkers
(
4
)
.
SetColumnsToLoad
(
column_list
)
.
SetColumnsToLoad
(
column_list
)
...
@@ -277,7 +280,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordCategory) {
...
@@ -277,7 +280,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordCategory) {
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
MindRecordOp
::
Builder
builder
;
MindRecordOp
::
Builder
builder
;
builder
.
SetDatasetFile
(
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
)
builder
.
SetDatasetFile
({
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
})
.
SetLoadDataset
(
true
)
.
SetRowsPerBuffer
(
3
)
.
SetRowsPerBuffer
(
3
)
.
SetNumMindRecordWorkers
(
4
)
.
SetNumMindRecordWorkers
(
4
)
.
SetColumnsToLoad
(
column_list
)
.
SetColumnsToLoad
(
column_list
)
...
@@ -345,7 +349,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) {
...
@@ -345,7 +349,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) {
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
MindRecordOp
::
Builder
builder
;
MindRecordOp
::
Builder
builder
;
builder
.
SetDatasetFile
(
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
)
builder
.
SetDatasetFile
({
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
})
.
SetLoadDataset
(
true
)
.
SetRowsPerBuffer
(
3
)
.
SetRowsPerBuffer
(
3
)
.
SetNumMindRecordWorkers
(
4
)
.
SetNumMindRecordWorkers
(
4
)
.
SetColumnsToLoad
(
column_list
);
.
SetColumnsToLoad
(
column_list
);
...
@@ -426,7 +431,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) {
...
@@ -426,7 +431,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) {
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
MindRecordOp
::
Builder
builder
;
MindRecordOp
::
Builder
builder
;
builder
.
SetDatasetFile
(
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
)
builder
.
SetDatasetFile
({
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
})
.
SetLoadDataset
(
true
)
.
SetRowsPerBuffer
(
3
)
.
SetRowsPerBuffer
(
3
)
.
SetNumMindRecordWorkers
(
4
)
.
SetNumMindRecordWorkers
(
4
)
.
SetBlockReader
()
.
SetBlockReader
()
...
@@ -507,7 +513,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordInvalidColumnList) {
...
@@ -507,7 +513,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordInvalidColumnList) {
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
std
::
shared_ptr
<
MindRecordOp
>
my_mindrecord_op
;
MindRecordOp
::
Builder
builder
;
MindRecordOp
::
Builder
builder
;
builder
.
SetDatasetFile
(
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
)
builder
.
SetDatasetFile
({
mindrecord_root_path_
+
"/testMindDataSet/testImageNetData/imagenet.mindrecord0"
})
.
SetLoadDataset
(
true
)
.
SetRowsPerBuffer
(
3
)
.
SetRowsPerBuffer
(
3
)
.
SetNumMindRecordWorkers
(
4
)
.
SetNumMindRecordWorkers
(
4
)
.
SetColumnsToLoad
(
column_list
);
.
SetColumnsToLoad
(
column_list
);
...
...
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
浏览文件 @
aa3f89e7
...
@@ -63,7 +63,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
...
@@ -63,7 +63,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
kSampleCount
));
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
kSampleCount
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -89,7 +89,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
...
@@ -89,7 +89,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
kNum
,
kDen
));
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
kNum
,
kDen
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -115,7 +115,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
...
@@ -115,7 +115,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
kNum
,
kDen
));
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
kNum
,
kDen
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -144,7 +144,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
...
@@ -144,7 +144,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
ASSERT_TRUE
(
partitions
.
second
==
2
);
ASSERT_TRUE
(
partitions
.
second
==
2
);
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -168,7 +168,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
...
@@ -168,7 +168,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
ops
.
push_back
(
std
::
make_shared
<
ShardPkSample
>
(
"label"
,
2
));
ops
.
push_back
(
std
::
make_shared
<
ShardPkSample
>
(
"label"
,
2
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -193,7 +193,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
...
@@ -193,7 +193,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
ops
.
push_back
(
std
::
make_shared
<
ShardPkSample
>
(
"label"
,
2
,
3
,
0
));
ops
.
push_back
(
std
::
make_shared
<
ShardPkSample
>
(
"label"
,
2
,
3
,
0
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -223,7 +223,7 @@ TEST_F(TestShardOperator, TestShardCategory) {
...
@@ -223,7 +223,7 @@ TEST_F(TestShardOperator, TestShardCategory) {
ops
.
push_back
(
std
::
make_shared
<
ShardCategory
>
(
categories
));
ops
.
push_back
(
std
::
make_shared
<
ShardCategory
>
(
categories
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -254,7 +254,7 @@ TEST_F(TestShardOperator, TestShardShuffle) {
...
@@ -254,7 +254,7 @@ TEST_F(TestShardOperator, TestShardShuffle) {
ops
.
push_back
(
std
::
make_shared
<
ShardShuffle
>
(
1
));
ops
.
push_back
(
std
::
make_shared
<
ShardShuffle
>
(
1
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
16
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
16
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -279,7 +279,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
...
@@ -279,7 +279,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
ops
.
push_back
(
std
::
make_shared
<
ShardShuffle
>
(
1
));
ops
.
push_back
(
std
::
make_shared
<
ShardShuffle
>
(
1
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -306,7 +306,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
...
@@ -306,7 +306,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
kSampleSize
));
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
kSampleSize
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_name
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
true
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -333,7 +333,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
...
@@ -333,7 +333,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
35
));
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
35
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -357,11 +357,11 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
...
@@ -357,11 +357,11 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
ops
.
push_back
(
std
::
make_shared
<
ShardShuffle
>
(
1
));
ops
.
push_back
(
std
::
make_shared
<
ShardShuffle
>
(
1
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_name
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
true
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
ShardReader
compare_dataset
;
ShardReader
compare_dataset
;
compare_dataset
.
Open
(
file_nam
e
,
4
,
column_list
);
compare_dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
);
compare_dataset
.
Launch
();
compare_dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -396,7 +396,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
...
@@ -396,7 +396,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
ops
.
push_back
(
std
::
make_shared
<
ShardShuffle
>
(
21
));
ops
.
push_back
(
std
::
make_shared
<
ShardShuffle
>
(
21
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -430,7 +430,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
...
@@ -430,7 +430,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
ops
.
push_back
(
std
::
make_shared
<
ShardCategory
>
(
categories
));
ops
.
push_back
(
std
::
make_shared
<
ShardCategory
>
(
categories
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -464,7 +464,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
...
@@ -464,7 +464,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
ops
.
push_back
(
std
::
make_shared
<
ShardCategory
>
(
categories
));
ops
.
push_back
(
std
::
make_shared
<
ShardCategory
>
(
categories
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
@@ -502,7 +502,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
...
@@ -502,7 +502,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
ops
.
push_back
(
std
::
make_shared
<
ShardShuffle
>
(
100
));
ops
.
push_back
(
std
::
make_shared
<
ShardShuffle
>
(
100
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
int
i
=
0
;
int
i
=
0
;
...
...
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
浏览文件 @
aa3f89e7
...
@@ -55,7 +55,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
...
@@ -55,7 +55,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
};
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
};
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
);
dataset
.
Launch
();
dataset
.
Launch
();
while
(
true
)
{
while
(
true
)
{
...
@@ -78,7 +78,7 @@ TEST_F(TestShardReader, TestShardReaderSample) {
...
@@ -78,7 +78,7 @@ TEST_F(TestShardReader, TestShardReaderSample) {
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
std
::
vector
<
std
::
shared_ptr
<
ShardOperator
>>
ops
;
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
17
));
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
17
));
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
);
dataset
.
Launch
();
dataset
.
Launch
();
while
(
true
)
{
while
(
true
)
{
...
@@ -103,7 +103,7 @@ TEST_F(TestShardReader, TestShardReaderBlock) {
...
@@ -103,7 +103,7 @@ TEST_F(TestShardReader, TestShardReaderBlock) {
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
3
));
ops
.
push_back
(
std
::
make_shared
<
ShardSample
>
(
3
));
ShardReader
dataset
;
ShardReader
dataset
;
const
bool
kBlockReader
=
true
;
const
bool
kBlockReader
=
true
;
dataset
.
Open
(
file_nam
e
,
4
,
column_list
,
ops
,
kBlockReader
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
,
ops
,
kBlockReader
);
dataset
.
Launch
();
dataset
.
Launch
();
while
(
true
)
{
while
(
true
)
{
...
@@ -123,7 +123,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
...
@@ -123,7 +123,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
MS_LOG
(
INFO
)
<<
FormatInfo
(
"Test read imageNet"
);
MS_LOG
(
INFO
)
<<
FormatInfo
(
"Test read imageNet"
);
std
::
string
file_name
=
"./imagenet.shard01"
;
std
::
string
file_name
=
"./imagenet.shard01"
;
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_nam
e
);
dataset
.
Open
(
{
file_name
},
tru
e
);
dataset
.
Launch
();
dataset
.
Launch
();
while
(
true
)
{
while
(
true
)
{
...
@@ -143,7 +143,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
...
@@ -143,7 +143,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
std
::
string
file_name
=
"./imagenet.shard01"
;
std
::
string
file_name
=
"./imagenet.shard01"
;
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"label"
};
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"label"
};
ShardReader
dataset
;
ShardReader
dataset
;
MSRStatus
ret
=
dataset
.
Open
(
file_name
,
4
,
column_list
);
MSRStatus
ret
=
dataset
.
Open
(
{
file_name
},
true
,
4
,
column_list
);
ASSERT_EQ
(
ret
,
SUCCESS
);
ASSERT_EQ
(
ret
,
SUCCESS
);
dataset
.
Launch
();
dataset
.
Launch
();
...
@@ -164,7 +164,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) {
...
@@ -164,7 +164,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) {
std
::
string
file_name
=
"./imagenet.shard01"
;
std
::
string
file_name
=
"./imagenet.shard01"
;
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_namex"
};
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_namex"
};
ShardReader
dataset
;
ShardReader
dataset
;
MSRStatus
ret
=
dataset
.
Open
(
file_nam
e
,
4
,
column_list
);
MSRStatus
ret
=
dataset
.
Open
(
{
file_name
},
tru
e
,
4
,
column_list
);
ASSERT_EQ
(
ret
,
ILLEGAL_COLUMN_LIST
);
ASSERT_EQ
(
ret
,
ILLEGAL_COLUMN_LIST
);
}
}
...
@@ -172,7 +172,7 @@ TEST_F(TestShardReader, TestShardVersion) {
...
@@ -172,7 +172,7 @@ TEST_F(TestShardReader, TestShardVersion) {
MS_LOG
(
INFO
)
<<
FormatInfo
(
"Test shard version"
);
MS_LOG
(
INFO
)
<<
FormatInfo
(
"Test shard version"
);
std
::
string
file_name
=
"./imagenet.shard01"
;
std
::
string
file_name
=
"./imagenet.shard01"
;
ShardReader
dataset
;
ShardReader
dataset
;
MSRStatus
ret
=
dataset
.
Open
(
file_name
,
4
);
MSRStatus
ret
=
dataset
.
Open
(
{
file_name
},
true
,
4
);
ASSERT_EQ
(
ret
,
SUCCESS
);
ASSERT_EQ
(
ret
,
SUCCESS
);
dataset
.
Launch
();
dataset
.
Launch
();
...
@@ -195,7 +195,7 @@ TEST_F(TestShardReader, TestShardReaderDir) {
...
@@ -195,7 +195,7 @@ TEST_F(TestShardReader, TestShardReaderDir) {
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
};
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
};
ShardReader
dataset
;
ShardReader
dataset
;
MSRStatus
ret
=
dataset
.
Open
(
file_name
,
4
,
column_list
);
MSRStatus
ret
=
dataset
.
Open
(
{
file_name
},
true
,
4
,
column_list
);
ASSERT_EQ
(
ret
,
FAILED
);
ASSERT_EQ
(
ret
,
FAILED
);
}
}
...
@@ -205,7 +205,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
...
@@ -205,7 +205,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
};
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"file_name"
};
ShardReader
dataset
;
ShardReader
dataset
;
dataset
.
Open
(
file_name
,
-
481565535
,
column_list
);
dataset
.
Open
(
{
file_name
},
true
,
-
481565535
,
column_list
);
dataset
.
Launch
();
dataset
.
Launch
();
while
(
true
)
{
while
(
true
)
{
...
...
tests/ut/cpp/mindrecord/ut_shard_segment_test.cc
浏览文件 @
aa3f89e7
...
@@ -59,7 +59,7 @@ TEST_F(TestShardSegment, TestShardSegment) {
...
@@ -59,7 +59,7 @@ TEST_F(TestShardSegment, TestShardSegment) {
std
::
string
file_name
=
"./imagenet.shard01"
;
std
::
string
file_name
=
"./imagenet.shard01"
;
ShardSegment
dataset
;
ShardSegment
dataset
;
dataset
.
Open
(
file_nam
e
,
4
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
);
auto
x
=
dataset
.
GetCategoryFields
();
auto
x
=
dataset
.
GetCategoryFields
();
for
(
const
auto
&
fields
:
x
.
second
)
{
for
(
const
auto
&
fields
:
x
.
second
)
{
...
@@ -97,7 +97,7 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) {
...
@@ -97,7 +97,7 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) {
std
::
string
file_name
=
"./imagenet.shard01"
;
std
::
string
file_name
=
"./imagenet.shard01"
;
ShardSegment
dataset
;
ShardSegment
dataset
;
dataset
.
Open
(
file_nam
e
,
4
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
);
auto
x
=
dataset
.
GetCategoryFields
();
auto
x
=
dataset
.
GetCategoryFields
();
for
(
const
auto
&
fields
:
x
.
second
)
{
for
(
const
auto
&
fields
:
x
.
second
)
{
...
@@ -121,7 +121,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) {
...
@@ -121,7 +121,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) {
std
::
string
file_name
=
"./imagenet.shard01"
;
std
::
string
file_name
=
"./imagenet.shard01"
;
ShardSegment
dataset
;
ShardSegment
dataset
;
dataset
.
Open
(
file_name
,
4
);
dataset
.
Open
(
{
file_name
},
true
,
4
);
auto
x
=
dataset
.
GetCategoryFields
();
auto
x
=
dataset
.
GetCategoryFields
();
for
(
const
auto
&
fields
:
x
.
second
)
{
for
(
const
auto
&
fields
:
x
.
second
)
{
...
@@ -143,7 +143,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) {
...
@@ -143,7 +143,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) {
std
::
string
file_name
=
"./imagenet.shard01"
;
std
::
string
file_name
=
"./imagenet.shard01"
;
ShardSegment
dataset
;
ShardSegment
dataset
;
dataset
.
Open
(
file_nam
e
,
4
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
);
auto
x
=
dataset
.
GetCategoryFields
();
auto
x
=
dataset
.
GetCategoryFields
();
for
(
const
auto
&
fields
:
x
.
second
)
{
for
(
const
auto
&
fields
:
x
.
second
)
{
...
@@ -165,7 +165,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) {
...
@@ -165,7 +165,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) {
std
::
string
file_name
=
"./imagenet.shard01"
;
std
::
string
file_name
=
"./imagenet.shard01"
;
ShardSegment
dataset
;
ShardSegment
dataset
;
dataset
.
Open
(
file_nam
e
,
4
);
dataset
.
Open
(
{
file_name
},
tru
e
,
4
);
auto
x
=
dataset
.
GetCategoryFields
();
auto
x
=
dataset
.
GetCategoryFields
();
for
(
const
auto
&
fields
:
x
.
second
)
{
for
(
const
auto
&
fields
:
x
.
second
)
{
...
...
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
浏览文件 @
aa3f89e7
...
@@ -60,7 +60,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
...
@@ -60,7 +60,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
std
::
string
filename
=
"./OneSample.shard01"
;
std
::
string
filename
=
"./OneSample.shard01"
;
ShardReader
dataset
;
ShardReader
dataset
;
MSRStatus
ret
=
dataset
.
Open
(
filenam
e
,
4
);
MSRStatus
ret
=
dataset
.
Open
(
{
filename
},
tru
e
,
4
);
ASSERT_EQ
(
ret
,
SUCCESS
);
ASSERT_EQ
(
ret
,
SUCCESS
);
dataset
.
Launch
();
dataset
.
Launch
();
...
@@ -756,7 +756,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
...
@@ -756,7 +756,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
filename
=
"./imagenet.shard01"
;
filename
=
"./imagenet.shard01"
;
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"label"
,
"file_name"
,
"data"
};
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"label"
,
"file_name"
,
"data"
};
ShardReader
dataset
;
ShardReader
dataset
;
MSRStatus
ret
=
dataset
.
Open
(
filenam
e
,
4
,
column_list
);
MSRStatus
ret
=
dataset
.
Open
(
{
filename
},
tru
e
,
4
,
column_list
);
ASSERT_EQ
(
ret
,
SUCCESS
);
ASSERT_EQ
(
ret
,
SUCCESS
);
dataset
.
Launch
();
dataset
.
Launch
();
...
@@ -842,7 +842,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
...
@@ -842,7 +842,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
filename
=
"./imagenet.shard01"
;
filename
=
"./imagenet.shard01"
;
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"label"
,
"file_name"
};
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"label"
,
"file_name"
};
ShardReader
dataset
;
ShardReader
dataset
;
MSRStatus
ret
=
dataset
.
Open
(
filenam
e
,
4
,
column_list
);
MSRStatus
ret
=
dataset
.
Open
(
{
filename
},
tru
e
,
4
,
column_list
);
ASSERT_EQ
(
ret
,
SUCCESS
);
ASSERT_EQ
(
ret
,
SUCCESS
);
dataset
.
Launch
();
dataset
.
Launch
();
...
@@ -936,7 +936,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
...
@@ -936,7 +936,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
filename
=
"./imagenet.shard01"
;
filename
=
"./imagenet.shard01"
;
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"label"
,
"data"
};
auto
column_list
=
std
::
vector
<
std
::
string
>
{
"label"
,
"data"
};
ShardReader
dataset
;
ShardReader
dataset
;
MSRStatus
ret
=
dataset
.
Open
(
filenam
e
,
4
,
column_list
);
MSRStatus
ret
=
dataset
.
Open
(
{
filename
},
tru
e
,
4
,
column_list
);
ASSERT_EQ
(
ret
,
SUCCESS
);
ASSERT_EQ
(
ret
,
SUCCESS
);
dataset
.
Launch
();
dataset
.
Launch
();
...
@@ -1043,7 +1043,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {
...
@@ -1043,7 +1043,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {
filename
=
"./TenSampleFortyShard.shard01"
;
filename
=
"./TenSampleFortyShard.shard01"
;
ShardReader
dataset
;
ShardReader
dataset
;
MSRStatus
ret
=
dataset
.
Open
(
filenam
e
,
4
);
MSRStatus
ret
=
dataset
.
Open
(
{
filename
},
tru
e
,
4
);
ASSERT_EQ
(
ret
,
SUCCESS
);
ASSERT_EQ
(
ret
,
SUCCESS
);
dataset
.
Launch
();
dataset
.
Launch
();
...
...
tests/ut/python/dataset/test_minddataset.py
浏览文件 @
aa3f89e7
...
@@ -32,6 +32,8 @@ from mindspore.mindrecord import FileWriter
...
@@ -32,6 +32,8 @@ from mindspore.mindrecord import FileWriter
FILES_NUM
=
4
FILES_NUM
=
4
CV_FILE_NAME
=
"../data/mindrecord/imagenet.mindrecord"
CV_FILE_NAME
=
"../data/mindrecord/imagenet.mindrecord"
CV1_FILE_NAME
=
"../data/mindrecord/imagenet1.mindrecord"
CV2_FILE_NAME
=
"../data/mindrecord/imagenet2.mindrecord"
CV_DIR_NAME
=
"../data/mindrecord/testImageNetData"
CV_DIR_NAME
=
"../data/mindrecord/testImageNetData"
NLP_FILE_NAME
=
"../data/mindrecord/aclImdb.mindrecord"
NLP_FILE_NAME
=
"../data/mindrecord/aclImdb.mindrecord"
NLP_FILE_POS
=
"../data/mindrecord/testAclImdbData/pos"
NLP_FILE_POS
=
"../data/mindrecord/testAclImdbData/pos"
...
@@ -111,7 +113,6 @@ def test_cv_minddataset_writer_tutorial():
...
@@ -111,7 +113,6 @@ def test_cv_minddataset_writer_tutorial():
os
.
remove
(
"{}"
.
format
(
x
))
os
.
remove
(
"{}"
.
format
(
x
))
os
.
remove
(
"{}.db"
.
format
(
x
))
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_cv_minddataset_partition_tutorial
(
add_and_remove_cv_file
):
def
test_cv_minddataset_partition_tutorial
(
add_and_remove_cv_file
):
"""tutorial for cv minddataset."""
"""tutorial for cv minddataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
...
@@ -247,6 +248,126 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem
...
@@ -247,6 +248,126 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem
assert
num_iter
==
20
assert
num_iter
==
20
def
test_cv_minddataset_reader_file_list
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
data_set
=
ds
.
MindDataset
([
CV_FILE_NAME
+
str
(
x
)
for
x
in
range
(
FILES_NUM
)],
columns_list
,
num_readers
)
assert
data_set
.
get_dataset_size
()
==
10
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- len(item[data]): {} ------------------------"
.
format
(
len
(
item
[
"data"
])))
logger
.
info
(
"-------------- item[data]: {} -----------------------------"
.
format
(
item
[
"data"
]))
logger
.
info
(
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
assert
num_iter
==
10
def
test_cv_minddataset_reader_one_partition
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
data_set
=
ds
.
MindDataset
([
CV_FILE_NAME
+
"0"
],
columns_list
,
num_readers
)
assert
data_set
.
get_dataset_size
()
<
10
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- len(item[data]): {} ------------------------"
.
format
(
len
(
item
[
"data"
])))
logger
.
info
(
"-------------- item[data]: {} -----------------------------"
.
format
(
item
[
"data"
]))
logger
.
info
(
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
assert
num_iter
<
10
def
test_cv_minddataset_reader_two_dataset
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
if
os
.
path
.
exists
(
CV1_FILE_NAME
):
os
.
remove
(
CV1_FILE_NAME
)
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV1_FILE_NAME
)):
os
.
remove
(
"{}.db"
.
format
(
CV1_FILE_NAME
))
if
os
.
path
.
exists
(
CV2_FILE_NAME
):
os
.
remove
(
CV2_FILE_NAME
)
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV2_FILE_NAME
)):
os
.
remove
(
"{}.db"
.
format
(
CV2_FILE_NAME
))
writer
=
FileWriter
(
CV1_FILE_NAME
,
1
)
data
=
get_data
(
CV_DIR_NAME
)
cv_schema_json
=
{
"id"
:
{
"type"
:
"int32"
},
"file_name"
:
{
"type"
:
"string"
},
"label"
:
{
"type"
:
"int32"
},
"data"
:
{
"type"
:
"bytes"
}}
writer
.
add_schema
(
cv_schema_json
,
"CV1_schema"
)
writer
.
add_index
([
"file_name"
,
"label"
])
writer
.
write_raw_data
(
data
)
writer
.
commit
()
writer
=
FileWriter
(
CV2_FILE_NAME
,
1
)
data
=
get_data
(
CV_DIR_NAME
)
cv_schema_json
=
{
"id"
:
{
"type"
:
"int32"
},
"file_name"
:
{
"type"
:
"string"
},
"label"
:
{
"type"
:
"int32"
},
"data"
:
{
"type"
:
"bytes"
}}
writer
.
add_schema
(
cv_schema_json
,
"CV2_schema"
)
writer
.
add_index
([
"file_name"
,
"label"
])
writer
.
write_raw_data
(
data
)
writer
.
commit
()
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
data_set
=
ds
.
MindDataset
([
CV_FILE_NAME
+
str
(
x
)
for
x
in
range
(
FILES_NUM
)]
+
[
CV1_FILE_NAME
,
CV2_FILE_NAME
],
columns_list
,
num_readers
)
assert
data_set
.
get_dataset_size
()
==
30
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- len(item[data]): {} ------------------------"
.
format
(
len
(
item
[
"data"
])))
logger
.
info
(
"-------------- item[data]: {} -----------------------------"
.
format
(
item
[
"data"
]))
logger
.
info
(
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
assert
num_iter
==
30
if
os
.
path
.
exists
(
CV1_FILE_NAME
):
os
.
remove
(
CV1_FILE_NAME
)
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV1_FILE_NAME
)):
os
.
remove
(
"{}.db"
.
format
(
CV1_FILE_NAME
))
if
os
.
path
.
exists
(
CV2_FILE_NAME
):
os
.
remove
(
CV2_FILE_NAME
)
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV2_FILE_NAME
)):
os
.
remove
(
"{}.db"
.
format
(
CV2_FILE_NAME
))
def
test_cv_minddataset_reader_two_dataset_partition
(
add_and_remove_cv_file
):
paths
=
[
"{}{}"
.
format
(
CV1_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
for
x
in
range
(
FILES_NUM
)]
for
x
in
paths
:
os
.
remove
(
"{}"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}"
.
format
(
x
))
else
None
os
.
remove
(
"{}.db"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
))
else
None
writer
=
FileWriter
(
CV1_FILE_NAME
,
FILES_NUM
)
data
=
get_data
(
CV_DIR_NAME
)
cv_schema_json
=
{
"id"
:
{
"type"
:
"int32"
},
"file_name"
:
{
"type"
:
"string"
},
"label"
:
{
"type"
:
"int32"
},
"data"
:
{
"type"
:
"bytes"
}}
writer
.
add_schema
(
cv_schema_json
,
"CV1_schema"
)
writer
.
add_index
([
"file_name"
,
"label"
])
writer
.
write_raw_data
(
data
)
writer
.
commit
()
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
data_set
=
ds
.
MindDataset
([
CV_FILE_NAME
+
str
(
x
)
for
x
in
range
(
2
)]
+
[
CV1_FILE_NAME
+
str
(
x
)
for
x
in
range
(
2
,
4
)],
columns_list
,
num_readers
)
assert
data_set
.
get_dataset_size
()
<
20
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- cv reader basic: {} ------------------------"
.
format
(
num_iter
))
logger
.
info
(
"-------------- len(item[data]): {} ------------------------"
.
format
(
len
(
item
[
"data"
])))
logger
.
info
(
"-------------- item[data]: {} -----------------------------"
.
format
(
item
[
"data"
]))
logger
.
info
(
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} ----------------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
assert
num_iter
<
20
for
x
in
paths
:
os
.
remove
(
"{}"
.
format
(
x
))
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_cv_minddataset_reader_basic_tutorial
(
add_and_remove_cv_file
):
def
test_cv_minddataset_reader_basic_tutorial
(
add_and_remove_cv_file
):
"""tutorial for cv minderdataset."""
"""tutorial for cv minderdataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
...
...
tests/ut/python/dataset/test_minddataset_exception.py
浏览文件 @
aa3f89e7
...
@@ -22,6 +22,7 @@ import mindspore.dataset as ds
...
@@ -22,6 +22,7 @@ import mindspore.dataset as ds
from
mindspore.mindrecord
import
FileWriter
from
mindspore.mindrecord
import
FileWriter
CV_FILE_NAME
=
"./imagenet.mindrecord"
CV_FILE_NAME
=
"./imagenet.mindrecord"
CV1_FILE_NAME
=
"./imagenet1.mindrecord"
def
create_cv_mindrecord
(
files_num
):
def
create_cv_mindrecord
(
files_num
):
...
@@ -37,6 +38,31 @@ def create_cv_mindrecord(files_num):
...
@@ -37,6 +38,31 @@ def create_cv_mindrecord(files_num):
writer
.
commit
()
writer
.
commit
()
def
create_diff_schema_cv_mindrecord
(
files_num
):
"""tutorial for cv dataset writer."""
os
.
remove
(
CV1_FILE_NAME
)
if
os
.
path
.
exists
(
CV1_FILE_NAME
)
else
None
os
.
remove
(
"{}.db"
.
format
(
CV1_FILE_NAME
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV1_FILE_NAME
))
else
None
writer
=
FileWriter
(
CV1_FILE_NAME
,
files_num
)
cv_schema_json
=
{
"file_name_1"
:
{
"type"
:
"string"
},
"label"
:
{
"type"
:
"int32"
},
"data"
:
{
"type"
:
"bytes"
}}
data
=
[{
"file_name_1"
:
"001.jpg"
,
"label"
:
43
,
"data"
:
bytes
(
'0xffsafdafda'
,
encoding
=
'utf-8'
)}]
writer
.
add_schema
(
cv_schema_json
,
"img_schema"
)
writer
.
add_index
([
"file_name_1"
,
"label"
])
writer
.
write_raw_data
(
data
)
writer
.
commit
()
def
create_diff_page_size_cv_mindrecord
(
files_num
):
"""tutorial for cv dataset writer."""
os
.
remove
(
CV1_FILE_NAME
)
if
os
.
path
.
exists
(
CV1_FILE_NAME
)
else
None
os
.
remove
(
"{}.db"
.
format
(
CV1_FILE_NAME
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV1_FILE_NAME
))
else
None
writer
=
FileWriter
(
CV1_FILE_NAME
,
files_num
)
writer
.
set_page_size
(
1
<<
26
)
#64MB
cv_schema_json
=
{
"file_name"
:
{
"type"
:
"string"
},
"label"
:
{
"type"
:
"int32"
},
"data"
:
{
"type"
:
"bytes"
}}
data
=
[{
"file_name"
:
"001.jpg"
,
"label"
:
43
,
"data"
:
bytes
(
'0xffsafdafda'
,
encoding
=
'utf-8'
)}]
writer
.
add_schema
(
cv_schema_json
,
"img_schema"
)
writer
.
add_index
([
"file_name"
,
"label"
])
writer
.
write_raw_data
(
data
)
writer
.
commit
()
def
test_cv_lack_json
():
def
test_cv_lack_json
():
"""tutorial for cv minderdataset."""
"""tutorial for cv minderdataset."""
create_cv_mindrecord
(
1
)
create_cv_mindrecord
(
1
)
...
@@ -111,3 +137,34 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle():
...
@@ -111,3 +137,34 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle():
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
def
test_cv_minddataset_reader_different_schema
():
create_cv_mindrecord
(
1
)
create_diff_schema_cv_mindrecord
(
1
)
columns_list
=
[
"data"
,
"label"
]
num_readers
=
4
with
pytest
.
raises
(
Exception
,
match
=
"MindRecordOp init failed"
):
data_set
=
ds
.
MindDataset
([
CV_FILE_NAME
,
CV1_FILE_NAME
],
columns_list
,
num_readers
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
os
.
remove
(
CV1_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV1_FILE_NAME
))
def
test_cv_minddataset_reader_different_page_size
():
create_cv_mindrecord
(
1
)
create_diff_page_size_cv_mindrecord
(
1
)
columns_list
=
[
"data"
,
"label"
]
num_readers
=
4
with
pytest
.
raises
(
Exception
,
match
=
"MindRecordOp init failed"
):
data_set
=
ds
.
MindDataset
([
CV_FILE_NAME
,
CV1_FILE_NAME
],
columns_list
,
num_readers
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
os
.
remove
(
CV1_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV1_FILE_NAME
))
tests/ut/python/mindrecord/test_mindrecord_base.py
浏览文件 @
aa3f89e7
...
@@ -202,6 +202,16 @@ def test_cv_file_reader_tutorial():
...
@@ -202,6 +202,16 @@ def test_cv_file_reader_tutorial():
assert
count
==
10
assert
count
==
10
reader
.
close
()
reader
.
close
()
def
test_cv_file_reader_file_list
():
"""tutorial for cv file partial reader."""
reader
=
FileReader
([
CV_FILE_NAME
+
str
(
x
)
for
x
in
range
(
FILES_NUM
)])
count
=
0
for
index
,
x
in
enumerate
(
reader
.
get_next
()):
assert
len
(
x
)
==
3
count
=
count
+
1
logger
.
info
(
"#item{}: {}"
.
format
(
index
,
x
))
assert
count
==
10
def
test_cv_file_reader_partial_tutorial
():
def
test_cv_file_reader_partial_tutorial
():
"""tutorial for cv file partial reader."""
"""tutorial for cv file partial reader."""
reader
=
FileReader
(
CV_FILE_NAME
+
"0"
)
reader
=
FileReader
(
CV_FILE_NAME
+
"0"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录