Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1f222ddb
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1f222ddb
编写于
4月 23, 2020
作者:
L
liyong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix mindrecord c ut
上级
ebc3f12b
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
447 addition
and
445 deletion
+447
-445
mindspore/ccsrc/mindrecord/io/shard_reader.cc
mindspore/ccsrc/mindrecord/io/shard_reader.cc
+2
-1
tests/ut/cpp/mindrecord/ut_common.cc
tests/ut/cpp/mindrecord/ut_common.cc
+350
-19
tests/ut/cpp/mindrecord/ut_common.h
tests/ut/cpp/mindrecord/ut_common.h
+21
-4
tests/ut/cpp/mindrecord/ut_shard.cc
tests/ut/cpp/mindrecord/ut_shard.cc
+7
-3
tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc
tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc
+0
-32
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
+13
-2
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
+24
-14
tests/ut/cpp/mindrecord/ut_shard_segment_test.cc
tests/ut/cpp/mindrecord/ut_shard_segment_test.cc
+12
-2
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
+18
-342
tests/ut/cpp/mindrecord/ut_shard_writer_test.h
tests/ut/cpp/mindrecord/ut_shard_writer_test.h
+0
-26
未找到文件。
mindspore/ccsrc/mindrecord/io/shard_reader.cc
浏览文件 @
1f222ddb
...
...
@@ -346,7 +346,8 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string
MS_LOG
(
ERROR
)
<<
"Error in select sql statement, sql:"
<<
common
::
SafeCStr
(
sql
)
<<
", error: "
<<
errmsg
;
return
;
}
MS_LOG
(
INFO
)
<<
"Get"
<<
static_cast
<
int
>
(
columns
.
size
())
<<
" records from shard "
<<
shard_id
<<
" index."
;
MS_LOG
(
INFO
)
<<
"Get "
<<
static_cast
<
int
>
(
columns
.
size
())
<<
" records from shard "
<<
shard_id
<<
" index."
;
std
::
lock_guard
<
std
::
mutex
>
lck
(
shard_locker_
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
columns
.
size
());
++
i
)
{
categories
.
emplace
(
columns
[
i
][
0
]);
}
...
...
tests/ut/cpp/mindrecord/ut_common.cc
浏览文件 @
1f222ddb
...
...
@@ -16,9 +16,9 @@
#include "ut_common.h"
using
mindspore
::
MsLogLevel
::
ERROR
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
MsLogLevel
::
ERROR
;
namespace
mindspore
{
namespace
mindrecord
{
...
...
@@ -33,23 +33,6 @@ void Common::SetUp() {}
void
Common
::
TearDown
()
{}
void
Common
::
LoadData
(
const
std
::
string
&
directory
,
std
::
vector
<
json
>
&
json_buffer
,
const
int
max_num
)
{
int
count
=
0
;
string
input_path
=
directory
;
ifstream
infile
(
input_path
);
if
(
!
infile
.
is_open
())
{
MS_LOG
(
ERROR
)
<<
"can not open the file "
;
return
;
}
string
temp
;
while
(
getline
(
infile
,
temp
)
&&
count
!=
max_num
)
{
count
++
;
json
j
=
json
::
parse
(
temp
);
json_buffer
.
push_back
(
j
);
}
infile
.
close
();
}
#ifdef __cplusplus
#if __cplusplus
}
...
...
@@ -70,5 +53,353 @@ const std::string FormatInfo(const std::string &message, uint32_t message_total_
std
::
string
right_padding
(
static_cast
<
uint64_t
>
(
floor
(
padding_length
/
2.0
)),
'='
);
return
left_padding
+
part_message
+
right_padding
;
}
void
LoadData
(
const
std
::
string
&
directory
,
std
::
vector
<
json
>
&
json_buffer
,
const
int
max_num
)
{
int
count
=
0
;
string
input_path
=
directory
;
ifstream
infile
(
input_path
);
if
(
!
infile
.
is_open
())
{
MS_LOG
(
ERROR
)
<<
"can not open the file "
;
return
;
}
string
temp
;
while
(
getline
(
infile
,
temp
)
&&
count
!=
max_num
)
{
count
++
;
json
j
=
json
::
parse
(
temp
);
json_buffer
.
push_back
(
j
);
}
infile
.
close
();
}
void
LoadDataFromImageNet
(
const
std
::
string
&
directory
,
std
::
vector
<
json
>
&
json_buffer
,
const
int
max_num
)
{
int
count
=
0
;
string
input_path
=
directory
;
ifstream
infile
(
input_path
);
if
(
!
infile
.
is_open
())
{
MS_LOG
(
ERROR
)
<<
"can not open the file "
;
return
;
}
string
temp
;
string
filename
;
string
label
;
json
j
;
while
(
getline
(
infile
,
temp
)
&&
count
!=
max_num
)
{
count
++
;
std
::
size_t
pos
=
temp
.
find
(
","
,
0
);
if
(
pos
!=
std
::
string
::
npos
)
{
j
[
"file_name"
]
=
temp
.
substr
(
0
,
pos
);
j
[
"label"
]
=
atoi
(
common
::
SafeCStr
(
temp
.
substr
(
pos
+
1
,
temp
.
length
())));
json_buffer
.
push_back
(
j
);
}
}
infile
.
close
();
}
int
Img2DataUint8
(
const
std
::
vector
<
std
::
string
>
&
img_absolute_path
,
std
::
vector
<
std
::
vector
<
uint8_t
>>
&
bin_data
)
{
for
(
auto
&
file
:
img_absolute_path
)
{
// read image file
std
::
ifstream
in
(
common
::
SafeCStr
(
file
),
std
::
ios
::
in
|
std
::
ios
::
binary
|
std
::
ios
::
ate
);
if
(
!
in
)
{
MS_LOG
(
ERROR
)
<<
common
::
SafeCStr
(
file
)
<<
" is not a directory or not exist!"
;
return
-
1
;
}
// get the file size
uint64_t
size
=
in
.
tellg
();
in
.
seekg
(
0
,
std
::
ios
::
beg
);
std
::
vector
<
uint8_t
>
file_data
(
size
);
in
.
read
(
reinterpret_cast
<
char
*>
(
&
file_data
[
0
]),
size
);
in
.
close
();
bin_data
.
push_back
(
file_data
);
}
return
0
;
}
int
GetAbsoluteFiles
(
std
::
string
directory
,
std
::
vector
<
std
::
string
>
&
files_absolute_path
)
{
DIR
*
dir
=
opendir
(
common
::
SafeCStr
(
directory
));
if
(
dir
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
common
::
SafeCStr
(
directory
)
<<
" is not a directory or not exist!"
;
return
-
1
;
}
struct
dirent
*
d_ent
=
nullptr
;
char
dot
[
3
]
=
"."
;
char
dotdot
[
6
]
=
".."
;
while
((
d_ent
=
readdir
(
dir
))
!=
nullptr
)
{
if
((
strcmp
(
d_ent
->
d_name
,
dot
)
!=
0
)
&&
(
strcmp
(
d_ent
->
d_name
,
dotdot
)
!=
0
))
{
if
(
d_ent
->
d_type
==
DT_DIR
)
{
std
::
string
new_directory
=
directory
+
std
::
string
(
"/"
)
+
std
::
string
(
d_ent
->
d_name
);
if
(
directory
[
directory
.
length
()
-
1
]
==
'/'
)
{
new_directory
=
directory
+
string
(
d_ent
->
d_name
);
}
if
(
-
1
==
GetAbsoluteFiles
(
new_directory
,
files_absolute_path
))
{
closedir
(
dir
);
return
-
1
;
}
}
else
{
std
::
string
absolute_path
=
directory
+
std
::
string
(
"/"
)
+
std
::
string
(
d_ent
->
d_name
);
if
(
directory
[
directory
.
length
()
-
1
]
==
'/'
)
{
absolute_path
=
directory
+
std
::
string
(
d_ent
->
d_name
);
}
files_absolute_path
.
push_back
(
absolute_path
);
}
}
}
closedir
(
dir
);
return
0
;
}
void
ShardWriterImageNet
()
{
MS_LOG
(
INFO
)
<<
common
::
SafeCStr
(
FormatInfo
(
"Write imageNet"
));
// load binary data
std
::
vector
<
std
::
vector
<
uint8_t
>>
bin_data
;
std
::
vector
<
std
::
string
>
filenames
;
if
(
-
1
==
mindrecord
::
GetAbsoluteFiles
(
"./data/mindrecord/testImageNetData/images"
,
filenames
))
{
MS_LOG
(
INFO
)
<<
"-- ATTN -- Missed data directory. Skip this case. -----------------"
;
return
;
}
mindrecord
::
Img2DataUint8
(
filenames
,
bin_data
);
// init shardHeader
ShardHeader
header_data
;
MS_LOG
(
INFO
)
<<
"Init ShardHeader Already."
;
// create schema
json
anno_schema_json
=
R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"
_json
;
std
::
shared_ptr
<
mindrecord
::
Schema
>
anno_schema
=
mindrecord
::
Schema
::
Build
(
"annotation"
,
anno_schema_json
);
if
(
anno_schema
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Build annotation schema failed"
;
return
;
}
// add schema to shardHeader
int
anno_schema_id
=
header_data
.
AddSchema
(
anno_schema
);
MS_LOG
(
INFO
)
<<
"Init Schema Already."
;
// create index
std
::
pair
<
uint64_t
,
std
::
string
>
index_field1
(
anno_schema_id
,
"file_name"
);
std
::
pair
<
uint64_t
,
std
::
string
>
index_field2
(
anno_schema_id
,
"label"
);
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
fields
;
fields
.
push_back
(
index_field1
);
fields
.
push_back
(
index_field2
);
// add index to shardHeader
header_data
.
AddIndexFields
(
fields
);
MS_LOG
(
INFO
)
<<
"Init Index Fields Already."
;
// load meta data
std
::
vector
<
json
>
annotations
;
LoadDataFromImageNet
(
"./data/mindrecord/testImageNetData/annotation.txt"
,
annotations
,
10
);
// add data
std
::
map
<
std
::
uint64_t
,
std
::
vector
<
json
>>
rawdatas
;
rawdatas
.
insert
(
pair
<
uint64_t
,
vector
<
json
>>
(
anno_schema_id
,
annotations
));
MS_LOG
(
INFO
)
<<
"Init Images Already."
;
// init file_writer
std
::
vector
<
std
::
string
>
file_names
;
int
file_count
=
4
;
for
(
int
i
=
1
;
i
<=
file_count
;
i
++
)
{
file_names
.
emplace_back
(
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
));
MS_LOG
(
INFO
)
<<
"shard name is: "
<<
common
::
SafeCStr
(
file_names
[
i
-
1
]);
}
MS_LOG
(
INFO
)
<<
"Init Output Files Already."
;
{
ShardWriter
fw_init
;
fw_init
.
Open
(
file_names
);
// set shardHeader
fw_init
.
SetShardHeader
(
std
::
make_shared
<
mindrecord
::
ShardHeader
>
(
header_data
));
// close file_writer
fw_init
.
Commit
();
}
std
::
string
filename
=
"./imagenet.shard01"
;
{
MS_LOG
(
INFO
)
<<
"=============== images "
<<
bin_data
.
size
()
<<
" ============================"
;
mindrecord
::
ShardWriter
fw
;
fw
.
OpenForAppend
(
filename
);
fw
.
WriteRawData
(
rawdatas
,
bin_data
);
fw
.
Commit
();
}
mindrecord
::
ShardIndexGenerator
sg
{
filename
};
sg
.
Build
();
sg
.
WriteToDatabase
();
MS_LOG
(
INFO
)
<<
"Done create index"
;
}
void
ShardWriterImageNetOneSample
()
{
// load binary data
std
::
vector
<
std
::
vector
<
uint8_t
>>
bin_data
;
std
::
vector
<
std
::
string
>
filenames
;
if
(
-
1
==
mindrecord
::
GetAbsoluteFiles
(
"./data/mindrecord/testImageNetData/images"
,
filenames
))
{
MS_LOG
(
INFO
)
<<
"-- ATTN -- Missed data directory. Skip this case. -----------------"
;
return
;
}
mindrecord
::
Img2DataUint8
(
filenames
,
bin_data
);
// init shardHeader
mindrecord
::
ShardHeader
header_data
;
MS_LOG
(
INFO
)
<<
"Init ShardHeader Already."
;
// create schema
json
anno_schema_json
=
R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"
_json
;
std
::
shared_ptr
<
mindrecord
::
Schema
>
anno_schema
=
mindrecord
::
Schema
::
Build
(
"annotation"
,
anno_schema_json
);
if
(
anno_schema
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Build annotation schema failed"
;
return
;
}
// add schema to shardHeader
int
anno_schema_id
=
header_data
.
AddSchema
(
anno_schema
);
MS_LOG
(
INFO
)
<<
"Init Schema Already."
;
// create index
std
::
pair
<
uint64_t
,
std
::
string
>
index_field1
(
anno_schema_id
,
"file_name"
);
std
::
pair
<
uint64_t
,
std
::
string
>
index_field2
(
anno_schema_id
,
"label"
);
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
fields
;
fields
.
push_back
(
index_field1
);
fields
.
push_back
(
index_field2
);
// add index to shardHeader
header_data
.
AddIndexFields
(
fields
);
MS_LOG
(
INFO
)
<<
"Init Index Fields Already."
;
// load meta data
std
::
vector
<
json
>
annotations
;
LoadDataFromImageNet
(
"./data/mindrecord/testImageNetData/annotation.txt"
,
annotations
,
1
);
// add data
std
::
map
<
std
::
uint64_t
,
std
::
vector
<
json
>>
rawdatas
;
rawdatas
.
insert
(
pair
<
uint64_t
,
vector
<
json
>>
(
anno_schema_id
,
annotations
));
MS_LOG
(
INFO
)
<<
"Init Images Already."
;
// init file_writer
std
::
vector
<
std
::
string
>
file_names
;
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
file_names
.
emplace_back
(
std
::
string
(
"./OneSample.shard0"
)
+
std
::
to_string
(
i
));
MS_LOG
(
INFO
)
<<
"shard name is: "
<<
common
::
SafeCStr
(
file_names
[
i
-
1
]);
}
MS_LOG
(
INFO
)
<<
"Init Output Files Already."
;
{
mindrecord
::
ShardWriter
fw_init
;
fw_init
.
Open
(
file_names
);
// set shardHeader
fw_init
.
SetShardHeader
(
std
::
make_shared
<
mindrecord
::
ShardHeader
>
(
header_data
));
// close file_writer
fw_init
.
Commit
();
}
std
::
string
filename
=
"./OneSample.shard01"
;
{
MS_LOG
(
INFO
)
<<
"=============== images "
<<
bin_data
.
size
()
<<
" ============================"
;
mindrecord
::
ShardWriter
fw
;
fw
.
OpenForAppend
(
filename
);
bin_data
=
std
::
vector
<
std
::
vector
<
uint8_t
>>
(
bin_data
.
begin
(),
bin_data
.
begin
()
+
1
);
fw
.
WriteRawData
(
rawdatas
,
bin_data
);
fw
.
Commit
();
}
mindrecord
::
ShardIndexGenerator
sg
{
filename
};
sg
.
Build
();
sg
.
WriteToDatabase
();
MS_LOG
(
INFO
)
<<
"Done create index"
;
}
void
ShardWriterImageNetOpenForAppend
(
string
filename
)
{
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
string
filename
=
std
::
string
(
"./OpenForAppendSample.shard0"
)
+
std
::
to_string
(
i
);
string
db_name
=
std
::
string
(
"./OpenForAppendSample.shard0"
)
+
std
::
to_string
(
i
)
+
".db"
;
remove
(
common
::
SafeCStr
(
filename
));
remove
(
common
::
SafeCStr
(
db_name
));
}
// load binary data
std
::
vector
<
std
::
vector
<
uint8_t
>>
bin_data
;
std
::
vector
<
std
::
string
>
filenames
;
if
(
-
1
==
mindrecord
::
GetAbsoluteFiles
(
"./data/mindrecord/testImageNetData/images"
,
filenames
))
{
MS_LOG
(
INFO
)
<<
"-- ATTN -- Missed data directory. Skip this case. -----------------"
;
return
;
}
mindrecord
::
Img2DataUint8
(
filenames
,
bin_data
);
// init shardHeader
mindrecord
::
ShardHeader
header_data
;
MS_LOG
(
INFO
)
<<
"Init ShardHeader Already."
;
// create schema
json
anno_schema_json
=
R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"
_json
;
std
::
shared_ptr
<
mindrecord
::
Schema
>
anno_schema
=
mindrecord
::
Schema
::
Build
(
"annotation"
,
anno_schema_json
);
if
(
anno_schema
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Build annotation schema failed"
;
return
;
}
// add schema to shardHeader
int
anno_schema_id
=
header_data
.
AddSchema
(
anno_schema
);
MS_LOG
(
INFO
)
<<
"Init Schema Already."
;
// create index
std
::
pair
<
uint64_t
,
std
::
string
>
index_field1
(
anno_schema_id
,
"file_name"
);
std
::
pair
<
uint64_t
,
std
::
string
>
index_field2
(
anno_schema_id
,
"label"
);
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
fields
;
fields
.
push_back
(
index_field1
);
fields
.
push_back
(
index_field2
);
// add index to shardHeader
header_data
.
AddIndexFields
(
fields
);
MS_LOG
(
INFO
)
<<
"Init Index Fields Already."
;
// load meta data
std
::
vector
<
json
>
annotations
;
LoadDataFromImageNet
(
"./data/mindrecord/testImageNetData/annotation.txt"
,
annotations
,
1
);
// add data
std
::
map
<
std
::
uint64_t
,
std
::
vector
<
json
>>
rawdatas
;
rawdatas
.
insert
(
pair
<
uint64_t
,
vector
<
json
>>
(
anno_schema_id
,
annotations
));
MS_LOG
(
INFO
)
<<
"Init Images Already."
;
// init file_writer
std
::
vector
<
std
::
string
>
file_names
;
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
file_names
.
emplace_back
(
std
::
string
(
"./OpenForAppendSample.shard0"
)
+
std
::
to_string
(
i
));
MS_LOG
(
INFO
)
<<
"shard name is: "
<<
common
::
SafeCStr
(
file_names
[
i
-
1
]);
}
MS_LOG
(
INFO
)
<<
"Init Output Files Already."
;
{
mindrecord
::
ShardWriter
fw_init
;
fw_init
.
Open
(
file_names
);
// set shardHeader
fw_init
.
SetShardHeader
(
std
::
make_shared
<
mindrecord
::
ShardHeader
>
(
header_data
));
// close file_writer
fw_init
.
Commit
();
}
{
MS_LOG
(
INFO
)
<<
"=============== images "
<<
bin_data
.
size
()
<<
" ============================"
;
mindrecord
::
ShardWriter
fw
;
auto
ret
=
fw
.
OpenForAppend
(
filename
);
if
(
ret
==
FAILED
)
{
return
;
}
bin_data
=
std
::
vector
<
std
::
vector
<
uint8_t
>>
(
bin_data
.
begin
(),
bin_data
.
begin
()
+
1
);
fw
.
WriteRawData
(
rawdatas
,
bin_data
);
fw
.
Commit
();
}
ShardIndexGenerator
sg
{
filename
};
sg
.
Build
();
sg
.
WriteToDatabase
();
MS_LOG
(
INFO
)
<<
"Done create index"
;
}
}
// namespace mindrecord
}
// namespace mindspore
tests/ut/cpp/mindrecord/ut_common.h
浏览文件 @
1f222ddb
...
...
@@ -17,6 +17,7 @@
#ifndef TESTS_MINDRECORD_UT_UT_COMMON_H_
#define TESTS_MINDRECORD_UT_UT_COMMON_H_
#include <dirent.h>
#include <fstream>
#include <string>
#include <vector>
...
...
@@ -25,7 +26,9 @@
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "mindrecord/include/shard_index.h"
#include "mindrecord/include/shard_header.h"
#include "mindrecord/include/shard_index_generator.h"
#include "mindrecord/include/shard_writer.h"
using
json
=
nlohmann
::
json
;
using
std
::
ifstream
;
using
std
::
pair
;
...
...
@@ -40,11 +43,10 @@ class Common : public testing::Test {
std
::
string
install_root
;
// every TEST_F macro will enter one
void
SetUp
();
v
irtual
v
oid
SetUp
();
void
TearDown
();
v
irtual
v
oid
TearDown
();
static
void
LoadData
(
const
std
::
string
&
directory
,
std
::
vector
<
json
>
&
json_buffer
,
const
int
max_num
);
};
}
// namespace UT
...
...
@@ -55,6 +57,21 @@ class Common : public testing::Test {
///
/// return the formatted string
const
std
::
string
FormatInfo
(
const
std
::
string
&
message
,
uint32_t
message_total_length
=
128
);
void
LoadData
(
const
std
::
string
&
directory
,
std
::
vector
<
json
>
&
json_buffer
,
const
int
max_num
);
void
LoadDataFromImageNet
(
const
std
::
string
&
directory
,
std
::
vector
<
json
>
&
json_buffer
,
const
int
max_num
);
int
Img2DataUint8
(
const
std
::
vector
<
std
::
string
>
&
img_absolute_path
,
std
::
vector
<
std
::
vector
<
uint8_t
>>
&
bin_data
);
int
GetAbsoluteFiles
(
std
::
string
directory
,
std
::
vector
<
std
::
string
>
&
files_absolute_path
);
void
ShardWriterImageNet
();
void
ShardWriterImageNetOneSample
();
void
ShardWriterImageNetOpenForAppend
(
string
filename
);
}
// namespace mindrecord
}
// namespace mindspore
#endif // TESTS_MINDRECORD_UT_UT_COMMON_H_
tests/ut/cpp/mindrecord/ut_shard.cc
浏览文件 @
1f222ddb
...
...
@@ -29,7 +29,6 @@
#include "mindrecord/include/shard_statistics.h"
#include "securec.h"
#include "ut_common.h"
#include "ut_shard_writer_test.h"
using
mindspore
::
MsLogLevel
::
INFO
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
...
...
@@ -43,7 +42,7 @@ class TestShard : public UT::Common {
};
TEST_F
(
TestShard
,
TestShardSchemaPart
)
{
Test
ShardWriterImageNet
();
ShardWriterImageNet
();
MS_LOG
(
INFO
)
<<
FormatInfo
(
"Test schema"
);
...
...
@@ -55,6 +54,12 @@ TEST_F(TestShard, TestShardSchemaPart) {
ASSERT_TRUE
(
schema
!=
nullptr
);
MS_LOG
(
INFO
)
<<
"schema description: "
<<
schema
->
get_desc
()
<<
", schema: "
<<
common
::
SafeCStr
(
schema
->
GetSchema
().
dump
());
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
string
filename
=
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
);
string
db_name
=
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
)
+
".db"
;
remove
(
common
::
SafeCStr
(
filename
));
remove
(
common
::
SafeCStr
(
db_name
));
}
}
TEST_F
(
TestShard
,
TestStatisticPart
)
{
...
...
@@ -128,6 +133,5 @@ TEST_F(TestShard, TestShardHeaderPart) {
ASSERT_EQ
(
resFields
,
fields
);
}
TEST_F
(
TestShard
,
TestShardWriteImage
)
{
MS_LOG
(
INFO
)
<<
FormatInfo
(
"Test writer"
);
}
}
// namespace mindrecord
}
// namespace mindspore
tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc
浏览文件 @
1f222ddb
...
...
@@ -53,38 +53,6 @@ class TestShardIndexGenerator : public UT::Common {
TestShardIndexGenerator
()
{}
};
/*
TEST_F(TestShardIndexGenerator, GetField) {
MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field");
int max_num = 1;
string input_path1 = install_root + "/test/testCBGData/data/annotation.data";
std::vector<json> json_buffer1; // store the image_raw_meta.data
Common::LoadData(input_path1, json_buffer1, max_num);
MS_LOG(INFO) << "Fetch fields: ";
for (auto &j : json_buffer1) {
auto v_name = ShardIndexGenerator::GetField("anno_tool", j);
auto v_attr_name = ShardIndexGenerator::GetField("entity_instances.attributes.attr_name", j);
auto v_entity_name = ShardIndexGenerator::GetField("entity_instances.entity_name", j);
vector<string> names = {"\"CVAT\""};
for (unsigned int i = 0; i != names.size(); i++) {
ASSERT_EQ(names[i], v_name[i]);
}
vector<string> attr_names = {"\"脸部评分\"", "\"特征点\"", "\"points_example\"", "\"polyline_example\"",
"\"polyline_example\""};
for (unsigned int i = 0; i != attr_names.size(); i++) {
ASSERT_EQ(attr_names[i], v_attr_name[i]);
}
vector<string> entity_names = {"\"276点人脸\"", "\"points_example\"", "\"polyline_example\"",
"\"polyline_example\""};
for (unsigned int i = 0; i != entity_names.size(); i++) {
ASSERT_EQ(entity_names[i], v_entity_name[i]);
}
}
}
*/
TEST_F
(
TestShardIndexGenerator
,
TakeFieldType
)
{
MS_LOG
(
INFO
)
<<
FormatInfo
(
"Test ShardSchema: take field Type"
);
...
...
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
浏览文件 @
1f222ddb
...
...
@@ -40,6 +40,17 @@ namespace mindrecord {
class
TestShardOperator
:
public
UT
::
Common
{
public:
TestShardOperator
()
{}
void
SetUp
()
override
{
ShardWriterImageNet
();
}
void
TearDown
()
override
{
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
string
filename
=
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
);
string
db_name
=
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
)
+
".db"
;
remove
(
common
::
SafeCStr
(
filename
));
remove
(
common
::
SafeCStr
(
db_name
));
}
}
};
TEST_F
(
TestShardOperator
,
TestShardSampleBasic
)
{
...
...
@@ -165,7 +176,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
std
::
cout
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
())
<<
std
::
endl
;
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
())
<<
std
::
endl
;
i
++
;
}
dataset
.
Finish
();
...
...
@@ -191,7 +202,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
if
(
x
.
empty
())
break
;
std
::
cout
<<
"index: "
<<
i
<<
", filename: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"file_name"
])
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
())
<<
std
::
endl
;
<<
", label: "
<<
common
::
SafeCStr
((
std
::
get
<
1
>
(
x
[
0
]))[
"label"
].
dump
())
<<
std
::
endl
;
i
++
;
}
dataset
.
Finish
();
...
...
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
浏览文件 @
1f222ddb
...
...
@@ -37,6 +37,16 @@ namespace mindrecord {
class
TestShardReader
:
public
UT
::
Common
{
public:
TestShardReader
()
{}
void
SetUp
()
override
{
ShardWriterImageNet
();
}
void
TearDown
()
override
{
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
string
filename
=
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
);
string
db_name
=
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
)
+
".db"
;
remove
(
common
::
SafeCStr
(
filename
));
remove
(
common
::
SafeCStr
(
db_name
));
}
}
};
TEST_F
(
TestShardReader
,
TestShardReaderGeneral
)
{
...
...
@@ -51,8 +61,8 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
MS_LOG
(
INFO
)
<<
"key: "
<<
item
.
key
()
<<
", value: "
<<
item
.
value
().
dump
();
}
}
...
...
@@ -74,8 +84,8 @@ TEST_F(TestShardReader, TestShardReaderSample) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
MS_LOG
(
INFO
)
<<
"key: "
<<
item
.
key
()
<<
", value: "
<<
item
.
value
().
dump
();
}
}
...
...
@@ -99,8 +109,8 @@ TEST_F(TestShardReader, TestShardReaderBlock) {
while
(
true
)
{
auto
x
=
dataset
.
GetBlockNext
();
if
(
x
.
empty
())
break
;
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
MS_LOG
(
INFO
)
<<
"key: "
<<
item
.
key
()
<<
", value: "
<<
item
.
value
().
dump
();
}
}
...
...
@@ -119,8 +129,8 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
MS_LOG
(
INFO
)
<<
"key: "
<<
item
.
key
()
<<
", value: "
<<
item
.
value
().
dump
();
}
}
...
...
@@ -140,8 +150,8 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
MS_LOG
(
INFO
)
<<
"key: "
<<
item
.
key
()
<<
", value: "
<<
item
.
value
().
dump
();
}
}
...
...
@@ -169,9 +179,9 @@ TEST_F(TestShardReader, TestShardVersion) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
for
(
auto
&
j
:
x
)
{
for
(
auto
&
j
:
x
)
{
MS_LOG
(
INFO
)
<<
"result size: "
<<
std
::
get
<
0
>
(
j
).
size
();
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
MS_LOG
(
INFO
)
<<
"key: "
<<
common
::
SafeCStr
(
item
.
key
())
<<
", value: "
<<
common
::
SafeCStr
(
item
.
value
().
dump
());
}
}
...
...
@@ -201,8 +211,8 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
while
(
true
)
{
auto
x
=
dataset
.
GetNext
();
if
(
x
.
empty
())
break
;
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
for
(
auto
&
j
:
x
)
{
for
(
auto
&
item
:
std
::
get
<
1
>
(
j
).
items
())
{
MS_LOG
(
INFO
)
<<
"key: "
<<
common
::
SafeCStr
(
item
.
key
())
<<
", value: "
<<
common
::
SafeCStr
(
item
.
value
().
dump
());
}
}
...
...
tests/ut/cpp/mindrecord/ut_shard_segment_test.cc
浏览文件 @
1f222ddb
...
...
@@ -33,15 +33,25 @@
#include "mindrecord/include/shard_segment.h"
#include "ut_common.h"
using
mindspore
::
MsLogLevel
::
INFO
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
MsLogLevel
::
INFO
;
namespace
mindspore
{
namespace
mindrecord
{
class
TestShardSegment
:
public
UT
::
Common
{
public:
TestShardSegment
()
{}
void
SetUp
()
override
{
ShardWriterImageNet
();
}
void
TearDown
()
override
{
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
string
filename
=
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
);
string
db_name
=
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
)
+
".db"
;
remove
(
common
::
SafeCStr
(
filename
));
remove
(
common
::
SafeCStr
(
db_name
));
}
}
};
TEST_F
(
TestShardSegment
,
TestShardSegment
)
{
...
...
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
浏览文件 @
1f222ddb
...
...
@@ -16,7 +16,6 @@
#include <chrono>
#include <cstring>
#include <dirent.h>
#include <iostream>
#include <memory>
#include <string>
...
...
@@ -30,7 +29,6 @@
#include "mindrecord/include/shard_index_generator.h"
#include "securec.h"
#include "ut_common.h"
#include "ut_shard_writer_test.h"
using
mindspore
::
LogStream
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
...
...
@@ -44,249 +42,10 @@ class TestShardWriter : public UT::Common {
TestShardWriter
()
{}
};
void
LoadDataFromImageNet
(
const
std
::
string
&
directory
,
std
::
vector
<
json
>
&
json_buffer
,
const
int
max_num
)
{
int
count
=
0
;
string
input_path
=
directory
;
ifstream
infile
(
input_path
);
if
(
!
infile
.
is_open
())
{
MS_LOG
(
ERROR
)
<<
"can not open the file "
;
return
;
}
string
temp
;
string
filename
;
string
label
;
json
j
;
while
(
getline
(
infile
,
temp
)
&&
count
!=
max_num
)
{
count
++
;
std
::
size_t
pos
=
temp
.
find
(
","
,
0
);
if
(
pos
!=
std
::
string
::
npos
)
{
j
[
"file_name"
]
=
temp
.
substr
(
0
,
pos
);
j
[
"label"
]
=
atoi
(
common
::
SafeCStr
(
temp
.
substr
(
pos
+
1
,
temp
.
length
())));
json_buffer
.
push_back
(
j
);
}
}
infile
.
close
();
}
int
Img2DataUint8
(
const
std
::
vector
<
std
::
string
>
&
img_absolute_path
,
std
::
vector
<
std
::
vector
<
uint8_t
>>
&
bin_data
)
{
for
(
auto
&
file
:
img_absolute_path
)
{
// read image file
std
::
ifstream
in
(
common
::
SafeCStr
(
file
),
std
::
ios
::
in
|
std
::
ios
::
binary
|
std
::
ios
::
ate
);
if
(
!
in
)
{
MS_LOG
(
ERROR
)
<<
common
::
SafeCStr
(
file
)
<<
" is not a directory or not exist!"
;
return
-
1
;
}
// get the file size
uint64_t
size
=
in
.
tellg
();
in
.
seekg
(
0
,
std
::
ios
::
beg
);
std
::
vector
<
uint8_t
>
file_data
(
size
);
in
.
read
(
reinterpret_cast
<
char
*>
(
&
file_data
[
0
]),
size
);
in
.
close
();
bin_data
.
push_back
(
file_data
);
}
return
0
;
}
int
GetAbsoluteFiles
(
std
::
string
directory
,
std
::
vector
<
std
::
string
>
&
files_absolute_path
)
{
DIR
*
dir
=
opendir
(
common
::
SafeCStr
(
directory
));
if
(
dir
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
common
::
SafeCStr
(
directory
)
<<
" is not a directory or not exist!"
;
return
-
1
;
}
struct
dirent
*
d_ent
=
nullptr
;
char
dot
[
3
]
=
"."
;
char
dotdot
[
6
]
=
".."
;
while
((
d_ent
=
readdir
(
dir
))
!=
nullptr
)
{
if
((
strcmp
(
d_ent
->
d_name
,
dot
)
!=
0
)
&&
(
strcmp
(
d_ent
->
d_name
,
dotdot
)
!=
0
))
{
if
(
d_ent
->
d_type
==
DT_DIR
)
{
std
::
string
new_directory
=
directory
+
std
::
string
(
"/"
)
+
std
::
string
(
d_ent
->
d_name
);
if
(
directory
[
directory
.
length
()
-
1
]
==
'/'
)
{
new_directory
=
directory
+
string
(
d_ent
->
d_name
);
}
if
(
-
1
==
GetAbsoluteFiles
(
new_directory
,
files_absolute_path
))
{
closedir
(
dir
);
return
-
1
;
}
}
else
{
std
::
string
absolute_path
=
directory
+
std
::
string
(
"/"
)
+
std
::
string
(
d_ent
->
d_name
);
if
(
directory
[
directory
.
length
()
-
1
]
==
'/'
)
{
absolute_path
=
directory
+
std
::
string
(
d_ent
->
d_name
);
}
files_absolute_path
.
push_back
(
absolute_path
);
}
}
}
closedir
(
dir
);
return
0
;
}
void
TestShardWriterImageNet
()
{
MS_LOG
(
INFO
)
<<
common
::
SafeCStr
(
FormatInfo
(
"Write imageNet"
));
// load binary data
std
::
vector
<
std
::
vector
<
uint8_t
>>
bin_data
;
std
::
vector
<
std
::
string
>
filenames
;
if
(
-
1
==
mindrecord
::
GetAbsoluteFiles
(
"./data/mindrecord/testImageNetData/images"
,
filenames
))
{
MS_LOG
(
INFO
)
<<
"-- ATTN -- Missed data directory. Skip this case. -----------------"
;
return
;
}
mindrecord
::
Img2DataUint8
(
filenames
,
bin_data
);
// init shardHeader
mindrecord
::
ShardHeader
header_data
;
MS_LOG
(
INFO
)
<<
"Init ShardHeader Already."
;
// create schema
json
anno_schema_json
=
R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"
_json
;
std
::
shared_ptr
<
mindrecord
::
Schema
>
anno_schema
=
mindrecord
::
Schema
::
Build
(
"annotation"
,
anno_schema_json
);
if
(
anno_schema
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Build annotation schema failed"
;
return
;
}
// add schema to shardHeader
int
anno_schema_id
=
header_data
.
AddSchema
(
anno_schema
);
MS_LOG
(
INFO
)
<<
"Init Schema Already."
;
// create index
std
::
pair
<
uint64_t
,
std
::
string
>
index_field1
(
anno_schema_id
,
"file_name"
);
std
::
pair
<
uint64_t
,
std
::
string
>
index_field2
(
anno_schema_id
,
"label"
);
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
fields
;
fields
.
push_back
(
index_field1
);
fields
.
push_back
(
index_field2
);
// add index to shardHeader
header_data
.
AddIndexFields
(
fields
);
MS_LOG
(
INFO
)
<<
"Init Index Fields Already."
;
// load meta data
std
::
vector
<
json
>
annotations
;
LoadDataFromImageNet
(
"./data/mindrecord/testImageNetData/annotation.txt"
,
annotations
,
10
);
// add data
std
::
map
<
std
::
uint64_t
,
std
::
vector
<
json
>>
rawdatas
;
rawdatas
.
insert
(
pair
<
uint64_t
,
vector
<
json
>>
(
anno_schema_id
,
annotations
));
MS_LOG
(
INFO
)
<<
"Init Images Already."
;
// init file_writer
std
::
vector
<
std
::
string
>
file_names
;
int
file_count
=
4
;
for
(
int
i
=
1
;
i
<=
file_count
;
i
++
)
{
file_names
.
emplace_back
(
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
));
MS_LOG
(
INFO
)
<<
"shard name is: "
<<
common
::
SafeCStr
(
file_names
[
i
-
1
]);
}
MS_LOG
(
INFO
)
<<
"Init Output Files Already."
;
{
mindrecord
::
ShardWriter
fw_init
;
fw_init
.
Open
(
file_names
);
// set shardHeader
fw_init
.
SetShardHeader
(
std
::
make_shared
<
mindrecord
::
ShardHeader
>
(
header_data
));
// close file_writer
fw_init
.
Commit
();
}
std
::
string
filename
=
"./imagenet.shard01"
;
{
MS_LOG
(
INFO
)
<<
"=============== images "
<<
bin_data
.
size
()
<<
" ============================"
;
mindrecord
::
ShardWriter
fw
;
fw
.
OpenForAppend
(
filename
);
fw
.
WriteRawData
(
rawdatas
,
bin_data
);
fw
.
Commit
();
}
mindrecord
::
ShardIndexGenerator
sg
{
filename
};
sg
.
Build
();
sg
.
WriteToDatabase
();
MS_LOG
(
INFO
)
<<
"Done create index"
;
}
void
TestShardWriterImageNetOneSample
()
{
// load binary data
std
::
vector
<
std
::
vector
<
uint8_t
>>
bin_data
;
std
::
vector
<
std
::
string
>
filenames
;
if
(
-
1
==
mindrecord
::
GetAbsoluteFiles
(
"./data/mindrecord/testImageNetData/images"
,
filenames
))
{
MS_LOG
(
INFO
)
<<
"-- ATTN -- Missed data directory. Skip this case. -----------------"
;
return
;
}
mindrecord
::
Img2DataUint8
(
filenames
,
bin_data
);
// init shardHeader
mindrecord
::
ShardHeader
header_data
;
MS_LOG
(
INFO
)
<<
"Init ShardHeader Already."
;
// create schema
json
anno_schema_json
=
R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"
_json
;
std
::
shared_ptr
<
mindrecord
::
Schema
>
anno_schema
=
mindrecord
::
Schema
::
Build
(
"annotation"
,
anno_schema_json
);
if
(
anno_schema
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Build annotation schema failed"
;
return
;
}
// add schema to shardHeader
int
anno_schema_id
=
header_data
.
AddSchema
(
anno_schema
);
MS_LOG
(
INFO
)
<<
"Init Schema Already."
;
// create index
std
::
pair
<
uint64_t
,
std
::
string
>
index_field1
(
anno_schema_id
,
"file_name"
);
std
::
pair
<
uint64_t
,
std
::
string
>
index_field2
(
anno_schema_id
,
"label"
);
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
fields
;
fields
.
push_back
(
index_field1
);
fields
.
push_back
(
index_field2
);
// add index to shardHeader
header_data
.
AddIndexFields
(
fields
);
MS_LOG
(
INFO
)
<<
"Init Index Fields Already."
;
// load meta data
std
::
vector
<
json
>
annotations
;
LoadDataFromImageNet
(
"./data/mindrecord/testImageNetData/annotation.txt"
,
annotations
,
1
);
// add data
std
::
map
<
std
::
uint64_t
,
std
::
vector
<
json
>>
rawdatas
;
rawdatas
.
insert
(
pair
<
uint64_t
,
vector
<
json
>>
(
anno_schema_id
,
annotations
));
MS_LOG
(
INFO
)
<<
"Init Images Already."
;
// init file_writer
std
::
vector
<
std
::
string
>
file_names
;
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
file_names
.
emplace_back
(
std
::
string
(
"./OneSample.shard0"
)
+
std
::
to_string
(
i
));
MS_LOG
(
INFO
)
<<
"shard name is: "
<<
common
::
SafeCStr
(
file_names
[
i
-
1
]);
}
MS_LOG
(
INFO
)
<<
"Init Output Files Already."
;
{
mindrecord
::
ShardWriter
fw_init
;
fw_init
.
Open
(
file_names
);
// set shardHeader
fw_init
.
SetShardHeader
(
std
::
make_shared
<
mindrecord
::
ShardHeader
>
(
header_data
));
// close file_writer
fw_init
.
Commit
();
}
std
::
string
filename
=
"./OneSample.shard01"
;
{
MS_LOG
(
INFO
)
<<
"=============== images "
<<
bin_data
.
size
()
<<
" ============================"
;
mindrecord
::
ShardWriter
fw
;
fw
.
OpenForAppend
(
filename
);
bin_data
=
std
::
vector
<
std
::
vector
<
uint8_t
>>
(
bin_data
.
begin
(),
bin_data
.
begin
()
+
1
);
fw
.
WriteRawData
(
rawdatas
,
bin_data
);
fw
.
Commit
();
}
mindrecord
::
ShardIndexGenerator
sg
{
filename
};
sg
.
Build
();
sg
.
WriteToDatabase
();
MS_LOG
(
INFO
)
<<
"Done create index"
;
}
TEST_F
(
TestShardWriter
,
TestShardWriterBench
)
{
MS_LOG
(
INFO
)
<<
common
::
SafeCStr
(
FormatInfo
(
"Test write imageNet"
));
Test
ShardWriterImageNet
();
ShardWriterImageNet
();
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
string
filename
=
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
);
string
db_name
=
std
::
string
(
"./imagenet.shard0"
)
+
std
::
to_string
(
i
)
+
".db"
;
...
...
@@ -297,7 +56,7 @@ TEST_F(TestShardWriter, TestShardWriterBench) {
TEST_F
(
TestShardWriter
,
TestShardWriterOneSample
)
{
MS_LOG
(
INFO
)
<<
common
::
SafeCStr
(
FormatInfo
(
"Test write imageNet int32 of sample less than num of shards"
));
Test
ShardWriterImageNetOneSample
();
ShardWriterImageNetOneSample
();
std
::
string
filename
=
"./OneSample.shard01"
;
ShardReader
dataset
;
...
...
@@ -342,7 +101,7 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) {
std
::
vector
<
std
::
string
>
image_filenames
;
// save all files' path within path_dir
// read image_raw_meta.data
Common
::
LoadData
(
input_path1
,
json_buffer1
,
kMaxNum
);
LoadData
(
input_path1
,
json_buffer1
,
kMaxNum
);
MS_LOG
(
INFO
)
<<
"Load Meta Data Already."
;
// get files' pathes stored in vector<string> image_filenames
...
...
@@ -375,7 +134,7 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) {
MS_LOG
(
INFO
)
<<
"Init Schema Already."
;
// create/init statistics
Common
::
LoadData
(
input_path3
,
json_buffer4
,
2
);
LoadData
(
input_path3
,
json_buffer4
,
2
);
json
static1_json
=
json_buffer4
[
0
];
json
static2_json
=
json_buffer4
[
1
];
MS_LOG
(
INFO
)
<<
"Initial statistics 1 is: "
<<
common
::
SafeCStr
(
static1_json
.
dump
());
...
...
@@ -474,7 +233,7 @@ TEST_F(TestShardWriter, TestShardWriterTrial) {
std
::
vector
<
std
::
string
>
image_filenames
;
// save all files' path within path_dir
// read image_raw_meta.data
Common
::
LoadData
(
input_path1
,
json_buffer1
,
kMaxNum
);
LoadData
(
input_path1
,
json_buffer1
,
kMaxNum
);
MS_LOG
(
INFO
)
<<
"Load Meta Data Already."
;
// get files' pathes stored in vector<string> image_filenames
...
...
@@ -508,7 +267,7 @@ TEST_F(TestShardWriter, TestShardWriterTrial) {
MS_LOG
(
INFO
)
<<
"Init Schema Already."
;
// create/init statistics
Common
::
LoadData
(
input_path3
,
json_buffer4
,
2
);
LoadData
(
input_path3
,
json_buffer4
,
2
);
json
static1_json
=
json_buffer4
[
0
];
json
static2_json
=
json_buffer4
[
1
];
MS_LOG
(
INFO
)
<<
"Initial statistics 1 is: "
<<
common
::
SafeCStr
(
static1_json
.
dump
());
...
...
@@ -613,7 +372,7 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) {
std
::
vector
<
std
::
string
>
image_filenames
;
// save all files' path within path_dir
// read image_raw_meta.data
Common
::
LoadData
(
input_path1
,
json_buffer1
,
kMaxNum
);
LoadData
(
input_path1
,
json_buffer1
,
kMaxNum
);
MS_LOG
(
INFO
)
<<
"Load Meta Data Already."
;
// get files' pathes stored in vector<string> image_filenames
...
...
@@ -644,7 +403,7 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) {
MS_LOG
(
INFO
)
<<
"Init Schema Already."
;
// create/init statistics
Common
::
LoadData
(
input_path3
,
json_buffer4
,
2
);
LoadData
(
input_path3
,
json_buffer4
,
2
);
json
static1_json
=
json_buffer4
[
0
];
json
static2_json
=
json_buffer4
[
1
];
MS_LOG
(
INFO
)
<<
"Initial statistics 1 is: "
<<
common
::
SafeCStr
(
static1_json
.
dump
());
...
...
@@ -1357,107 +1116,24 @@ TEST_F(TestShardWriter, TestWriteOpenFileName) {
}
}
void
TestShardWriterImageNetOpenForAppend
(
string
filename
)
{
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
string
filename
=
std
::
string
(
"./OpenForAppendSample.shard0"
)
+
std
::
to_string
(
i
);
string
db_name
=
std
::
string
(
"./OpenForAppendSample.shard0"
)
+
std
::
to_string
(
i
)
+
".db"
;
remove
(
common
::
SafeCStr
(
filename
));
remove
(
common
::
SafeCStr
(
db_name
));
}
// load binary data
std
::
vector
<
std
::
vector
<
uint8_t
>>
bin_data
;
std
::
vector
<
std
::
string
>
filenames
;
if
(
-
1
==
mindrecord
::
GetAbsoluteFiles
(
"./data/mindrecord/testImageNetData/images"
,
filenames
))
{
MS_LOG
(
INFO
)
<<
"-- ATTN -- Missed data directory. Skip this case. -----------------"
;
return
;
}
mindrecord
::
Img2DataUint8
(
filenames
,
bin_data
);
// init shardHeader
mindrecord
::
ShardHeader
header_data
;
MS_LOG
(
INFO
)
<<
"Init ShardHeader Already."
;
// create schema
json
anno_schema_json
=
R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"
_json
;
std
::
shared_ptr
<
mindrecord
::
Schema
>
anno_schema
=
mindrecord
::
Schema
::
Build
(
"annotation"
,
anno_schema_json
);
if
(
anno_schema
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Build annotation schema failed"
;
return
;
}
// add schema to shardHeader
int
anno_schema_id
=
header_data
.
AddSchema
(
anno_schema
);
MS_LOG
(
INFO
)
<<
"Init Schema Already."
;
// create index
std
::
pair
<
uint64_t
,
std
::
string
>
index_field1
(
anno_schema_id
,
"file_name"
);
std
::
pair
<
uint64_t
,
std
::
string
>
index_field2
(
anno_schema_id
,
"label"
);
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
fields
;
fields
.
push_back
(
index_field1
);
fields
.
push_back
(
index_field2
);
// add index to shardHeader
header_data
.
AddIndexFields
(
fields
);
MS_LOG
(
INFO
)
<<
"Init Index Fields Already."
;
// load meta data
std
::
vector
<
json
>
annotations
;
LoadDataFromImageNet
(
"./data/mindrecord/testImageNetData/annotation.txt"
,
annotations
,
1
);
// add data
std
::
map
<
std
::
uint64_t
,
std
::
vector
<
json
>>
rawdatas
;
rawdatas
.
insert
(
pair
<
uint64_t
,
vector
<
json
>>
(
anno_schema_id
,
annotations
));
MS_LOG
(
INFO
)
<<
"Init Images Already."
;
// init file_writer
std
::
vector
<
std
::
string
>
file_names
;
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
file_names
.
emplace_back
(
std
::
string
(
"./OpenForAppendSample.shard0"
)
+
std
::
to_string
(
i
));
MS_LOG
(
INFO
)
<<
"shard name is: "
<<
common
::
SafeCStr
(
file_names
[
i
-
1
]);
}
MS_LOG
(
INFO
)
<<
"Init Output Files Already."
;
{
mindrecord
::
ShardWriter
fw_init
;
fw_init
.
Open
(
file_names
);
// set shardHeader
fw_init
.
SetShardHeader
(
std
::
make_shared
<
mindrecord
::
ShardHeader
>
(
header_data
));
// close file_writer
fw_init
.
Commit
();
}
{
MS_LOG
(
INFO
)
<<
"=============== images "
<<
bin_data
.
size
()
<<
" ============================"
;
mindrecord
::
ShardWriter
fw
;
auto
ret
=
fw
.
OpenForAppend
(
filename
);
if
(
ret
==
FAILED
)
{
return
;
}
bin_data
=
std
::
vector
<
std
::
vector
<
uint8_t
>>
(
bin_data
.
begin
(),
bin_data
.
begin
()
+
1
);
fw
.
WriteRawData
(
rawdatas
,
bin_data
);
fw
.
Commit
();
}
mindrecord
::
ShardIndexGenerator
sg
{
filename
};
sg
.
Build
();
sg
.
WriteToDatabase
();
MS_LOG
(
INFO
)
<<
"Done create index"
;
}
TEST_F
(
TestShardWriter
,
TestOpenForAppend
)
{
MS_LOG
(
INFO
)
<<
"start ---- TestOpenForAppend
\n
"
;
string
filename
=
"./"
;
Test
ShardWriterImageNetOpenForAppend
(
filename
);
ShardWriterImageNetOpenForAppend
(
filename
);
string
filename1
=
"./▒AppendSample.shard01"
;
Test
ShardWriterImageNetOpenForAppend
(
filename1
);
ShardWriterImageNetOpenForAppend
(
filename1
);
string
filename2
=
"./ä
\xA9
ü"
;
TestShardWriterImageNetOpenForAppend
(
filename2
);
ShardWriterImageNetOpenForAppend
(
filename2
);
MS_LOG
(
INFO
)
<<
"end ---- TestOpenForAppend
\n
"
;
for
(
int
i
=
1
;
i
<=
4
;
i
++
)
{
string
filename
=
std
::
string
(
"./OpenForAppendSample.shard0"
)
+
std
::
to_string
(
i
);
string
db_name
=
std
::
string
(
"./OpenForAppendSample.shard0"
)
+
std
::
to_string
(
i
)
+
".db"
;
remove
(
common
::
SafeCStr
(
filename
));
remove
(
common
::
SafeCStr
(
db_name
));
}
}
}
// namespace mindrecord
...
...
tests/ut/cpp/mindrecord/ut_shard_writer_test.h
已删除
100644 → 0
浏览文件 @
ebc3f12b
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TESTS_MINDRECORD_UT_SHARDWRITER_H
#define TESTS_MINDRECORD_UT_SHARDWRITER_H
namespace
mindspore
{
namespace
mindrecord
{
void
TestShardWriterImageNet
();
}
// namespace mindrecord
}
// namespace mindspore
#endif // TESTS_MINDRECORD_UT_SHARDWRITER_H
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录