Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
decf12cd
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
decf12cd
编写于
5月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1317 [MD]add compress for nlp data in mindrecord
Merge pull request !1317 from liyong126/mindrecord_compress
上级
274f6f01
bb51bb88
变更
27
展开全部
隐藏空白更改
内联
并排
Showing
27 changed file
with
1228 addition
and
746 deletion
+1228
-746
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
...e/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
+54
-317
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
...re/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
+5
-50
mindspore/ccsrc/mindrecord/common/shard_pybind.cc
mindspore/ccsrc/mindrecord/common/shard_pybind.cc
+2
-2
mindspore/ccsrc/mindrecord/include/common/shard_utils.h
mindspore/ccsrc/mindrecord/include/common/shard_utils.h
+3
-0
mindspore/ccsrc/mindrecord/include/shard_column.h
mindspore/ccsrc/mindrecord/include/shard_column.h
+163
-0
mindspore/ccsrc/mindrecord/include/shard_header.h
mindspore/ccsrc/mindrecord/include/shard_header.h
+0
-3
mindspore/ccsrc/mindrecord/include/shard_reader.h
mindspore/ccsrc/mindrecord/include/shard_reader.h
+11
-4
mindspore/ccsrc/mindrecord/include/shard_writer.h
mindspore/ccsrc/mindrecord/include/shard_writer.h
+3
-1
mindspore/ccsrc/mindrecord/io/shard_reader.cc
mindspore/ccsrc/mindrecord/io/shard_reader.cc
+41
-68
mindspore/ccsrc/mindrecord/io/shard_writer.cc
mindspore/ccsrc/mindrecord/io/shard_writer.cc
+10
-0
mindspore/ccsrc/mindrecord/meta/shard_column.cc
mindspore/ccsrc/mindrecord/meta/shard_column.cc
+473
-0
mindspore/ccsrc/mindrecord/meta/shard_header.cc
mindspore/ccsrc/mindrecord/meta/shard_header.cc
+3
-3
mindspore/mindrecord/shardutils.py
mindspore/mindrecord/shardutils.py
+7
-28
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord0
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord0
+0
-0
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord0.db
.../ut/data/mindrecord/testOldVersion/aclImdb.mindrecord0.db
+0
-0
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord1
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord1
+0
-0
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord1.db
.../ut/data/mindrecord/testOldVersion/aclImdb.mindrecord1.db
+0
-0
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord2
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord2
+0
-0
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord2.db
.../ut/data/mindrecord/testOldVersion/aclImdb.mindrecord2.db
+0
-0
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord3
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord3
+0
-0
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord3.db
.../ut/data/mindrecord/testOldVersion/aclImdb.mindrecord3.db
+0
-0
tests/ut/python/dataset/test_minddataset.py
tests/ut/python/dataset/test_minddataset.py
+288
-96
tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py
tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py
+24
-29
tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py
tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py
+42
-32
tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py
tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py
+25
-15
tests/ut/python/mindrecord/test_mindrecord_exception.py
tests/ut/python/mindrecord/test_mindrecord_exception.py
+42
-78
tests/ut/python/mindrecord/test_mnist_to_mr.py
tests/ut/python/mindrecord/test_mnist_to_mr.py
+32
-20
未找到文件。
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc
浏览文件 @
decf12cd
...
@@ -112,25 +112,26 @@ Status MindRecordOp::Init() {
...
@@ -112,25 +112,26 @@ Status MindRecordOp::Init() {
data_schema_
=
std
::
make_unique
<
DataSchema
>
();
data_schema_
=
std
::
make_unique
<
DataSchema
>
();
std
::
vector
<
std
::
s
hared_ptr
<
Schema
>>
schema_vec
=
shard_reader_
->
GetShardHeader
()
->
GetSchemas
();
std
::
vector
<
std
::
s
tring
>
col_names
=
shard_reader_
->
get_shard_column
()
->
GetColumnName
();
// check whether schema exists, if so use the first one
CHECK_FAIL_RETURN_UNEXPECTED
(
!
col_names
.
empty
(),
"No schema found"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
!
schema_vec
.
empty
(),
"No schema found"
);
std
::
vector
<
mindrecord
::
ColumnDataType
>
col_data_types
=
shard_reader_
->
get_shard_column
()
->
GeColumnDataType
(
);
mindrecord
::
json
mr_schema
=
schema_vec
[
0
]
->
GetSchema
()[
"schema"
]
;
std
::
vector
<
std
::
vector
<
int64_t
>>
col_shapes
=
shard_reader_
->
get_shard_column
()
->
GetColumnShape
()
;
bool
load_all_cols
=
columns_to_load_
.
empty
();
// if columns_to_load_ is empty it means load everything
bool
load_all_cols
=
columns_to_load_
.
empty
();
// if columns_to_load_ is empty it means load everything
std
::
map
<
std
::
string
,
int32_t
>
colname_to_ind
;
std
::
map
<
std
::
string
,
int32_t
>
colname_to_ind
;
for
(
mindrecord
::
json
::
iterator
it
=
mr_schema
.
begin
();
it
!=
mr_schema
.
end
();
++
it
)
{
for
(
uint32_t
i
=
0
;
i
<
col_names
.
size
();
i
++
)
{
std
::
string
colname
=
it
.
key
();
// key of the json, column name
std
::
string
colname
=
col_names
[
i
];
mindrecord
::
json
it_value
=
it
.
value
();
// value, which contains type info and may contain shape
ColDescriptor
col_desc
;
ColDescriptor
col_desc
;
TensorShape
t_shape
=
TensorShape
::
CreateUnknownRankShape
();
// shape of tensor, default unknown
TensorShape
t_shape
=
TensorShape
::
CreateUnknownRankShape
();
// shape of tensor, default unknown
std
::
string
type_str
=
(
it_value
[
"type"
]
==
"bytes"
||
it_value
[
"type"
]
==
"string"
)
?
"uint8"
:
it_value
[
"type"
];
std
::
string
type_str
=
mindrecord
::
ColumnDataTypeNameNormalized
[
col_data_types
[
i
]
];
DataType
t_dtype
=
DataType
(
type_str
);
// valid types: {"bytes", "string", "int32", "int64", "float32", "float64"}
DataType
t_dtype
=
DataType
(
type_str
);
// valid types: {"bytes", "string", "int32", "int64", "float32", "float64"}
if
(
it_value
[
"type"
]
==
"bytes"
)
{
// rank = 1
if
(
col_data_types
[
i
]
==
mindrecord
::
ColumnBytes
||
col_data_types
[
i
]
==
mindrecord
::
ColumnString
)
{
// rank = 1
col_desc
=
ColDescriptor
(
colname
,
t_dtype
,
TensorImpl
::
kFlexible
,
1
);
col_desc
=
ColDescriptor
(
colname
,
t_dtype
,
TensorImpl
::
kFlexible
,
1
);
}
else
if
(
it_value
.
find
(
"shape"
)
!=
it_value
.
end
()
)
{
}
else
if
(
col_shapes
[
i
].
size
()
>
0
)
{
std
::
vector
<
dsize_t
>
vec
(
it_value
[
"shape"
].
size
());
// temporary vector to hold shape
std
::
vector
<
dsize_t
>
vec
(
col_shapes
[
i
].
size
());
// temporary vector to hold shape
(
void
)
std
::
copy
(
it_value
[
"shape"
].
begin
(),
it_value
[
"shape"
].
end
(),
vec
.
begin
());
(
void
)
std
::
copy
(
col_shapes
[
i
].
begin
(),
col_shapes
[
i
].
end
(),
vec
.
begin
());
t_shape
=
TensorShape
(
vec
);
t_shape
=
TensorShape
(
vec
);
col_desc
=
ColDescriptor
(
colname
,
t_dtype
,
TensorImpl
::
kFlexible
,
t_shape
.
Rank
(),
&
t_shape
);
col_desc
=
ColDescriptor
(
colname
,
t_dtype
,
TensorImpl
::
kFlexible
,
t_shape
.
Rank
(),
&
t_shape
);
}
else
{
// unknown shape
}
else
{
// unknown shape
...
@@ -162,30 +163,7 @@ Status MindRecordOp::Init() {
...
@@ -162,30 +163,7 @@ Status MindRecordOp::Init() {
num_rows_
=
shard_reader_
->
GetNumRows
();
num_rows_
=
shard_reader_
->
GetNumRows
();
// Compute how many buffers we would need to accomplish rowsPerBuffer
// Compute how many buffers we would need to accomplish rowsPerBuffer
buffers_needed_
=
(
num_rows_
+
rows_per_buffer_
-
1
)
/
rows_per_buffer_
;
buffers_needed_
=
(
num_rows_
+
rows_per_buffer_
-
1
)
/
rows_per_buffer_
;
RETURN_IF_NOT_OK
(
SetColumnsBlob
());
return
Status
::
OK
();
}
Status
MindRecordOp
::
SetColumnsBlob
()
{
columns_blob_
=
shard_reader_
->
GetBlobFields
().
second
;
// get the exactly blob fields by columns_to_load_
std
::
vector
<
std
::
string
>
columns_blob_exact
;
for
(
auto
&
blob_field
:
columns_blob_
)
{
for
(
auto
&
column
:
columns_to_load_
)
{
if
(
column
.
compare
(
blob_field
)
==
0
)
{
columns_blob_exact
.
push_back
(
blob_field
);
break
;
}
}
}
columns_blob_index_
=
std
::
vector
<
int32_t
>
(
columns_to_load_
.
size
(),
-
1
);
int32_t
iBlob
=
0
;
for
(
auto
&
blob_exact
:
columns_blob_exact
)
{
columns_blob_index_
[
column_name_id_map_
[
blob_exact
]]
=
iBlob
++
;
}
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
@@ -215,248 +193,18 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
...
@@ -215,248 +193,18 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
}
}
}
}
template
<
typename
T
>
Status
MindRecordOp
::
LoadFeature
(
std
::
shared_ptr
<
Tensor
>
*
tensor
,
int32_t
i_col
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
mindrecord
::
json
&
columns_json
)
const
{
TensorShape
new_shape
=
TensorShape
::
CreateUnknownRankShape
();
const
unsigned
char
*
data
=
nullptr
;
std
::
unique_ptr
<
T
[]
>
array_data
;
std
::
string
string_data
;
const
ColDescriptor
&
cur_column
=
data_schema_
->
column
(
i_col
);
std
::
string
column_name
=
columns_to_load_
[
i_col
];
DataType
type
=
cur_column
.
type
();
// load blob column
if
(
columns_blob_index_
[
i_col
]
>=
0
&&
columns_blob
.
size
()
>
0
)
{
int32_t
pos
=
columns_blob_
.
size
()
==
1
?
-
1
:
columns_blob_index_
[
i_col
];
RETURN_IF_NOT_OK
(
LoadBlob
(
&
new_shape
,
&
data
,
columns_blob
,
pos
,
cur_column
));
}
else
{
switch
(
type
.
value
())
{
case
DataType
::
DE_UINT8
:
{
// For strings (Assume DE_UINT8 is reserved for strings)
RETURN_IF_NOT_OK
(
LoadByte
(
&
new_shape
,
&
string_data
,
column_name
,
columns_json
));
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
common
::
SafeCStr
(
string_data
));
break
;
}
case
DataType
::
DE_FLOAT32
:
{
// For both float scalars and arrays
RETURN_IF_NOT_OK
(
LoadFloat
(
&
new_shape
,
&
array_data
,
column_name
,
columns_json
,
cur_column
,
false
));
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
array_data
.
get
());
break
;
}
case
DataType
::
DE_FLOAT64
:
{
// For both double scalars and arrays
RETURN_IF_NOT_OK
(
LoadFloat
(
&
new_shape
,
&
array_data
,
column_name
,
columns_json
,
cur_column
,
true
));
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
array_data
.
get
());
break
;
}
default:
{
// For both integers scalars and arrays
RETURN_IF_NOT_OK
(
LoadInt
(
&
new_shape
,
&
array_data
,
column_name
,
columns_json
,
cur_column
));
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
array_data
.
get
());
break
;
}
}
}
// Create Tensor with given details
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
tensor
,
cur_column
.
tensorImpl
(),
new_shape
,
type
,
data
));
return
Status
::
OK
();
}
Status
MindRecordOp
::
LoadBlob
(
TensorShape
*
new_shape
,
const
unsigned
char
**
data
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
int32_t
pos
,
const
ColDescriptor
&
column
)
{
const
auto
kColumnSize
=
column
.
type
().
SizeInBytes
();
if
(
kColumnSize
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"column size is null"
);
}
if
(
pos
==
-
1
)
{
if
(
column
.
hasShape
())
{
*
new_shape
=
TensorShape
::
CreateUnknownRankShape
();
RETURN_IF_NOT_OK
(
column
.
MaterializeTensorShape
(
static_cast
<
int32_t
>
(
columns_blob
.
size
()
/
kColumnSize
),
new_shape
));
}
else
{
std
::
vector
<
dsize_t
>
shapeDetails
=
{
static_cast
<
dsize_t
>
(
columns_blob
.
size
()
/
kColumnSize
)};
*
new_shape
=
TensorShape
(
shapeDetails
);
}
*
data
=
reinterpret_cast
<
const
uint8_t
*>
(
&
(
columns_blob
[
0
]));
return
Status
::
OK
();
}
auto
uint64_from_bytes
=
[
&
](
int64_t
pos
)
{
uint64_t
result
=
0
;
for
(
uint64_t
n
=
0
;
n
<
kInt64Len
;
n
++
)
{
result
=
(
result
<<
8
)
+
columns_blob
[
pos
+
n
];
}
return
result
;
};
uint64_t
iStart
=
0
;
for
(
int32_t
i
=
0
;
i
<
pos
;
i
++
)
{
uint64_t
num_bytes
=
uint64_from_bytes
(
iStart
);
iStart
+=
kInt64Len
+
num_bytes
;
}
uint64_t
num_bytes
=
uint64_from_bytes
(
iStart
);
iStart
+=
kInt64Len
;
if
(
column
.
hasShape
())
{
*
new_shape
=
TensorShape
::
CreateUnknownRankShape
();
RETURN_IF_NOT_OK
(
column
.
MaterializeTensorShape
(
static_cast
<
int32_t
>
(
num_bytes
/
kColumnSize
),
new_shape
));
}
else
{
std
::
vector
<
dsize_t
>
shapeDetails
=
{
static_cast
<
dsize_t
>
(
num_bytes
/
kColumnSize
)};
*
new_shape
=
TensorShape
(
shapeDetails
);
}
*
data
=
reinterpret_cast
<
const
uint8_t
*>
(
&
(
columns_blob
[
iStart
]));
return
Status
::
OK
();
}
template
<
typename
T
>
Status
MindRecordOp
::
LoadFloat
(
TensorShape
*
new_shape
,
std
::
unique_ptr
<
T
[]
>
*
array_data
,
const
std
::
string
&
column_name
,
const
mindrecord
::
json
&
columns_json
,
const
ColDescriptor
&
column
,
bool
use_double
)
{
if
(
!
columns_json
[
column_name
].
is_array
())
{
T
value
=
0
;
RETURN_IF_NOT_OK
(
GetFloat
(
&
value
,
columns_json
[
column_name
],
use_double
));
*
new_shape
=
TensorShape
::
CreateScalar
();
*
array_data
=
std
::
make_unique
<
T
[]
>
(
1
);
(
*
array_data
)[
0
]
=
value
;
}
else
{
if
(
column
.
hasShape
())
{
*
new_shape
=
TensorShape
(
column
.
shape
());
}
else
{
std
::
vector
<
dsize_t
>
shapeDetails
=
{
static_cast
<
dsize_t
>
(
columns_json
[
column_name
].
size
())};
*
new_shape
=
TensorShape
(
shapeDetails
);
}
int
idx
=
0
;
*
array_data
=
std
::
make_unique
<
T
[]
>
(
new_shape
->
NumOfElements
());
for
(
auto
&
element
:
columns_json
[
column_name
])
{
T
value
=
0
;
RETURN_IF_NOT_OK
(
GetFloat
(
&
value
,
element
,
use_double
));
(
*
array_data
)[
idx
++
]
=
value
;
}
}
return
Status
::
OK
();
}
template
<
typename
T
>
Status
MindRecordOp
::
GetFloat
(
T
*
value
,
const
mindrecord
::
json
&
data
,
bool
use_double
)
{
if
(
data
.
is_number
())
{
*
value
=
data
;
}
else
if
(
data
.
is_string
())
{
try
{
if
(
use_double
)
{
*
value
=
data
.
get
<
double
>
();
}
else
{
*
value
=
data
.
get
<
float
>
();
}
}
catch
(
mindrecord
::
json
::
exception
&
e
)
{
RETURN_STATUS_UNEXPECTED
(
"Conversion to float failed."
);
}
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Conversion to float failed."
);
}
return
Status
::
OK
();
}
template
<
typename
T
>
Status
MindRecordOp
::
LoadInt
(
TensorShape
*
new_shape
,
std
::
unique_ptr
<
T
[]
>
*
array_data
,
const
std
::
string
&
column_name
,
const
mindrecord
::
json
&
columns_json
,
const
ColDescriptor
&
column
)
{
if
(
!
columns_json
[
column_name
].
is_array
())
{
T
value
=
0
;
RETURN_IF_NOT_OK
(
GetInt
(
&
value
,
columns_json
[
column_name
]));
*
new_shape
=
TensorShape
::
CreateScalar
();
*
array_data
=
std
::
make_unique
<
T
[]
>
(
1
);
(
*
array_data
)[
0
]
=
value
;
}
else
{
if
(
column
.
hasShape
())
{
*
new_shape
=
TensorShape
(
column
.
shape
());
}
else
{
std
::
vector
<
dsize_t
>
shapeDetails
=
{
static_cast
<
dsize_t
>
(
columns_json
[
column_name
].
size
())};
*
new_shape
=
TensorShape
(
shapeDetails
);
}
int
idx
=
0
;
*
array_data
=
std
::
make_unique
<
T
[]
>
(
new_shape
->
NumOfElements
());
for
(
auto
&
element
:
columns_json
[
column_name
])
{
T
value
=
0
;
RETURN_IF_NOT_OK
(
GetInt
(
&
value
,
element
));
(
*
array_data
)[
idx
++
]
=
value
;
}
}
return
Status
::
OK
();
}
template
<
typename
T
>
Status
MindRecordOp
::
GetInt
(
T
*
value
,
const
mindrecord
::
json
&
data
)
{
int64_t
temp_value
=
0
;
bool
less_than_zero
=
false
;
if
(
data
.
is_number_integer
())
{
const
mindrecord
::
json
json_zero
=
0
;
if
(
data
<
json_zero
)
less_than_zero
=
true
;
temp_value
=
data
;
}
else
if
(
data
.
is_string
())
{
std
::
string
string_value
=
data
;
if
(
!
string_value
.
empty
()
&&
string_value
[
0
]
==
'-'
)
{
try
{
temp_value
=
std
::
stoll
(
string_value
);
less_than_zero
=
true
;
}
catch
(
std
::
invalid_argument
&
e
)
{
RETURN_STATUS_UNEXPECTED
(
"Conversion to int failed, invalid argument."
);
}
catch
(
std
::
out_of_range
&
e
)
{
RETURN_STATUS_UNEXPECTED
(
"Conversion to int failed, out of range."
);
}
}
else
{
try
{
temp_value
=
static_cast
<
int64_t
>
(
std
::
stoull
(
string_value
));
}
catch
(
std
::
invalid_argument
&
e
)
{
RETURN_STATUS_UNEXPECTED
(
"Conversion to int failed, invalid argument."
);
}
catch
(
std
::
out_of_range
&
e
)
{
RETURN_STATUS_UNEXPECTED
(
"Conversion to int failed, out of range."
);
}
}
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Conversion to int failed."
);
}
if
((
less_than_zero
&&
temp_value
<
static_cast
<
int64_t
>
(
std
::
numeric_limits
<
T
>::
min
()))
||
(
!
less_than_zero
&&
static_cast
<
uint64_t
>
(
temp_value
)
>
static_cast
<
uint64_t
>
(
std
::
numeric_limits
<
T
>::
max
())))
{
RETURN_STATUS_UNEXPECTED
(
"Conversion to int failed. Out of range"
);
}
*
value
=
static_cast
<
T
>
(
temp_value
);
return
Status
::
OK
();
}
Status
MindRecordOp
::
LoadByte
(
TensorShape
*
new_shape
,
std
::
string
*
string_data
,
const
std
::
string
&
column_name
,
const
mindrecord
::
json
&
columns_json
)
{
*
string_data
=
columns_json
[
column_name
];
std
::
vector
<
dsize_t
>
shape_details
=
{
static_cast
<
dsize_t
>
(
string_data
->
size
())};
*
new_shape
=
TensorShape
(
shape_details
);
return
Status
::
OK
();
}
Status
MindRecordOp
::
WorkerEntry
(
int32_t
worker_id
)
{
Status
MindRecordOp
::
WorkerEntry
(
int32_t
worker_id
)
{
TaskManager
::
FindMe
()
->
Post
();
TaskManager
::
FindMe
()
->
Post
();
std
::
unique_ptr
<
IOBlock
>
io_block
;
std
::
unique_ptr
<
IOBlock
>
io_block
;
RETURN_IF_NOT_OK
(
io_blk_queues_
[
worker_id
]
->
PopFront
(
&
io_block
));
RETURN_IF_NOT_OK
(
io_blk_queues_
[
worker_id
]
->
PopFront
(
&
io_block
));
while
(
io_block
!=
nullptr
)
{
while
(
io_block
!=
nullptr
)
{
if
(
io_block
->
eoe
()
==
true
)
{
if
(
io_block
->
eoe
())
{
RETURN_IF_NOT_OK
(
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
worker_id
,
std
::
move
(
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
))));
out_connector_
->
Add
(
worker_id
,
std
::
move
(
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
))));
RETURN_IF_NOT_OK
(
io_blk_queues_
[
worker_id
]
->
PopFront
(
&
io_block
));
RETURN_IF_NOT_OK
(
io_blk_queues_
[
worker_id
]
->
PopFront
(
&
io_block
));
continue
;
continue
;
}
}
if
(
io_block
->
eof
()
==
true
)
{
if
(
io_block
->
eof
())
{
RETURN_IF_NOT_OK
(
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
worker_id
,
std
::
move
(
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
))));
out_connector_
->
Add
(
worker_id
,
std
::
move
(
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
))));
RETURN_IF_NOT_OK
(
io_blk_queues_
[
worker_id
]
->
PopFront
(
&
io_block
));
RETURN_IF_NOT_OK
(
io_blk_queues_
[
worker_id
]
->
PopFront
(
&
io_block
));
...
@@ -521,19 +269,10 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu
...
@@ -521,19 +269,10 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu
if
(
tupled_buffer
.
empty
())
break
;
if
(
tupled_buffer
.
empty
())
break
;
}
}
for
(
const
auto
&
tupled_row
:
tupled_buffer
)
{
for
(
const
auto
&
tupled_row
:
tupled_buffer
)
{
std
::
vector
<
uint8_t
>
columns
B
lob
=
std
::
get
<
0
>
(
tupled_row
);
std
::
vector
<
uint8_t
>
columns
_b
lob
=
std
::
get
<
0
>
(
tupled_row
);
mindrecord
::
json
columns_json
=
std
::
get
<
1
>
(
tupled_row
);
mindrecord
::
json
columns_json
=
std
::
get
<
1
>
(
tupled_row
);
TensorRow
tensor_row
;
TensorRow
tensor_row
;
for
(
uint32_t
j
=
0
;
j
<
columns_to_load_
.
size
();
++
j
)
{
RETURN_IF_NOT_OK
(
LoadTensorRow
(
&
tensor_row
,
columns_blob
,
columns_json
));
std
::
shared_ptr
<
Tensor
>
tensor
;
const
ColDescriptor
&
cur_column
=
data_schema_
->
column
(
j
);
DataType
type
=
cur_column
.
type
();
RETURN_IF_NOT_OK
(
SwitchLoadFeature
(
type
,
&
tensor
,
j
,
columnsBlob
,
columns_json
));
tensor_row
.
push_back
(
std
::
move
(
tensor
));
}
tensor_table
->
push_back
(
std
::
move
(
tensor_row
));
tensor_table
->
push_back
(
std
::
move
(
tensor_row
));
}
}
}
}
...
@@ -543,48 +282,46 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu
...
@@ -543,48 +282,46 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
MindRecordOp
::
SwitchLoadFeature
(
const
DataType
&
type
,
std
::
shared_ptr
<
Tensor
>
*
tensor
,
int32_t
i_col
,
Status
MindRecordOp
::
LoadTensorRow
(
TensorRow
*
tensor_row
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
mindrecord
::
json
&
columns_json
)
{
const
mindrecord
::
json
&
columns_json
)
const
{
for
(
uint32_t
i_col
=
0
;
i_col
<
columns_to_load_
.
size
();
i_col
++
)
{
switch
(
type
.
value
())
{
auto
column_name
=
columns_to_load_
[
i_col
];
case
DataType
::
DE_BOOL
:
{
return
LoadFeature
<
bool
>
(
tensor
,
i_col
,
columns_blob
,
columns_json
);
// Initialize column parameters
}
const
unsigned
char
*
data
=
nullptr
;
case
DataType
::
DE_INT8
:
{
std
::
unique_ptr
<
unsigned
char
[]
>
data_ptr
;
return
LoadFeature
<
int8_t
>
(
tensor
,
i_col
,
columns_blob
,
columns_json
);
uint64_t
n_bytes
=
0
;
}
mindrecord
::
ColumnDataType
column_data_type
=
mindrecord
::
ColumnNoDataType
;
case
DataType
::
DE_UINT8
:
{
uint64_t
column_data_type_size
=
1
;
return
LoadFeature
<
uint8_t
>
(
tensor
,
i_col
,
columns_blob
,
columns_json
);
std
::
vector
<
int64_t
>
column_shape
;
}
case
DataType
::
DE_INT16
:
{
// Get column data
return
LoadFeature
<
int16_t
>
(
tensor
,
i_col
,
columns_blob
,
columns_json
);
}
auto
has_column
=
shard_reader_
->
get_shard_column
()
->
GetColumnValueByName
(
case
DataType
::
DE_UINT16
:
{
column_name
,
columns_blob
,
columns_json
,
&
data
,
&
data_ptr
,
&
n_bytes
,
&
column_data_type
,
&
column_data_type_size
,
return
LoadFeature
<
uint16_t
>
(
tensor
,
i_col
,
columns_blob
,
columns_json
);
&
column_shape
);
}
if
(
has_column
==
MSRStatus
::
FAILED
)
{
case
DataType
::
DE_INT32
:
{
RETURN_STATUS_UNEXPECTED
(
"Failed to retrieve data from mindrecord reader."
);
return
LoadFeature
<
int32_t
>
(
tensor
,
i_col
,
columns_blob
,
columns_json
);
}
case
DataType
::
DE_UINT32
:
{
return
LoadFeature
<
uint32_t
>
(
tensor
,
i_col
,
columns_blob
,
columns_json
);
}
case
DataType
::
DE_INT64
:
{
return
LoadFeature
<
int64_t
>
(
tensor
,
i_col
,
columns_blob
,
columns_json
);
}
case
DataType
::
DE_UINT64
:
{
return
LoadFeature
<
uint64_t
>
(
tensor
,
i_col
,
columns_blob
,
columns_json
);
}
case
DataType
::
DE_FLOAT32
:
{
return
LoadFeature
<
float
>
(
tensor
,
i_col
,
columns_blob
,
columns_json
);
}
case
DataType
::
DE_FLOAT64
:
{
return
LoadFeature
<
double
>
(
tensor
,
i_col
,
columns_blob
,
columns_json
);
}
}
default:
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
std
::
shared_ptr
<
Tensor
>
tensor
;
"mindrecord column list type does not match any known types"
);
const
ColDescriptor
&
column
=
data_schema_
->
column
(
i_col
);
DataType
type
=
column
.
type
();
// Set shape
auto
num_elements
=
n_bytes
/
column_data_type_size
;
if
(
column
.
hasShape
())
{
auto
new_shape
=
TensorShape
(
column
.
shape
());
RETURN_IF_NOT_OK
(
column
.
MaterializeTensorShape
(
static_cast
<
int32_t
>
(
num_elements
),
&
new_shape
));
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
tensor
,
column
.
tensorImpl
(),
new_shape
,
type
,
data
));
}
else
{
std
::
vector
<
dsize_t
>
shapeDetails
=
{
static_cast
<
dsize_t
>
(
num_elements
)};
auto
new_shape
=
TensorShape
(
shapeDetails
);
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
tensor
,
column
.
tensorImpl
(),
new_shape
,
type
,
data
));
}
}
tensor_row
->
push_back
(
std
::
move
(
tensor
));
}
}
return
Status
::
OK
();
}
}
Status
MindRecordOp
::
FetchBlockBuffer
(
const
int32_t
&
buffer_id
)
{
Status
MindRecordOp
::
FetchBlockBuffer
(
const
int32_t
&
buffer_id
)
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h
浏览文件 @
decf12cd
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
#include <queue>
#include <queue>
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
...
@@ -31,6 +32,7 @@
...
@@ -31,6 +32,7 @@
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/util/queue.h"
#include "dataset/util/queue.h"
#include "dataset/util/status.h"
#include "dataset/util/status.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_reader.h"
#include "mindrecord/include/shard_reader.h"
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/common/shard_utils.h"
...
@@ -193,8 +195,6 @@ class MindRecordOp : public ParallelOp {
...
@@ -193,8 +195,6 @@ class MindRecordOp : public ParallelOp {
Status
Init
();
Status
Init
();
Status
SetColumnsBlob
();
// Base-class override for NodePass visitor acceptor.
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @param modified - Whether this node visit modified the pipeline.
...
@@ -205,56 +205,11 @@ class MindRecordOp : public ParallelOp {
...
@@ -205,56 +205,11 @@ class MindRecordOp : public ParallelOp {
Status
GetBufferFromReader
(
std
::
unique_ptr
<
DataBuffer
>
*
fetched_buffer
,
int64_t
buffer_id
,
int32_t
worker_id
);
Status
GetBufferFromReader
(
std
::
unique_ptr
<
DataBuffer
>
*
fetched_buffer
,
int64_t
buffer_id
,
int32_t
worker_id
);
// Parses a single cell and puts the data into a tensor
// Parses a single cell and puts the data into a tensor
// @param tensor - the tensor to put the parsed data in
// @param tensor_row - the tensor row to put the parsed data in
// @param i_col - the id of column to parse
// @param columns_blob - the blob data received from the reader
// @param columns_blob - the blob data received from the reader
// @param columns_json - the data for fields received from the reader
// @param columns_json - the data for fields received from the reader
template
<
typename
T
>
Status
LoadTensorRow
(
TensorRow
*
tensor_row
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
Status
LoadFeature
(
std
::
shared_ptr
<
Tensor
>
*
tensor
,
int32_t
i_col
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
mindrecord
::
json
&
columns_json
);
const
mindrecord
::
json
&
columns_json
)
const
;
Status
SwitchLoadFeature
(
const
DataType
&
type
,
std
::
shared_ptr
<
Tensor
>
*
tensor
,
int32_t
i_col
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
mindrecord
::
json
&
columns_json
)
const
;
static
Status
LoadBlob
(
TensorShape
*
new_shape
,
const
unsigned
char
**
data
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
int32_t
pos
,
const
ColDescriptor
&
column
);
// Get shape and data (scalar or array) for tensor to be created (for floats and doubles)
// @param new_shape - the shape of tensor to be created.
// @param array_data - the array where data should be put in
// @param column_name - name of current column to be processed
// @param columns_json - the data for fields received from the reader
// @param column - description of current column from schema
// @param use_double - boolean to choose between float32 and float64
template
<
typename
T
>
static
Status
LoadFloat
(
TensorShape
*
new_shape
,
std
::
unique_ptr
<
T
[]
>
*
array_data
,
const
std
::
string
&
column_name
,
const
mindrecord
::
json
&
columns_json
,
const
ColDescriptor
&
column
,
bool
use_double
);
// Get shape and data (scalar or array) for tensor to be created (for integers)
// @param new_shape - the shape of tensor to be created.
// @param array_data - the array where data should be put in
// @param column_name - name of current column to be processed
// @param columns_json - the data for fields received from the reader
// @param column - description of current column from schema
template
<
typename
T
>
static
Status
LoadInt
(
TensorShape
*
new_shape
,
std
::
unique_ptr
<
T
[]
>
*
array_data
,
const
std
::
string
&
column_name
,
const
mindrecord
::
json
&
columns_json
,
const
ColDescriptor
&
column
);
static
Status
LoadByte
(
TensorShape
*
new_shape
,
std
::
string
*
string_data
,
const
std
::
string
&
column_name
,
const
mindrecord
::
json
&
columns_json
);
// Get a single float value from the given json
// @param value - the float to put the value in
// @param arrayData - the given json containing the float
// @param use_double - boolean to choose between float32 and float64
template
<
typename
T
>
static
Status
GetFloat
(
T
*
value
,
const
mindrecord
::
json
&
data
,
bool
use_double
);
// Get a single integer value from the given json
// @param value - the integer to put the value in
// @param arrayData - the given json containing the integer
template
<
typename
T
>
static
Status
GetInt
(
T
*
value
,
const
mindrecord
::
json
&
data
);
Status
FetchBlockBuffer
(
const
int32_t
&
buffer_id
);
Status
FetchBlockBuffer
(
const
int32_t
&
buffer_id
);
...
...
mindspore/ccsrc/mindrecord/common/shard_pybind.cc
浏览文件 @
decf12cd
...
@@ -91,8 +91,8 @@ void BindShardReader(const py::module *m) {
...
@@ -91,8 +91,8 @@ void BindShardReader(const py::module *m) {
.
def
(
"launch"
,
&
ShardReader
::
Launch
)
.
def
(
"launch"
,
&
ShardReader
::
Launch
)
.
def
(
"get_header"
,
&
ShardReader
::
GetShardHeader
)
.
def
(
"get_header"
,
&
ShardReader
::
GetShardHeader
)
.
def
(
"get_blob_fields"
,
&
ShardReader
::
GetBlobFields
)
.
def
(
"get_blob_fields"
,
&
ShardReader
::
GetBlobFields
)
.
def
(
"get_next"
,
.
def
(
"get_next"
,
(
std
::
vector
<
std
::
tuple
<
std
::
vector
<
std
::
vector
<
uint8_t
>>
,
pybind11
::
object
>>
(
ShardReader
::*
)())
&
(
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
pybind11
::
object
>>
(
ShardReader
::*
)())
&
ShardReader
::
GetNextPy
)
ShardReader
::
GetNextPy
)
.
def
(
"finish"
,
&
ShardReader
::
Finish
)
.
def
(
"finish"
,
&
ShardReader
::
Finish
)
.
def
(
"close"
,
&
ShardReader
::
Close
);
.
def
(
"close"
,
&
ShardReader
::
Close
);
}
}
...
...
mindspore/ccsrc/mindrecord/include/common/shard_utils.h
浏览文件 @
decf12cd
...
@@ -65,6 +65,9 @@ const int kUnsignedInt4 = 4;
...
@@ -65,6 +65,9 @@ const int kUnsignedInt4 = 4;
enum
LabelCategory
{
kSchemaLabel
,
kStatisticsLabel
,
kIndexLabel
};
enum
LabelCategory
{
kSchemaLabel
,
kStatisticsLabel
,
kIndexLabel
};
const
char
kVersion
[]
=
"3.0"
;
const
std
::
vector
<
std
::
string
>
kSupportedVersion
=
{
"2.0"
,
kVersion
};
enum
ShardType
{
enum
ShardType
{
kNLP
=
0
,
kNLP
=
0
,
kCV
=
1
,
kCV
=
1
,
...
...
mindspore/ccsrc/mindrecord/include/shard_column.h
0 → 100644
浏览文件 @
decf12cd
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_
#define MINDRECORD_INCLUDE_SHARD_COLUMN_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "mindrecord/include/shard_header.h"
namespace
mindspore
{
namespace
mindrecord
{
const
uint64_t
kUnsignedOne
=
1
;
const
uint64_t
kBitsOfByte
=
8
;
const
uint64_t
kDataTypeBits
=
2
;
const
uint64_t
kNumDataOfByte
=
4
;
const
uint64_t
kBytesOfColumnLen
=
4
;
const
uint64_t
kDataTypeBitMask
=
3
;
const
uint64_t
kDataTypes
=
6
;
enum
IntegerType
{
kInt8Type
=
0
,
kInt16Type
,
kInt32Type
,
kInt64Type
};
enum
ColumnCategory
{
ColumnInRaw
,
ColumnInBlob
,
ColumnNotFound
};
enum
ColumnDataType
{
ColumnBytes
=
0
,
ColumnString
=
1
,
ColumnInt32
=
2
,
ColumnInt64
=
3
,
ColumnFloat32
=
4
,
ColumnFloat64
=
5
,
ColumnNoDataType
=
6
};
// mapping as {"bytes", "string", "int32", "int64", "float32", "float64"};
const
uint32_t
ColumnDataTypeSize
[
kDataTypes
]
=
{
1
,
1
,
4
,
8
,
4
,
8
};
const
std
::
vector
<
std
::
string
>
ColumnDataTypeNameNormalized
=
{
"uint8"
,
"uint8"
,
"int32"
,
"int64"
,
"float32"
,
"float64"
};
const
std
::
unordered_map
<
std
::
string
,
ColumnDataType
>
ColumnDataTypeMap
=
{
{
"bytes"
,
ColumnBytes
},
{
"string"
,
ColumnString
},
{
"int32"
,
ColumnInt32
},
{
"int64"
,
ColumnInt64
},
{
"float32"
,
ColumnFloat32
},
{
"float64"
,
ColumnFloat64
}};
class
ShardColumn
{
public:
explicit
ShardColumn
(
const
std
::
shared_ptr
<
ShardHeader
>
&
shard_header
,
bool
compress_integer
=
true
);
~
ShardColumn
()
=
default
;
/// \brief get column value by column name
MSRStatus
GetColumnValueByName
(
const
std
::
string
&
column_name
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
json
&
columns_json
,
const
unsigned
char
**
data
,
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
uint64_t
*
n_bytes
,
ColumnDataType
*
column_data_type
,
uint64_t
*
column_data_type_size
,
std
::
vector
<
int64_t
>
*
column_shape
);
/// \brief compress blob
std
::
vector
<
uint8_t
>
CompressBlob
(
const
std
::
vector
<
uint8_t
>
&
blob
);
/// \brief check if blob compressed
bool
CheckCompressBlob
()
const
{
return
has_compress_blob_
;
}
uint64_t
GetNumBlobColumn
()
const
{
return
num_blob_column_
;
}
std
::
vector
<
std
::
string
>
GetColumnName
()
{
return
column_name_
;
}
std
::
vector
<
ColumnDataType
>
GeColumnDataType
()
{
return
column_data_type_
;
}
std
::
vector
<
std
::
vector
<
int64_t
>>
GetColumnShape
()
{
return
column_shape_
;
}
/// \brief get column value from blob
MSRStatus
GetColumnFromBlob
(
const
std
::
string
&
column_name
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
unsigned
char
**
data
,
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
uint64_t
*
n_bytes
);
private:
/// \brief get column value from json
MSRStatus
GetColumnFromJson
(
const
std
::
string
&
column_name
,
const
json
&
columns_json
,
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
uint64_t
*
n_bytes
);
/// \brief get float value from json
template
<
typename
T
>
MSRStatus
GetFloat
(
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
const
json
&
json_column_value
,
bool
use_double
);
/// \brief get integer value from json
template
<
typename
T
>
MSRStatus
GetInt
(
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
const
json
&
json_column_value
);
/// \brief get column offset address and size from blob
MSRStatus
GetColumnAddressInBlock
(
const
uint64_t
&
column_id
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
uint64_t
*
num_bytes
,
uint64_t
*
shift_idx
);
/// \brief check if column name is available
ColumnCategory
CheckColumnName
(
const
std
::
string
&
column_name
);
/// \brief compress integer column
static
vector
<
uint8_t
>
CompressInt
(
const
vector
<
uint8_t
>
&
src_bytes
,
const
IntegerType
&
int_type
);
/// \brief uncompress integer array column
template
<
typename
T
>
static
MSRStatus
UncompressInt
(
const
uint64_t
&
column_id
,
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
uint64_t
*
num_bytes
,
uint64_t
shift_idx
);
/// \brief convert big-endian bytes to unsigned int
/// \param bytes_array bytes array
/// \param pos shift address in bytes array
/// \param i_type integer type
/// \return unsigned int
static
uint64_t
BytesBigToUInt64
(
const
std
::
vector
<
uint8_t
>
&
bytes_array
,
const
uint64_t
&
pos
,
const
IntegerType
&
i_type
);
/// \brief convert unsigned int to big-endian bytes
/// \param value integer value
/// \param i_type integer type
/// \return bytes
static
std
::
vector
<
uint8_t
>
UIntToBytesBig
(
uint64_t
value
,
const
IntegerType
&
i_type
);
/// \brief convert unsigned int to little-endian bytes
/// \param value integer value
/// \param i_type integer type
/// \return bytes
static
std
::
vector
<
uint8_t
>
UIntToBytesLittle
(
uint64_t
value
,
const
IntegerType
&
i_type
);
/// \brief convert unsigned int to little-endian bytes
/// \param bytes_array bytes array
/// \param pos shift address in bytes array
/// \param src_i_type source integer typ0e
/// \param dst_i_type (output), destination integer type
/// \return integer
static
int64_t
BytesLittleToMinIntType
(
const
std
::
vector
<
uint8_t
>
&
bytes_array
,
const
uint64_t
&
pos
,
const
IntegerType
&
src_i_type
,
IntegerType
*
dst_i_type
=
nullptr
);
private:
std
::
vector
<
std
::
string
>
column_name_
;
// column name list
std
::
vector
<
ColumnDataType
>
column_data_type_
;
// column data type list
std
::
vector
<
std
::
vector
<
int64_t
>>
column_shape_
;
// column shape list
std
::
unordered_map
<
string
,
uint64_t
>
column_name_id_
;
// column name id map
std
::
vector
<
std
::
string
>
blob_column_
;
// blob column list
std
::
unordered_map
<
std
::
string
,
uint64_t
>
blob_column_id_
;
// blob column name id map
bool
has_compress_blob_
;
// if has compress blob
uint64_t
num_blob_column_
;
// number of blob columns
};
}
// namespace mindrecord
}
// namespace mindspore
#endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_
mindspore/ccsrc/mindrecord/include/shard_header.h
浏览文件 @
decf12cd
...
@@ -118,8 +118,6 @@ class ShardHeader {
...
@@ -118,8 +118,6 @@ class ShardHeader {
void
SetPageSize
(
const
uint64_t
&
page_size
)
{
page_size_
=
page_size
;
}
void
SetPageSize
(
const
uint64_t
&
page_size
)
{
page_size_
=
page_size
;
}
const
string
GetVersion
()
{
return
version_
;
}
std
::
vector
<
std
::
string
>
SerializeHeader
();
std
::
vector
<
std
::
string
>
SerializeHeader
();
MSRStatus
PagesToFile
(
const
std
::
string
dump_file_name
);
MSRStatus
PagesToFile
(
const
std
::
string
dump_file_name
);
...
@@ -175,7 +173,6 @@ class ShardHeader {
...
@@ -175,7 +173,6 @@ class ShardHeader {
uint32_t
shard_count_
;
uint32_t
shard_count_
;
uint64_t
header_size_
;
uint64_t
header_size_
;
uint64_t
page_size_
;
uint64_t
page_size_
;
string
version_
=
"2.0"
;
std
::
shared_ptr
<
Index
>
index_
;
std
::
shared_ptr
<
Index
>
index_
;
std
::
vector
<
std
::
string
>
shard_addresses_
;
std
::
vector
<
std
::
string
>
shard_addresses_
;
...
...
mindspore/ccsrc/mindrecord/include/shard_reader.h
浏览文件 @
decf12cd
...
@@ -43,6 +43,7 @@
...
@@ -43,6 +43,7 @@
#include <vector>
#include <vector>
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_index_generator.h"
#include "mindrecord/include/shard_index_generator.h"
#include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_operator.h"
...
@@ -111,6 +112,10 @@ class ShardReader {
...
@@ -111,6 +112,10 @@ class ShardReader {
/// \return the metadata
/// \return the metadata
std
::
shared_ptr
<
ShardHeader
>
GetShardHeader
()
const
;
std
::
shared_ptr
<
ShardHeader
>
GetShardHeader
()
const
;
/// \brief aim to get columns context
/// \return the columns
std
::
shared_ptr
<
ShardColumn
>
get_shard_column
()
const
;
/// \brief get the number of shards
/// \brief get the number of shards
/// \return # of shards
/// \return # of shards
int
GetShardCount
()
const
;
int
GetShardCount
()
const
;
...
@@ -185,7 +190,7 @@ class ShardReader {
...
@@ -185,7 +190,7 @@ class ShardReader {
/// \brief return a batch, given that one is ready, python API
/// \brief return a batch, given that one is ready, python API
/// \return a batch of images and image data
/// \return a batch of images and image data
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
pybind11
::
object
>>
GetNextPy
();
std
::
vector
<
std
::
tuple
<
std
::
vector
<
std
::
vector
<
uint8_t
>
>
,
pybind11
::
object
>>
GetNextPy
();
/// \brief get blob filed list
/// \brief get blob filed list
/// \return blob field list
/// \return blob field list
...
@@ -295,16 +300,18 @@ class ShardReader {
...
@@ -295,16 +300,18 @@ class ShardReader {
/// \brief get number of classes
/// \brief get number of classes
int64_t
GetNumClasses
(
const
std
::
string
&
category_field
);
int64_t
GetNumClasses
(
const
std
::
string
&
category_field
);
/// \brief get meta of header
std
::
pair
<
MSRStatus
,
std
::
vector
<
std
::
string
>>
GetMeta
(
const
std
::
string
&
file_path
,
json
&
meta_data
);
std
::
pair
<
MSRStatus
,
std
::
vector
<
std
::
string
>>
GetMeta
(
const
std
::
string
&
file_path
,
json
&
meta_data
);
/// \brief get exactly blob fields data by indices
std
::
vector
<
uint8_t
>
ExtractBlobFieldBySelectColumns
(
std
::
vector
<
uint8_t
>
&
blob_fields_bytes
,
/// \brief extract uncompressed data based on column list
std
::
vector
<
uint32_t
>
&
ordered_selected_columns_index
);
std
::
pair
<
MSRStatus
,
std
::
vector
<
std
::
vector
<
uint8_t
>>>
UnCompressBlob
(
const
std
::
vector
<
uint8_t
>
&
raw_blob_data
);
protected:
protected:
uint64_t
header_size_
;
// header size
uint64_t
header_size_
;
// header size
uint64_t
page_size_
;
// page size
uint64_t
page_size_
;
// page size
int
shard_count_
;
// number of shards
int
shard_count_
;
// number of shards
std
::
shared_ptr
<
ShardHeader
>
shard_header_
;
// shard header
std
::
shared_ptr
<
ShardHeader
>
shard_header_
;
// shard header
std
::
shared_ptr
<
ShardColumn
>
shard_column_
;
// shard column
std
::
vector
<
sqlite3
*>
database_paths_
;
// sqlite handle list
std
::
vector
<
sqlite3
*>
database_paths_
;
// sqlite handle list
std
::
vector
<
string
>
file_paths_
;
// file paths
std
::
vector
<
string
>
file_paths_
;
// file paths
...
...
mindspore/ccsrc/mindrecord/include/shard_writer.h
浏览文件 @
decf12cd
...
@@ -36,6 +36,7 @@
...
@@ -36,6 +36,7 @@
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_header.h"
#include "mindrecord/include/shard_header.h"
#include "mindrecord/include/shard_index.h"
#include "mindrecord/include/shard_index.h"
...
@@ -242,7 +243,8 @@ class ShardWriter {
...
@@ -242,7 +243,8 @@ class ShardWriter {
std
::
vector
<
std
::
string
>
file_paths_
;
// file paths
std
::
vector
<
std
::
string
>
file_paths_
;
// file paths
std
::
vector
<
std
::
shared_ptr
<
std
::
fstream
>>
file_streams_
;
// file handles
std
::
vector
<
std
::
shared_ptr
<
std
::
fstream
>>
file_streams_
;
// file handles
std
::
shared_ptr
<
ShardHeader
>
shard_header_
;
// shard headers
std
::
shared_ptr
<
ShardHeader
>
shard_header_
;
// shard header
std
::
shared_ptr
<
ShardColumn
>
shard_column_
;
// shard columns
std
::
map
<
uint64_t
,
std
::
map
<
int
,
std
::
string
>>
err_mg_
;
// used for storing error raw_data info
std
::
map
<
uint64_t
,
std
::
map
<
int
,
std
::
string
>>
err_mg_
;
// used for storing error raw_data info
...
...
mindspore/ccsrc/mindrecord/io/shard_reader.cc
浏览文件 @
decf12cd
...
@@ -133,6 +133,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
...
@@ -133,6 +133,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
shard_header_
=
std
::
make_shared
<
ShardHeader
>
(
sh
);
shard_header_
=
std
::
make_shared
<
ShardHeader
>
(
sh
);
header_size_
=
shard_header_
->
GetHeaderSize
();
header_size_
=
shard_header_
->
GetHeaderSize
();
page_size_
=
shard_header_
->
GetPageSize
();
page_size_
=
shard_header_
->
GetPageSize
();
// version < 3.0
if
(
first_meta_data
[
"version"
]
<
kVersion
)
{
shard_column_
=
std
::
make_shared
<
ShardColumn
>
(
shard_header_
,
false
);
}
else
{
shard_column_
=
std
::
make_shared
<
ShardColumn
>
(
shard_header_
,
true
);
}
num_rows_
=
0
;
num_rows_
=
0
;
auto
row_group_summary
=
ReadRowGroupSummary
();
auto
row_group_summary
=
ReadRowGroupSummary
();
for
(
const
auto
&
rg
:
row_group_summary
)
{
for
(
const
auto
&
rg
:
row_group_summary
)
{
...
@@ -226,6 +232,8 @@ void ShardReader::Close() {
...
@@ -226,6 +232,8 @@ void ShardReader::Close() {
std
::
shared_ptr
<
ShardHeader
>
ShardReader
::
GetShardHeader
()
const
{
return
shard_header_
;
}
std
::
shared_ptr
<
ShardHeader
>
ShardReader
::
GetShardHeader
()
const
{
return
shard_header_
;
}
std
::
shared_ptr
<
ShardColumn
>
ShardReader
::
get_shard_column
()
const
{
return
shard_column_
;
}
int
ShardReader
::
GetShardCount
()
const
{
return
shard_header_
->
GetShardCount
();
}
int
ShardReader
::
GetShardCount
()
const
{
return
shard_header_
->
GetShardCount
();
}
int
ShardReader
::
GetNumRows
()
const
{
return
num_rows_
;
}
int
ShardReader
::
GetNumRows
()
const
{
return
num_rows_
;
}
...
@@ -1059,36 +1067,6 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
...
@@ -1059,36 +1067,6 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
return
SUCCESS
;
return
SUCCESS
;
}
}
std
::
vector
<
uint8_t
>
ShardReader
::
ExtractBlobFieldBySelectColumns
(
std
::
vector
<
uint8_t
>
&
blob_fields_bytes
,
std
::
vector
<
uint32_t
>
&
ordered_selected_columns_index
)
{
std
::
vector
<
uint8_t
>
exactly_blob_fields_bytes
;
auto
uint64_from_bytes
=
[
&
](
int64_t
pos
)
{
uint64_t
result
=
0
;
for
(
uint64_t
n
=
0
;
n
<
kInt64Len
;
n
++
)
{
result
=
(
result
<<
8
)
+
blob_fields_bytes
[
pos
+
n
];
}
return
result
;
};
// get the exactly blob fields
uint32_t
current_index
=
0
;
uint64_t
current_offset
=
0
;
uint64_t
data_len
=
uint64_from_bytes
(
current_offset
);
while
(
current_offset
<
blob_fields_bytes
.
size
())
{
if
(
std
::
any_of
(
ordered_selected_columns_index
.
begin
(),
ordered_selected_columns_index
.
end
(),
[
&
current_index
](
uint32_t
&
index
)
{
return
index
==
current_index
;
}))
{
exactly_blob_fields_bytes
.
insert
(
exactly_blob_fields_bytes
.
end
(),
blob_fields_bytes
.
begin
()
+
current_offset
,
blob_fields_bytes
.
begin
()
+
current_offset
+
kInt64Len
+
data_len
);
}
current_index
++
;
current_offset
+=
kInt64Len
+
data_len
;
data_len
=
uint64_from_bytes
(
current_offset
);
}
return
exactly_blob_fields_bytes
;
}
TASK_RETURN_CONTENT
ShardReader
::
ConsumerOneTask
(
int
task_id
,
uint32_t
consumer_id
)
{
TASK_RETURN_CONTENT
ShardReader
::
ConsumerOneTask
(
int
task_id
,
uint32_t
consumer_id
)
{
// All tasks are done
// All tasks are done
if
(
task_id
>=
static_cast
<
int
>
(
tasks_
.
Size
()))
{
if
(
task_id
>=
static_cast
<
int
>
(
tasks_
.
Size
()))
{
...
@@ -1126,40 +1104,10 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
...
@@ -1126,40 +1104,10 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
return
std
::
make_pair
(
FAILED
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
());
return
std
::
make_pair
(
FAILED
,
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
());
}
}
// extract the exactly blob bytes by selected columns
std
::
vector
<
uint8_t
>
images_with_exact_columns
;
if
(
selected_columns_
.
size
()
==
0
)
{
images_with_exact_columns
=
images
;
}
else
{
auto
blob_fields
=
GetBlobFields
();
std
::
vector
<
uint32_t
>
ordered_selected_columns_index
;
uint32_t
index
=
0
;
for
(
auto
&
blob_field
:
blob_fields
.
second
)
{
for
(
auto
&
field
:
selected_columns_
)
{
if
(
field
.
compare
(
blob_field
)
==
0
)
{
ordered_selected_columns_index
.
push_back
(
index
);
break
;
}
}
index
++
;
}
if
(
ordered_selected_columns_index
.
size
()
!=
0
)
{
// extract the images
if
(
blob_fields
.
second
.
size
()
==
1
)
{
if
(
ordered_selected_columns_index
.
size
()
==
1
)
{
images_with_exact_columns
=
images
;
}
}
else
{
images_with_exact_columns
=
ExtractBlobFieldBySelectColumns
(
images
,
ordered_selected_columns_index
);
}
}
}
// Deliver batch data to output map
// Deliver batch data to output map
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
batch
;
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>>
batch
;
batch
.
emplace_back
(
std
::
move
(
images_with_exact_columns
),
std
::
move
(
std
::
get
<
2
>
(
task
)));
batch
.
emplace_back
(
std
::
move
(
images
),
std
::
move
(
std
::
get
<
2
>
(
task
)));
return
std
::
make_pair
(
SUCCESS
,
std
::
move
(
batch
));
return
std
::
make_pair
(
SUCCESS
,
std
::
move
(
batch
));
}
}
...
@@ -1369,16 +1317,41 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNextById(con
...
@@ -1369,16 +1317,41 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNextById(con
return
std
::
move
(
ret
.
second
);
return
std
::
move
(
ret
.
second
);
}
}
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
pybind11
::
object
>>
ShardReader
::
GetNextPy
()
{
std
::
pair
<
MSRStatus
,
std
::
vector
<
std
::
vector
<
uint8_t
>>>
ShardReader
::
UnCompressBlob
(
const
std
::
vector
<
uint8_t
>
&
raw_blob_data
)
{
auto
loaded_columns
=
selected_columns_
.
size
()
==
0
?
shard_column_
->
GetColumnName
()
:
selected_columns_
;
auto
blob_fields
=
GetBlobFields
().
second
;
std
::
vector
<
std
::
vector
<
uint8_t
>>
blob_data
;
for
(
uint32_t
i_col
=
0
;
i_col
<
loaded_columns
.
size
();
++
i_col
)
{
if
(
std
::
find
(
blob_fields
.
begin
(),
blob_fields
.
end
(),
loaded_columns
[
i_col
])
==
blob_fields
.
end
())
continue
;
const
unsigned
char
*
data
=
nullptr
;
std
::
unique_ptr
<
unsigned
char
[]
>
data_ptr
;
uint64_t
n_bytes
=
0
;
auto
ret
=
shard_column_
->
GetColumnFromBlob
(
loaded_columns
[
i_col
],
raw_blob_data
,
&
data
,
&
data_ptr
,
&
n_bytes
);
if
(
ret
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"Error when get data from blob, column name is "
<<
loaded_columns
[
i_col
]
<<
"."
;
return
{
FAILED
,
std
::
vector
<
std
::
vector
<
uint8_t
>>
(
blob_fields
.
size
(),
std
::
vector
<
uint8_t
>
())};
}
if
(
data
==
nullptr
)
{
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
data_ptr
.
get
());
}
std
::
vector
<
uint8_t
>
column
(
data
,
data
+
(
n_bytes
/
sizeof
(
unsigned
char
)));
blob_data
.
push_back
(
column
);
}
return
{
SUCCESS
,
blob_data
};
}
std
::
vector
<
std
::
tuple
<
std
::
vector
<
std
::
vector
<
uint8_t
>>
,
pybind11
::
object
>>
ShardReader
::
GetNextPy
()
{
auto
res
=
GetNext
();
auto
res
=
GetNext
();
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
pybind11
::
object
>>
jsonD
ata
;
vector
<
std
::
tuple
<
std
::
vector
<
std
::
vector
<
uint8_t
>>
,
pybind11
::
object
>>
d
ata
;
std
::
transform
(
res
.
begin
(),
res
.
end
(),
std
::
back_inserter
(
jsonD
ata
),
std
::
transform
(
res
.
begin
(),
res
.
end
(),
std
::
back_inserter
(
d
ata
),
[](
const
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>
&
item
)
{
[
this
](
const
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
json
>
&
item
)
{
auto
&
j
=
std
::
get
<
1
>
(
item
);
auto
&
j
=
std
::
get
<
1
>
(
item
);
pybind11
::
object
obj
=
nlohmann
::
detail
::
FromJsonImpl
(
j
);
pybind11
::
object
obj
=
nlohmann
::
detail
::
FromJsonImpl
(
j
);
return
std
::
make_tuple
(
std
::
get
<
0
>
(
item
),
std
::
move
(
obj
));
auto
ret
=
UnCompressBlob
(
std
::
get
<
0
>
(
item
));
return
std
::
make_tuple
(
ret
.
second
,
std
::
move
(
obj
));
});
});
return
jsonD
ata
;
return
d
ata
;
}
}
void
ShardReader
::
Reset
()
{
void
ShardReader
::
Reset
()
{
...
...
mindspore/ccsrc/mindrecord/io/shard_writer.cc
浏览文件 @
decf12cd
...
@@ -206,6 +206,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
...
@@ -206,6 +206,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
MS_LOG
(
ERROR
)
<<
"Open file failed"
;
MS_LOG
(
ERROR
)
<<
"Open file failed"
;
return
FAILED
;
return
FAILED
;
}
}
shard_column_
=
std
::
make_shared
<
ShardColumn
>
(
shard_header_
);
return
SUCCESS
;
return
SUCCESS
;
}
}
...
@@ -271,6 +272,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
...
@@ -271,6 +272,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
shard_header_
=
header_data
;
shard_header_
=
header_data
;
shard_header_
->
SetHeaderSize
(
header_size_
);
shard_header_
->
SetHeaderSize
(
header_size_
);
shard_header_
->
SetPageSize
(
page_size_
);
shard_header_
->
SetPageSize
(
page_size_
);
shard_column_
=
std
::
make_shared
<
ShardColumn
>
(
shard_header_
);
return
SUCCESS
;
return
SUCCESS
;
}
}
...
@@ -608,6 +610,14 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
...
@@ -608,6 +610,14 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
MS_LOG
(
ERROR
)
<<
"IO error / there is no free disk to be used"
;
MS_LOG
(
ERROR
)
<<
"IO error / there is no free disk to be used"
;
return
FAILED
;
return
FAILED
;
}
}
// compress blob
if
(
shard_column_
->
CheckCompressBlob
())
{
for
(
auto
&
blob
:
blob_data
)
{
blob
=
shard_column_
->
CompressBlob
(
blob
);
}
}
// Add 4-bytes dummy blob data if no any blob fields
// Add 4-bytes dummy blob data if no any blob fields
if
(
blob_data
.
size
()
==
0
&&
raw_data
.
size
()
>
0
)
{
if
(
blob_data
.
size
()
==
0
&&
raw_data
.
size
()
>
0
)
{
blob_data
=
std
::
vector
<
std
::
vector
<
uint8_t
>>
(
raw_data
[
0
].
size
(),
std
::
vector
<
uint8_t
>
(
kUnsignedInt4
,
0
));
blob_data
=
std
::
vector
<
std
::
vector
<
uint8_t
>>
(
raw_data
[
0
].
size
(),
std
::
vector
<
uint8_t
>
(
kUnsignedInt4
,
0
));
...
...
mindspore/ccsrc/mindrecord/meta/shard_column.cc
0 → 100644
浏览文件 @
decf12cd
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindrecord/include/shard_column.h"
#include "common/utils.h"
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_error.h"
namespace
mindspore
{
namespace
mindrecord
{
ShardColumn
::
ShardColumn
(
const
std
::
shared_ptr
<
ShardHeader
>
&
shard_header
,
bool
compress_integer
)
{
auto
first_schema
=
shard_header
->
GetSchemas
()[
0
];
auto
schema
=
first_schema
->
GetSchema
()[
"schema"
];
bool
has_integer_array
=
false
;
for
(
json
::
iterator
it
=
schema
.
begin
();
it
!=
schema
.
end
();
++
it
)
{
const
std
::
string
&
column_name
=
it
.
key
();
column_name_
.
push_back
(
column_name
);
json
it_value
=
it
.
value
();
std
::
string
str_type
=
it_value
[
"type"
];
column_data_type_
.
push_back
(
ColumnDataTypeMap
.
at
(
str_type
));
if
(
it_value
.
find
(
"shape"
)
!=
it_value
.
end
())
{
std
::
vector
<
int64_t
>
vec
(
it_value
[
"shape"
].
size
());
std
::
copy
(
it_value
[
"shape"
].
begin
(),
it_value
[
"shape"
].
end
(),
vec
.
begin
());
column_shape_
.
push_back
(
vec
);
if
(
str_type
==
"int32"
||
str_type
==
"int64"
)
{
has_integer_array
=
true
;
}
}
else
{
std
::
vector
<
int64_t
>
vec
=
{};
column_shape_
.
push_back
(
vec
);
}
}
for
(
uint64_t
i
=
0
;
i
<
column_name_
.
size
();
i
++
)
{
column_name_id_
[
column_name_
[
i
]]
=
i
;
}
auto
blob_fields
=
first_schema
->
GetBlobFields
();
for
(
const
auto
&
field
:
blob_fields
)
{
blob_column_
.
push_back
(
field
);
}
for
(
uint64_t
i
=
0
;
i
<
blob_column_
.
size
();
i
++
)
{
blob_column_id_
[
blob_column_
[
i
]]
=
i
;
}
has_compress_blob_
=
(
compress_integer
&&
has_integer_array
);
num_blob_column_
=
blob_column_
.
size
();
}
MSRStatus
ShardColumn
::
GetColumnValueByName
(
const
std
::
string
&
column_name
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
json
&
columns_json
,
const
unsigned
char
**
data
,
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
uint64_t
*
n_bytes
,
ColumnDataType
*
column_data_type
,
uint64_t
*
column_data_type_size
,
std
::
vector
<
int64_t
>
*
column_shape
)
{
// Skip if column not found
auto
column_category
=
CheckColumnName
(
column_name
);
if
(
column_category
==
ColumnNotFound
)
{
return
FAILED
;
}
// Get data type and size
auto
column_id
=
column_name_id_
[
column_name
];
*
column_data_type
=
column_data_type_
[
column_id
];
*
column_data_type_size
=
ColumnDataTypeSize
[
*
column_data_type
];
*
column_shape
=
column_shape_
[
column_id
];
// Retrieve value from json
if
(
column_category
==
ColumnInRaw
)
{
if
(
GetColumnFromJson
(
column_name
,
columns_json
,
data_ptr
,
n_bytes
)
==
FAILED
)
{
MS_LOG
(
ERROR
)
<<
"Error when get data from json, column name is "
<<
column_name
<<
"."
;
return
FAILED
;
}
*
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
data_ptr
->
get
());
return
SUCCESS
;
}
// Retrieve value from blob
if
(
GetColumnFromBlob
(
column_name
,
columns_blob
,
data
,
data_ptr
,
n_bytes
)
==
FAILED
)
{
MS_LOG
(
ERROR
)
<<
"Error when get data from blob, column name is "
<<
column_name
<<
"."
;
return
FAILED
;
}
if
(
*
data
==
nullptr
)
{
*
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
data_ptr
->
get
());
}
return
SUCCESS
;
}
MSRStatus
ShardColumn
::
GetColumnFromJson
(
const
std
::
string
&
column_name
,
const
json
&
columns_json
,
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
uint64_t
*
n_bytes
)
{
auto
column_id
=
column_name_id_
[
column_name
];
auto
column_data_type
=
column_data_type_
[
column_id
];
// Initialize num bytes
*
n_bytes
=
ColumnDataTypeSize
[
column_data_type
];
auto
json_column_value
=
columns_json
[
column_name
];
switch
(
column_data_type
)
{
case
ColumnFloat32
:
{
return
GetFloat
<
float
>
(
data_ptr
,
json_column_value
,
false
);
}
case
ColumnFloat64
:
{
return
GetFloat
<
double
>
(
data_ptr
,
json_column_value
,
true
);
}
case
ColumnInt32
:
{
return
GetInt
<
int32_t
>
(
data_ptr
,
json_column_value
);
}
case
ColumnInt64
:
{
return
GetInt
<
int64_t
>
(
data_ptr
,
json_column_value
);
}
default:
{
// Convert string to c_str
std
::
string
tmp_string
=
json_column_value
;
*
n_bytes
=
tmp_string
.
size
();
auto
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
common
::
SafeCStr
(
tmp_string
));
*
data_ptr
=
std
::
make_unique
<
unsigned
char
[]
>
(
*
n_bytes
);
for
(
uint32_t
i
=
0
;
i
<
*
n_bytes
;
i
++
)
{
(
*
data_ptr
)[
i
]
=
*
(
data
+
i
);
}
break
;
}
}
return
SUCCESS
;
}
template
<
typename
T
>
MSRStatus
ShardColumn
::
GetFloat
(
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
const
json
&
json_column_value
,
bool
use_double
)
{
std
::
unique_ptr
<
T
[]
>
array_data
=
std
::
make_unique
<
T
[]
>
(
1
);
if
(
!
json_column_value
.
is_string
()
&&
!
json_column_value
.
is_number
())
{
MS_LOG
(
ERROR
)
<<
"Conversion to float failed ("
<<
json_column_value
<<
")."
;
return
FAILED
;
}
if
(
json_column_value
.
is_number
())
{
array_data
[
0
]
=
json_column_value
;
}
else
{
// Convert string to float
try
{
if
(
use_double
)
{
array_data
[
0
]
=
json_column_value
.
get
<
double
>
();
}
else
{
array_data
[
0
]
=
json_column_value
.
get
<
float
>
();
}
}
catch
(
json
::
exception
&
e
)
{
MS_LOG
(
ERROR
)
<<
"Conversion to float failed ("
<<
json_column_value
<<
")."
;
return
FAILED
;
}
}
auto
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
array_data
.
get
());
*
data_ptr
=
std
::
make_unique
<
unsigned
char
[]
>
(
sizeof
(
T
));
for
(
uint32_t
i
=
0
;
i
<
sizeof
(
T
);
i
++
)
{
(
*
data_ptr
)[
i
]
=
*
(
data
+
i
);
}
return
SUCCESS
;
}
template
<
typename
T
>
MSRStatus
ShardColumn
::
GetInt
(
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
const
json
&
json_column_value
)
{
std
::
unique_ptr
<
T
[]
>
array_data
=
std
::
make_unique
<
T
[]
>
(
1
);
int64_t
temp_value
;
bool
less_than_zero
=
false
;
if
(
json_column_value
.
is_number_integer
())
{
const
json
json_zero
=
0
;
if
(
json_column_value
<
json_zero
)
less_than_zero
=
true
;
temp_value
=
json_column_value
;
}
else
if
(
json_column_value
.
is_string
())
{
std
::
string
string_value
=
json_column_value
;
if
(
!
string_value
.
empty
()
&&
string_value
[
0
]
==
'-'
)
{
try
{
temp_value
=
std
::
stoll
(
string_value
);
less_than_zero
=
true
;
}
catch
(
std
::
invalid_argument
&
e
)
{
MS_LOG
(
ERROR
)
<<
"Conversion to int failed, invalid argument."
;
return
FAILED
;
}
catch
(
std
::
out_of_range
&
e
)
{
MS_LOG
(
ERROR
)
<<
"Conversion to int failed, out of range."
;
return
FAILED
;
}
}
else
{
try
{
temp_value
=
static_cast
<
int64_t
>
(
std
::
stoull
(
string_value
));
}
catch
(
std
::
invalid_argument
&
e
)
{
MS_LOG
(
ERROR
)
<<
"Conversion to int failed, invalid argument."
;
return
FAILED
;
}
catch
(
std
::
out_of_range
&
e
)
{
MS_LOG
(
ERROR
)
<<
"Conversion to int failed, out of range."
;
return
FAILED
;
}
}
}
else
{
MS_LOG
(
ERROR
)
<<
"Conversion to int failed."
;
return
FAILED
;
}
if
((
less_than_zero
&&
temp_value
<
static_cast
<
int64_t
>
(
std
::
numeric_limits
<
T
>::
min
()))
||
(
!
less_than_zero
&&
static_cast
<
uint64_t
>
(
temp_value
)
>
static_cast
<
uint64_t
>
(
std
::
numeric_limits
<
T
>::
max
())))
{
MS_LOG
(
ERROR
)
<<
"Conversion to int failed. Out of range"
;
return
FAILED
;
}
array_data
[
0
]
=
static_cast
<
T
>
(
temp_value
);
auto
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
array_data
.
get
());
*
data_ptr
=
std
::
make_unique
<
unsigned
char
[]
>
(
sizeof
(
T
));
for
(
uint32_t
i
=
0
;
i
<
sizeof
(
T
);
i
++
)
{
(
*
data_ptr
)[
i
]
=
*
(
data
+
i
);
}
return
SUCCESS
;
}
MSRStatus
ShardColumn
::
GetColumnFromBlob
(
const
std
::
string
&
column_name
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
const
unsigned
char
**
data
,
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
uint64_t
*
n_bytes
)
{
uint64_t
offset_address
=
0
;
auto
column_id
=
column_name_id_
[
column_name
];
if
(
GetColumnAddressInBlock
(
column_id
,
columns_blob
,
n_bytes
,
&
offset_address
)
==
FAILED
)
{
return
FAILED
;
}
auto
column_data_type
=
column_data_type_
[
column_id
];
if
(
has_compress_blob_
&&
column_data_type
==
ColumnInt32
)
{
if
(
UncompressInt
<
int32_t
>
(
column_id
,
data_ptr
,
columns_blob
,
n_bytes
,
offset_address
)
==
FAILED
)
{
return
FAILED
;
}
}
else
if
(
has_compress_blob_
&&
column_data_type
==
ColumnInt64
)
{
if
(
UncompressInt
<
int64_t
>
(
column_id
,
data_ptr
,
columns_blob
,
n_bytes
,
offset_address
)
==
FAILED
)
{
return
FAILED
;
}
}
else
{
*
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
&
(
columns_blob
[
offset_address
]));
}
return
SUCCESS
;
}
ColumnCategory
ShardColumn
::
CheckColumnName
(
const
std
::
string
&
column_name
)
{
auto
it_column
=
column_name_id_
.
find
(
column_name
);
if
(
it_column
==
column_name_id_
.
end
())
{
return
ColumnNotFound
;
}
auto
it_blob
=
blob_column_id_
.
find
(
column_name
);
return
it_blob
==
blob_column_id_
.
end
()
?
ColumnInRaw
:
ColumnInBlob
;
}
std
::
vector
<
uint8_t
>
ShardColumn
::
CompressBlob
(
const
std
::
vector
<
uint8_t
>
&
blob
)
{
// Skip if no compress columns
if
(
!
CheckCompressBlob
())
return
blob
;
std
::
vector
<
uint8_t
>
dst_blob
;
uint64_t
i_src
=
0
;
for
(
int64_t
i
=
0
;
i
<
num_blob_column_
;
i
++
)
{
// Get column data type
auto
src_data_type
=
column_data_type_
[
column_name_id_
[
blob_column_
[
i
]]];
auto
int_type
=
src_data_type
==
ColumnInt32
?
kInt32Type
:
kInt64Type
;
// Compress and return is blob has 1 column only
if
(
num_blob_column_
==
1
)
{
return
CompressInt
(
blob
,
int_type
);
}
// Just copy and continue if column dat type is not int32/int64
uint64_t
num_bytes
=
BytesBigToUInt64
(
blob
,
i_src
,
kInt64Type
);
if
(
src_data_type
!=
ColumnInt32
&&
src_data_type
!=
ColumnInt64
)
{
dst_blob
.
insert
(
dst_blob
.
end
(),
blob
.
begin
()
+
i_src
,
blob
.
begin
()
+
i_src
+
kInt64Len
+
num_bytes
);
i_src
+=
kInt64Len
+
num_bytes
;
continue
;
}
// Get column slice in source blob
std
::
vector
<
uint8_t
>
blob_slice
(
blob
.
begin
()
+
i_src
+
kInt64Len
,
blob
.
begin
()
+
i_src
+
kInt64Len
+
num_bytes
);
// Compress column
auto
dst_blob_slice
=
CompressInt
(
blob_slice
,
int_type
);
// Get new column size
auto
new_blob_size
=
UIntToBytesBig
(
dst_blob_slice
.
size
(),
kInt64Type
);
// Append new colmn size
dst_blob
.
insert
(
dst_blob
.
end
(),
new_blob_size
.
begin
(),
new_blob_size
.
end
());
// Append new colmn data
dst_blob
.
insert
(
dst_blob
.
end
(),
dst_blob_slice
.
begin
(),
dst_blob_slice
.
end
());
i_src
+=
kInt64Len
+
num_bytes
;
}
MS_LOG
(
DEBUG
)
<<
"Compress all blob from "
<<
blob
.
size
()
<<
" to "
<<
dst_blob
.
size
()
<<
"."
;
return
dst_blob
;
}
vector
<
uint8_t
>
ShardColumn
::
CompressInt
(
const
vector
<
uint8_t
>
&
src_bytes
,
const
IntegerType
&
int_type
)
{
uint64_t
i_size
=
kUnsignedOne
<<
int_type
;
// Get number of elements
uint64_t
src_n_int
=
src_bytes
.
size
()
/
i_size
;
// Calculate bitmap size (bytes)
uint64_t
bitmap_size
=
(
src_n_int
+
kNumDataOfByte
-
1
)
/
kNumDataOfByte
;
// Initilize destination blob, more space than needed, will be resized
vector
<
uint8_t
>
dst_bytes
(
kBytesOfColumnLen
+
bitmap_size
+
src_bytes
.
size
(),
0
);
// Write number of elements to destination blob
vector
<
uint8_t
>
size_by_bytes
=
UIntToBytesBig
(
src_n_int
,
kInt32Type
);
for
(
uint64_t
n
=
0
;
n
<
kBytesOfColumnLen
;
n
++
)
{
dst_bytes
[
n
]
=
size_by_bytes
[
n
];
}
// Write compressed int
uint64_t
i_dst
=
kBytesOfColumnLen
+
bitmap_size
;
for
(
uint64_t
i
=
0
;
i
<
src_n_int
;
i
++
)
{
// Initialize destination data type
IntegerType
dst_int_type
=
kInt8Type
;
// Shift to next int position
uint64_t
pos
=
i
*
(
kUnsignedOne
<<
int_type
);
// Narrow down this int
int64_t
i_n
=
BytesLittleToMinIntType
(
src_bytes
,
pos
,
int_type
,
&
dst_int_type
);
// Write this int to destination blob
uint64_t
u_n
=
*
reinterpret_cast
<
uint64_t
*>
(
&
i_n
);
auto
temp_bytes
=
UIntToBytesLittle
(
u_n
,
dst_int_type
);
for
(
uint64_t
j
=
0
;
j
<
(
kUnsignedOne
<<
dst_int_type
);
j
++
)
{
dst_bytes
[
i_dst
++
]
=
temp_bytes
[
j
];
}
// Update date type in bit map
dst_bytes
[
i
/
kNumDataOfByte
+
kBytesOfColumnLen
]
|=
(
dst_int_type
<<
(
kDataTypeBits
*
(
kNumDataOfByte
-
kUnsignedOne
-
(
i
%
kNumDataOfByte
))));
}
// Resize destination blob
dst_bytes
.
resize
(
i_dst
);
MS_LOG
(
DEBUG
)
<<
"Compress blob field from "
<<
src_bytes
.
size
()
<<
" to "
<<
dst_bytes
.
size
()
<<
"."
;
return
dst_bytes
;
}
MSRStatus
ShardColumn
::
GetColumnAddressInBlock
(
const
uint64_t
&
column_id
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
uint64_t
*
num_bytes
,
uint64_t
*
shift_idx
)
{
if
(
num_blob_column_
==
1
)
{
*
num_bytes
=
columns_blob
.
size
();
*
shift_idx
=
0
;
return
SUCCESS
;
}
auto
blob_id
=
blob_column_id_
[
column_name_
[
column_id
]];
for
(
int32_t
i
=
0
;
i
<
blob_id
;
i
++
)
{
*
shift_idx
+=
kInt64Len
+
BytesBigToUInt64
(
columns_blob
,
*
shift_idx
,
kInt64Type
);
}
*
num_bytes
=
BytesBigToUInt64
(
columns_blob
,
*
shift_idx
,
kInt64Type
);
(
*
shift_idx
)
+=
kInt64Len
;
return
SUCCESS
;
}
template
<
typename
T
>
MSRStatus
ShardColumn
::
UncompressInt
(
const
uint64_t
&
column_id
,
std
::
unique_ptr
<
unsigned
char
[]
>
*
data_ptr
,
const
std
::
vector
<
uint8_t
>
&
columns_blob
,
uint64_t
*
num_bytes
,
uint64_t
shift_idx
)
{
auto
num_elements
=
BytesBigToUInt64
(
columns_blob
,
shift_idx
,
kInt32Type
);
*
num_bytes
=
sizeof
(
T
)
*
num_elements
;
// Parse integer array
uint64_t
i_source
=
shift_idx
+
kBytesOfColumnLen
+
(
num_elements
+
kNumDataOfByte
-
1
)
/
kNumDataOfByte
;
auto
array_data
=
std
::
make_unique
<
T
[]
>
(
num_elements
);
for
(
uint64_t
i
=
0
;
i
<
num_elements
;
i
++
)
{
uint8_t
iBitMap
=
columns_blob
[
shift_idx
+
kBytesOfColumnLen
+
i
/
kNumDataOfByte
];
uint64_t
i_type
=
(
iBitMap
>>
((
kNumDataOfByte
-
1
-
(
i
%
kNumDataOfByte
))
*
kDataTypeBits
))
&
kDataTypeBitMask
;
auto
mr_int_type
=
static_cast
<
IntegerType
>
(
i_type
);
int64_t
i64
=
BytesLittleToMinIntType
(
columns_blob
,
i_source
,
mr_int_type
);
i_source
+=
(
kUnsignedOne
<<
i_type
);
array_data
[
i
]
=
static_cast
<
T
>
(
i64
);
}
auto
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
array_data
.
get
());
*
data_ptr
=
std
::
make_unique
<
unsigned
char
[]
>
(
*
num_bytes
);
memcpy
(
data_ptr
->
get
(),
data
,
*
num_bytes
);
return
SUCCESS
;
}
uint64_t
ShardColumn
::
BytesBigToUInt64
(
const
std
::
vector
<
uint8_t
>
&
bytes_array
,
const
uint64_t
&
pos
,
const
IntegerType
&
i_type
)
{
uint64_t
result
=
0
;
for
(
uint64_t
i
=
0
;
i
<
(
kUnsignedOne
<<
i_type
);
i
++
)
{
result
=
(
result
<<
kBitsOfByte
)
+
bytes_array
[
pos
+
i
];
}
return
result
;
}
std
::
vector
<
uint8_t
>
ShardColumn
::
UIntToBytesBig
(
uint64_t
value
,
const
IntegerType
&
i_type
)
{
uint64_t
n_bytes
=
kUnsignedOne
<<
i_type
;
std
::
vector
<
uint8_t
>
result
(
n_bytes
,
0
);
for
(
uint64_t
i
=
0
;
i
<
n_bytes
;
i
++
)
{
result
[
n_bytes
-
1
-
i
]
=
value
&
std
::
numeric_limits
<
uint8_t
>::
max
();
value
>>=
kBitsOfByte
;
}
return
result
;
}
std
::
vector
<
uint8_t
>
ShardColumn
::
UIntToBytesLittle
(
uint64_t
value
,
const
IntegerType
&
i_type
)
{
uint64_t
n_bytes
=
kUnsignedOne
<<
i_type
;
std
::
vector
<
uint8_t
>
result
(
n_bytes
,
0
);
for
(
uint64_t
i
=
0
;
i
<
n_bytes
;
i
++
)
{
result
[
i
]
=
value
&
std
::
numeric_limits
<
uint8_t
>::
max
();
value
>>=
kBitsOfByte
;
}
return
result
;
}
int64_t
ShardColumn
::
BytesLittleToMinIntType
(
const
std
::
vector
<
uint8_t
>
&
bytes_array
,
const
uint64_t
&
pos
,
const
IntegerType
&
src_i_type
,
IntegerType
*
dst_i_type
)
{
uint64_t
u_temp
=
0
;
for
(
uint64_t
i
=
0
;
i
<
(
kUnsignedOne
<<
src_i_type
);
i
++
)
{
u_temp
=
(
u_temp
<<
kBitsOfByte
)
+
bytes_array
[
pos
+
(
kUnsignedOne
<<
src_i_type
)
-
kUnsignedOne
-
i
];
}
int64_t
i_out
;
switch
(
src_i_type
)
{
case
kInt8Type
:
{
i_out
=
(
int8_t
)(
u_temp
&
std
::
numeric_limits
<
uint8_t
>::
max
());
break
;
}
case
kInt16Type
:
{
i_out
=
(
int16_t
)(
u_temp
&
std
::
numeric_limits
<
uint16_t
>::
max
());
break
;
}
case
kInt32Type
:
{
i_out
=
(
int32_t
)(
u_temp
&
std
::
numeric_limits
<
uint32_t
>::
max
());
break
;
}
case
kInt64Type
:
{
i_out
=
(
int64_t
)(
u_temp
&
std
::
numeric_limits
<
uint64_t
>::
max
());
break
;
}
default:
{
i_out
=
0
;
}
}
if
(
!
dst_i_type
)
{
return
i_out
;
}
if
(
i_out
>=
static_cast
<
int64_t
>
(
std
::
numeric_limits
<
int8_t
>::
min
())
&&
i_out
<=
static_cast
<
int64_t
>
(
std
::
numeric_limits
<
int8_t
>::
max
()))
{
*
dst_i_type
=
kInt8Type
;
}
else
if
(
i_out
>=
static_cast
<
int64_t
>
(
std
::
numeric_limits
<
int16_t
>::
min
())
&&
i_out
<=
static_cast
<
int64_t
>
(
std
::
numeric_limits
<
int16_t
>::
max
()))
{
*
dst_i_type
=
kInt16Type
;
}
else
if
(
i_out
>=
static_cast
<
int64_t
>
(
std
::
numeric_limits
<
int32_t
>::
min
())
&&
i_out
<=
static_cast
<
int64_t
>
(
std
::
numeric_limits
<
int32_t
>::
max
()))
{
*
dst_i_type
=
kInt32Type
;
}
else
{
*
dst_i_type
=
kInt64Type
;
}
return
i_out
;
}
}
// namespace mindrecord
}
// namespace mindspore
mindspore/ccsrc/mindrecord/meta/shard_header.cc
浏览文件 @
decf12cd
...
@@ -201,9 +201,9 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade
...
@@ -201,9 +201,9 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade
json
header
;
json
header
;
header
=
ret
.
second
;
header
=
ret
.
second
;
header
[
"shard_addresses"
]
=
realAddresses
;
header
[
"shard_addresses"
]
=
realAddresses
;
if
(
header
[
"version"
]
!=
version_
)
{
if
(
std
::
find
(
kSupportedVersion
.
begin
(),
kSupportedVersion
.
end
(),
header
[
"version"
])
==
kSupportedVersion
.
end
()
)
{
MS_LOG
(
ERROR
)
<<
"Version wrong, file version is: "
<<
header
[
"version"
].
dump
()
MS_LOG
(
ERROR
)
<<
"Version wrong, file version is: "
<<
header
[
"version"
].
dump
()
<<
", lib version is: "
<<
version_
;
<<
", lib version is: "
<<
kVersion
;
thread_status
=
true
;
thread_status
=
true
;
return
;
return
;
}
}
...
@@ -339,7 +339,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
...
@@ -339,7 +339,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
s
+=
"
\"
shard_addresses
\"
:"
+
address
+
","
;
s
+=
"
\"
shard_addresses
\"
:"
+
address
+
","
;
s
+=
"
\"
shard_id
\"
:"
+
std
::
to_string
(
shardId
)
+
","
;
s
+=
"
\"
shard_id
\"
:"
+
std
::
to_string
(
shardId
)
+
","
;
s
+=
"
\"
statistics
\"
:"
+
stats
+
","
;
s
+=
"
\"
statistics
\"
:"
+
stats
+
","
;
s
+=
"
\"
version
\"
:
\"
"
+
version_
+
"
\"
"
;
s
+=
"
\"
version
\"
:
\"
"
+
std
::
string
(
kVersion
)
+
"
\"
"
;
s
+=
"}"
;
s
+=
"}"
;
header
.
emplace_back
(
s
);
header
.
emplace_back
(
s
);
}
}
...
...
mindspore/mindrecord/shardutils.py
浏览文件 @
decf12cd
...
@@ -97,16 +97,13 @@ def populate_data(raw, blob, columns, blob_fields, schema):
...
@@ -97,16 +97,13 @@ def populate_data(raw, blob, columns, blob_fields, schema):
if
not
blob_fields
:
if
not
blob_fields
:
return
raw
return
raw
# Get the order preserving sequence of columns in blob
loaded_columns
=
[]
ordered_columns
=
[]
if
columns
:
if
columns
:
for
blob_field
in
blob_field
s
:
for
column
in
column
s
:
if
blob_field
in
column
s
:
if
column
in
blob_field
s
:
ordered_columns
.
append
(
blob_field
)
loaded_columns
.
append
(
column
)
else
:
else
:
ordered_columns
=
blob_fields
loaded_columns
=
blob_fields
blob_bytes
=
bytes
(
blob
)
def
_render_raw
(
field
,
blob_data
):
def
_render_raw
(
field
,
blob_data
):
data_type
=
schema
[
field
][
'type'
]
data_type
=
schema
[
field
][
'type'
]
...
@@ -119,24 +116,6 @@ def populate_data(raw, blob, columns, blob_fields, schema):
...
@@ -119,24 +116,6 @@ def populate_data(raw, blob, columns, blob_fields, schema):
else
:
else
:
raw
[
field
]
=
blob_data
raw
[
field
]
=
blob_data
if
len
(
blob_fields
)
==
1
:
for
i
,
blob_field
in
enumerate
(
loaded_columns
):
if
len
(
ordered_columns
)
==
1
:
_render_raw
(
blob_field
,
bytes
(
blob
[
i
]))
_render_raw
(
blob_fields
[
0
],
blob_bytes
)
return
raw
return
raw
def
_int_from_bytes
(
xbytes
:
bytes
)
->
int
:
return
int
.
from_bytes
(
xbytes
,
'big'
)
def
_blob_at_position
(
pos
):
start
=
0
for
_
in
range
(
pos
):
n_bytes
=
_int_from_bytes
(
blob_bytes
[
start
:
start
+
8
])
start
+=
8
+
n_bytes
n_bytes
=
_int_from_bytes
(
blob_bytes
[
start
:
start
+
8
])
start
+=
8
return
blob_bytes
[
start
:
start
+
n_bytes
]
for
i
,
blob_field
in
enumerate
(
ordered_columns
):
_render_raw
(
blob_field
,
_blob_at_position
(
i
))
return
raw
return
raw
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord0
0 → 100644
浏览文件 @
decf12cd
文件已添加
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord0.db
0 → 100644
浏览文件 @
decf12cd
文件已添加
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord1
0 → 100644
浏览文件 @
decf12cd
文件已添加
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord1.db
0 → 100644
浏览文件 @
decf12cd
文件已添加
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord2
0 → 100644
浏览文件 @
decf12cd
文件已添加
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord2.db
0 → 100644
浏览文件 @
decf12cd
文件已添加
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord3
0 → 100644
浏览文件 @
decf12cd
文件已添加
tests/ut/data/mindrecord/testOldVersion/aclImdb.mindrecord3.db
0 → 100644
浏览文件 @
decf12cd
文件已添加
tests/ut/python/dataset/test_minddataset.py
浏览文件 @
decf12cd
此差异已折叠。
点击以展开。
tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py
浏览文件 @
decf12cd
...
@@ -25,8 +25,24 @@ from mindspore.mindrecord import SUCCESS
...
@@ -25,8 +25,24 @@ from mindspore.mindrecord import SUCCESS
CIFAR100_DIR
=
"../data/mindrecord/testCifar100Data"
CIFAR100_DIR
=
"../data/mindrecord/testCifar100Data"
MINDRECORD_FILE
=
"./cifar100.mindrecord"
MINDRECORD_FILE
=
"./cifar100.mindrecord"
@
pytest
.
fixture
def
test_cifar100_to_mindrecord_without_index_fields
():
def
fixture_file
():
"""add/remove file"""
def
remove_file
(
x
):
if
os
.
path
.
exists
(
"{}"
.
format
(
x
)):
os
.
remove
(
"{}"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
)):
os
.
remove
(
"{}.db"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}_test"
.
format
(
x
)):
os
.
remove
(
"{}_test"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}_test.db"
.
format
(
x
)):
os
.
remove
(
"{}_test.db"
.
format
(
x
))
remove_file
(
MINDRECORD_FILE
)
yield
"yield_fixture_data"
remove_file
(
MINDRECORD_FILE
)
def
test_cifar100_to_mindrecord_without_index_fields
(
fixture_file
):
"""test transform cifar100 dataset to mindrecord without index fields."""
"""test transform cifar100 dataset to mindrecord without index fields."""
cifar100_transformer
=
Cifar100ToMR
(
CIFAR100_DIR
,
MINDRECORD_FILE
)
cifar100_transformer
=
Cifar100ToMR
(
CIFAR100_DIR
,
MINDRECORD_FILE
)
ret
=
cifar100_transformer
.
transform
()
ret
=
cifar100_transformer
.
transform
()
...
@@ -34,25 +50,14 @@ def test_cifar100_to_mindrecord_without_index_fields():
...
@@ -34,25 +50,14 @@ def test_cifar100_to_mindrecord_without_index_fields():
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
"_test"
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
"_test"
)
read
()
read
()
os
.
remove
(
"{}"
.
format
(
MINDRECORD_FILE
))
os
.
remove
(
"{}.db"
.
format
(
MINDRECORD_FILE
))
os
.
remove
(
"{}"
.
format
(
MINDRECORD_FILE
+
"_test"
))
os
.
remove
(
"{}.db"
.
format
(
MINDRECORD_FILE
+
"_test"
))
def
test_cifar100_to_mindrecord
():
def
test_cifar100_to_mindrecord
(
fixture_file
):
"""test transform cifar100 dataset to mindrecord."""
"""test transform cifar100 dataset to mindrecord."""
cifar100_transformer
=
Cifar100ToMR
(
CIFAR100_DIR
,
MINDRECORD_FILE
)
cifar100_transformer
=
Cifar100ToMR
(
CIFAR100_DIR
,
MINDRECORD_FILE
)
cifar100_transformer
.
transform
([
'fine_label'
,
'coarse_label'
])
cifar100_transformer
.
transform
([
'fine_label'
,
'coarse_label'
])
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
"_test"
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
"_test"
)
read
()
read
()
os
.
remove
(
"{}"
.
format
(
MINDRECORD_FILE
))
os
.
remove
(
"{}.db"
.
format
(
MINDRECORD_FILE
))
os
.
remove
(
"{}"
.
format
(
MINDRECORD_FILE
+
"_test"
))
os
.
remove
(
"{}.db"
.
format
(
MINDRECORD_FILE
+
"_test"
))
def
read
():
def
read
():
...
@@ -77,8 +82,7 @@ def read():
...
@@ -77,8 +82,7 @@ def read():
assert
count
==
4
assert
count
==
4
reader
.
close
()
reader
.
close
()
def
test_cifar100_to_mindrecord_illegal_file_name
(
fixture_file
):
def
test_cifar100_to_mindrecord_illegal_file_name
():
"""
"""
test transform cifar100 dataset to mindrecord
test transform cifar100 dataset to mindrecord
when file name contains illegal character.
when file name contains illegal character.
...
@@ -88,8 +92,7 @@ def test_cifar100_to_mindrecord_illegal_file_name():
...
@@ -88,8 +92,7 @@ def test_cifar100_to_mindrecord_illegal_file_name():
cifar100_transformer
=
Cifar100ToMR
(
CIFAR100_DIR
,
filename
)
cifar100_transformer
=
Cifar100ToMR
(
CIFAR100_DIR
,
filename
)
cifar100_transformer
.
transform
()
cifar100_transformer
.
transform
()
def
test_cifar100_to_mindrecord_filename_start_with_space
(
fixture_file
):
def
test_cifar100_to_mindrecord_filename_start_with_space
():
"""
"""
test transform cifar10 dataset to mindrecord
test transform cifar10 dataset to mindrecord
when file name starts with space.
when file name starts with space.
...
@@ -100,8 +103,7 @@ def test_cifar100_to_mindrecord_filename_start_with_space():
...
@@ -100,8 +103,7 @@ def test_cifar100_to_mindrecord_filename_start_with_space():
cifar100_transformer
=
Cifar100ToMR
(
CIFAR100_DIR
,
filename
)
cifar100_transformer
=
Cifar100ToMR
(
CIFAR100_DIR
,
filename
)
cifar100_transformer
.
transform
()
cifar100_transformer
.
transform
()
def
test_cifar100_to_mindrecord_filename_contain_space
(
fixture_file
):
def
test_cifar100_to_mindrecord_filename_contain_space
():
"""
"""
test transform cifar10 dataset to mindrecord
test transform cifar10 dataset to mindrecord
when file name contains space.
when file name contains space.
...
@@ -111,14 +113,8 @@ def test_cifar100_to_mindrecord_filename_contain_space():
...
@@ -111,14 +113,8 @@ def test_cifar100_to_mindrecord_filename_contain_space():
cifar100_transformer
.
transform
()
cifar100_transformer
.
transform
()
assert
os
.
path
.
exists
(
filename
)
assert
os
.
path
.
exists
(
filename
)
assert
os
.
path
.
exists
(
filename
+
"_test"
)
assert
os
.
path
.
exists
(
filename
+
"_test"
)
os
.
remove
(
"{}"
.
format
(
filename
))
os
.
remove
(
"{}.db"
.
format
(
filename
))
os
.
remove
(
"{}"
.
format
(
filename
+
"_test"
))
os
.
remove
(
"{}.db"
.
format
(
filename
+
"_test"
))
def
test_cifar100_to_mindrecord_directory
(
fixture_file
):
def
test_cifar100_to_mindrecord_directory
():
"""
"""
test transform cifar10 dataset to mindrecord
test transform cifar10 dataset to mindrecord
when destination path is directory.
when destination path is directory.
...
@@ -129,8 +125,7 @@ def test_cifar100_to_mindrecord_directory():
...
@@ -129,8 +125,7 @@ def test_cifar100_to_mindrecord_directory():
CIFAR100_DIR
)
CIFAR100_DIR
)
cifar100_transformer
.
transform
()
cifar100_transformer
.
transform
()
def
test_cifar100_to_mindrecord_filename_equals_cifar100
(
fixture_file
):
def
test_cifar100_to_mindrecord_filename_equals_cifar100
():
"""
"""
test transform cifar10 dataset to mindrecord
test transform cifar10 dataset to mindrecord
when destination path equals source path.
when destination path equals source path.
...
...
tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py
浏览文件 @
decf12cd
...
@@ -24,36 +24,60 @@ from mindspore.mindrecord import MRMOpenError, SUCCESS
...
@@ -24,36 +24,60 @@ from mindspore.mindrecord import MRMOpenError, SUCCESS
CIFAR10_DIR
=
"../data/mindrecord/testCifar10Data"
CIFAR10_DIR
=
"../data/mindrecord/testCifar10Data"
MINDRECORD_FILE
=
"./cifar10.mindrecord"
MINDRECORD_FILE
=
"./cifar10.mindrecord"
@
pytest
.
fixture
def
test_cifar10_to_mindrecord_without_index_fields
():
def
fixture_file
():
"""add/remove file"""
def
remove_file
(
x
):
if
os
.
path
.
exists
(
"{}"
.
format
(
x
)):
os
.
remove
(
"{}"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
)):
os
.
remove
(
"{}.db"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}_test"
.
format
(
x
)):
os
.
remove
(
"{}_test"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}_test.db"
.
format
(
x
)):
os
.
remove
(
"{}_test.db"
.
format
(
x
))
remove_file
(
MINDRECORD_FILE
)
yield
"yield_fixture_data"
remove_file
(
MINDRECORD_FILE
)
@
pytest
.
fixture
def
fixture_space_file
():
"""add/remove file"""
def
remove_file
(
x
):
if
os
.
path
.
exists
(
"{}"
.
format
(
x
)):
os
.
remove
(
"{}"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
)):
os
.
remove
(
"{}.db"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}_test"
.
format
(
x
)):
os
.
remove
(
"{}_test"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}_test.db"
.
format
(
x
)):
os
.
remove
(
"{}_test.db"
.
format
(
x
))
x
=
"./yes ok"
remove_file
(
x
)
yield
"yield_fixture_data"
remove_file
(
x
)
def
test_cifar10_to_mindrecord_without_index_fields
(
fixture_file
):
"""test transform cifar10 dataset to mindrecord without index fields."""
"""test transform cifar10 dataset to mindrecord without index fields."""
cifar10_transformer
=
Cifar10ToMR
(
CIFAR10_DIR
,
MINDRECORD_FILE
)
cifar10_transformer
=
Cifar10ToMR
(
CIFAR10_DIR
,
MINDRECORD_FILE
)
cifar10_transformer
.
transform
()
cifar10_transformer
.
transform
()
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
"_test"
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
"_test"
)
read
()
read
()
os
.
remove
(
"{}"
.
format
(
MINDRECORD_FILE
))
os
.
remove
(
"{}.db"
.
format
(
MINDRECORD_FILE
))
os
.
remove
(
"{}"
.
format
(
MINDRECORD_FILE
+
"_test"
))
os
.
remove
(
"{}.db"
.
format
(
MINDRECORD_FILE
+
"_test"
))
def
test_cifar10_to_mindrecord
():
def
test_cifar10_to_mindrecord
(
fixture_file
):
"""test transform cifar10 dataset to mindrecord."""
"""test transform cifar10 dataset to mindrecord."""
cifar10_transformer
=
Cifar10ToMR
(
CIFAR10_DIR
,
MINDRECORD_FILE
)
cifar10_transformer
=
Cifar10ToMR
(
CIFAR10_DIR
,
MINDRECORD_FILE
)
cifar10_transformer
.
transform
([
'label'
])
cifar10_transformer
.
transform
([
'label'
])
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
"_test"
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
"_test"
)
read
()
read
()
os
.
remove
(
"{}"
.
format
(
MINDRECORD_FILE
))
os
.
remove
(
"{}.db"
.
format
(
MINDRECORD_FILE
))
os
.
remove
(
"{}"
.
format
(
MINDRECORD_FILE
+
"_test"
))
os
.
remove
(
"{}.db"
.
format
(
MINDRECORD_FILE
+
"_test"
))
def
test_cifar10_to_mindrecord_with_return
():
def
test_cifar10_to_mindrecord_with_return
(
fixture_file
):
"""test transform cifar10 dataset to mindrecord."""
"""test transform cifar10 dataset to mindrecord."""
cifar10_transformer
=
Cifar10ToMR
(
CIFAR10_DIR
,
MINDRECORD_FILE
)
cifar10_transformer
=
Cifar10ToMR
(
CIFAR10_DIR
,
MINDRECORD_FILE
)
ret
=
cifar10_transformer
.
transform
([
'label'
])
ret
=
cifar10_transformer
.
transform
([
'label'
])
...
@@ -61,11 +85,6 @@ def test_cifar10_to_mindrecord_with_return():
...
@@ -61,11 +85,6 @@ def test_cifar10_to_mindrecord_with_return():
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
"_test"
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
"_test"
)
read
()
read
()
os
.
remove
(
"{}"
.
format
(
MINDRECORD_FILE
))
os
.
remove
(
"{}.db"
.
format
(
MINDRECORD_FILE
))
os
.
remove
(
"{}"
.
format
(
MINDRECORD_FILE
+
"_test"
))
os
.
remove
(
"{}.db"
.
format
(
MINDRECORD_FILE
+
"_test"
))
def
read
():
def
read
():
...
@@ -90,8 +109,7 @@ def read():
...
@@ -90,8 +109,7 @@ def read():
assert
count
==
4
assert
count
==
4
reader
.
close
()
reader
.
close
()
def
test_cifar10_to_mindrecord_illegal_file_name
(
fixture_file
):
def
test_cifar10_to_mindrecord_illegal_file_name
():
"""
"""
test transform cifar10 dataset to mindrecord
test transform cifar10 dataset to mindrecord
when file name contains illegal character.
when file name contains illegal character.
...
@@ -101,8 +119,7 @@ def test_cifar10_to_mindrecord_illegal_file_name():
...
@@ -101,8 +119,7 @@ def test_cifar10_to_mindrecord_illegal_file_name():
cifar10_transformer
=
Cifar10ToMR
(
CIFAR10_DIR
,
filename
)
cifar10_transformer
=
Cifar10ToMR
(
CIFAR10_DIR
,
filename
)
cifar10_transformer
.
transform
()
cifar10_transformer
.
transform
()
def
test_cifar10_to_mindrecord_filename_start_with_space
(
fixture_file
):
def
test_cifar10_to_mindrecord_filename_start_with_space
():
"""
"""
test transform cifar10 dataset to mindrecord
test transform cifar10 dataset to mindrecord
when file name starts with space.
when file name starts with space.
...
@@ -113,8 +130,7 @@ def test_cifar10_to_mindrecord_filename_start_with_space():
...
@@ -113,8 +130,7 @@ def test_cifar10_to_mindrecord_filename_start_with_space():
cifar10_transformer
=
Cifar10ToMR
(
CIFAR10_DIR
,
filename
)
cifar10_transformer
=
Cifar10ToMR
(
CIFAR10_DIR
,
filename
)
cifar10_transformer
.
transform
()
cifar10_transformer
.
transform
()
def
test_cifar10_to_mindrecord_filename_contain_space
(
fixture_space_file
):
def
test_cifar10_to_mindrecord_filename_contain_space
():
"""
"""
test transform cifar10 dataset to mindrecord
test transform cifar10 dataset to mindrecord
when file name contains space.
when file name contains space.
...
@@ -124,14 +140,8 @@ def test_cifar10_to_mindrecord_filename_contain_space():
...
@@ -124,14 +140,8 @@ def test_cifar10_to_mindrecord_filename_contain_space():
cifar10_transformer
.
transform
()
cifar10_transformer
.
transform
()
assert
os
.
path
.
exists
(
filename
)
assert
os
.
path
.
exists
(
filename
)
assert
os
.
path
.
exists
(
filename
+
"_test"
)
assert
os
.
path
.
exists
(
filename
+
"_test"
)
os
.
remove
(
"{}"
.
format
(
filename
))
os
.
remove
(
"{}.db"
.
format
(
filename
))
os
.
remove
(
"{}"
.
format
(
filename
+
"_test"
))
os
.
remove
(
"{}.db"
.
format
(
filename
+
"_test"
))
def
test_cifar10_to_mindrecord_directory
():
def
test_cifar10_to_mindrecord_directory
(
fixture_file
):
"""
"""
test transform cifar10 dataset to mindrecord
test transform cifar10 dataset to mindrecord
when destination path is directory.
when destination path is directory.
...
...
tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py
浏览文件 @
decf12cd
...
@@ -25,6 +25,26 @@ IMAGENET_IMAGE_DIR = "../data/mindrecord/testImageNetDataWhole/images"
...
@@ -25,6 +25,26 @@ IMAGENET_IMAGE_DIR = "../data/mindrecord/testImageNetDataWhole/images"
MINDRECORD_FILE
=
"../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord"
MINDRECORD_FILE
=
"../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord"
PARTITION_NUMBER
=
4
PARTITION_NUMBER
=
4
@
pytest
.
fixture
def
fixture_file
():
"""add/remove file"""
def
remove_one_file
(
x
):
if
os
.
path
.
exists
(
x
):
os
.
remove
(
x
)
def
remove_file
():
x
=
MINDRECORD_FILE
remove_one_file
(
x
)
x
=
MINDRECORD_FILE
+
".db"
remove_one_file
(
x
)
for
i
in
range
(
PARTITION_NUMBER
):
x
=
MINDRECORD_FILE
+
str
(
i
)
remove_one_file
(
x
)
x
=
MINDRECORD_FILE
+
str
(
i
)
+
".db"
remove_one_file
(
x
)
remove_file
()
yield
"yield_fixture_data"
remove_file
()
def
read
(
filename
):
def
read
(
filename
):
"""test file reade"""
"""test file reade"""
...
@@ -38,8 +58,7 @@ def read(filename):
...
@@ -38,8 +58,7 @@ def read(filename):
assert
count
==
20
assert
count
==
20
reader
.
close
()
reader
.
close
()
def
test_imagenet_to_mindrecord
(
fixture_file
):
def
test_imagenet_to_mindrecord
():
"""test transform imagenet dataset to mindrecord."""
"""test transform imagenet dataset to mindrecord."""
imagenet_transformer
=
ImageNetToMR
(
IMAGENET_MAP_FILE
,
IMAGENET_IMAGE_DIR
,
imagenet_transformer
=
ImageNetToMR
(
IMAGENET_MAP_FILE
,
IMAGENET_IMAGE_DIR
,
MINDRECORD_FILE
,
PARTITION_NUMBER
)
MINDRECORD_FILE
,
PARTITION_NUMBER
)
...
@@ -48,12 +67,8 @@ def test_imagenet_to_mindrecord():
...
@@ -48,12 +67,8 @@ def test_imagenet_to_mindrecord():
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
str
(
i
))
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
str
(
i
))
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
str
(
i
)
+
".db"
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
str
(
i
)
+
".db"
)
read
(
MINDRECORD_FILE
+
"0"
)
read
(
MINDRECORD_FILE
+
"0"
)
for
i
in
range
(
PARTITION_NUMBER
):
os
.
remove
(
MINDRECORD_FILE
+
str
(
i
))
os
.
remove
(
MINDRECORD_FILE
+
str
(
i
)
+
".db"
)
def
test_imagenet_to_mindrecord_default_partition_number
():
def
test_imagenet_to_mindrecord_default_partition_number
(
fixture_file
):
"""
"""
test transform imagenet dataset to mindrecord
test transform imagenet dataset to mindrecord
when partition number is default.
when partition number is default.
...
@@ -64,11 +79,8 @@ def test_imagenet_to_mindrecord_default_partition_number():
...
@@ -64,11 +79,8 @@ def test_imagenet_to_mindrecord_default_partition_number():
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
".db"
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE
+
".db"
)
read
(
MINDRECORD_FILE
)
read
(
MINDRECORD_FILE
)
os
.
remove
(
"{}"
.
format
(
MINDRECORD_FILE
))
os
.
remove
(
"{}.db"
.
format
(
MINDRECORD_FILE
))
def
test_imagenet_to_mindrecord_partition_number_0
(
fixture_file
):
def
test_imagenet_to_mindrecord_partition_number_0
():
"""
"""
test transform imagenet dataset to mindrecord
test transform imagenet dataset to mindrecord
when partition number is 0.
when partition number is 0.
...
@@ -79,8 +91,7 @@ def test_imagenet_to_mindrecord_partition_number_0():
...
@@ -79,8 +91,7 @@ def test_imagenet_to_mindrecord_partition_number_0():
MINDRECORD_FILE
,
0
)
MINDRECORD_FILE
,
0
)
imagenet_transformer
.
transform
()
imagenet_transformer
.
transform
()
def
test_imagenet_to_mindrecord_partition_number_none
(
fixture_file
):
def
test_imagenet_to_mindrecord_partition_number_none
():
"""
"""
test transform imagenet dataset to mindrecord
test transform imagenet dataset to mindrecord
when partition number is none.
when partition number is none.
...
@@ -92,8 +103,7 @@ def test_imagenet_to_mindrecord_partition_number_none():
...
@@ -92,8 +103,7 @@ def test_imagenet_to_mindrecord_partition_number_none():
MINDRECORD_FILE
,
None
)
MINDRECORD_FILE
,
None
)
imagenet_transformer
.
transform
()
imagenet_transformer
.
transform
()
def
test_imagenet_to_mindrecord_illegal_filename
(
fixture_file
):
def
test_imagenet_to_mindrecord_illegal_filename
():
"""
"""
test transform imagenet dataset to mindrecord
test transform imagenet dataset to mindrecord
when file name contains illegal character.
when file name contains illegal character.
...
...
tests/ut/python/mindrecord/test_mindrecord_exception.py
浏览文件 @
decf12cd
...
@@ -26,6 +26,34 @@ CV_FILE_NAME = "./imagenet.mindrecord"
...
@@ -26,6 +26,34 @@ CV_FILE_NAME = "./imagenet.mindrecord"
NLP_FILE_NAME
=
"./aclImdb.mindrecord"
NLP_FILE_NAME
=
"./aclImdb.mindrecord"
FILES_NUM
=
4
FILES_NUM
=
4
def
remove_one_file
(
x
):
if
os
.
path
.
exists
(
x
):
os
.
remove
(
x
)
def
remove_file
(
file_name
):
x
=
file_name
remove_one_file
(
x
)
x
=
file_name
+
".db"
remove_one_file
(
x
)
for
i
in
range
(
FILES_NUM
):
x
=
file_name
+
str
(
i
)
remove_one_file
(
x
)
x
=
file_name
+
str
(
i
)
+
".db"
remove_one_file
(
x
)
@
pytest
.
fixture
def
fixture_cv_file
():
"""add/remove file"""
remove_file
(
CV_FILE_NAME
)
yield
"yield_fixture_data"
remove_file
(
CV_FILE_NAME
)
@
pytest
.
fixture
def
fixture_nlp_file
():
"""add/remove file"""
remove_file
(
NLP_FILE_NAME
)
yield
"yield_fixture_data"
remove_file
(
NLP_FILE_NAME
)
def
test_cv_file_writer_shard_num_none
():
def
test_cv_file_writer_shard_num_none
():
"""test cv file writer when shard num is None."""
"""test cv file writer when shard num is None."""
...
@@ -83,8 +111,7 @@ def test_lack_partition_and_db():
...
@@ -83,8 +111,7 @@ def test_lack_partition_and_db():
'error_msg: MindRecord File could not open successfully.'
\
'error_msg: MindRecord File could not open successfully.'
\
in
str
(
err
.
value
)
in
str
(
err
.
value
)
def
test_lack_db
(
fixture_cv_file
):
def
test_lack_db
():
"""test file reader when db file does not exist."""
"""test file reader when db file does not exist."""
create_cv_mindrecord
(
1
)
create_cv_mindrecord
(
1
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
...
@@ -94,10 +121,8 @@ def test_lack_db():
...
@@ -94,10 +121,8 @@ def test_lack_db():
assert
'[MRMOpenError]: error_code: 1347690596, '
\
assert
'[MRMOpenError]: error_code: 1347690596, '
\
'error_msg: MindRecord File could not open successfully.'
\
'error_msg: MindRecord File could not open successfully.'
\
in
str
(
err
.
value
)
in
str
(
err
.
value
)
os
.
remove
(
CV_FILE_NAME
)
def
test_lack_some_partition_and_db
(
fixture_cv_file
):
def
test_lack_some_partition_and_db
():
"""test file reader when some partition and db do not exist."""
"""test file reader when some partition and db do not exist."""
create_cv_mindrecord
(
4
)
create_cv_mindrecord
(
4
)
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
...
@@ -110,16 +135,8 @@ def test_lack_some_partition_and_db():
...
@@ -110,16 +135,8 @@ def test_lack_some_partition_and_db():
assert
'[MRMOpenError]: error_code: 1347690596, '
\
assert
'[MRMOpenError]: error_code: 1347690596, '
\
'error_msg: MindRecord File could not open successfully.'
\
'error_msg: MindRecord File could not open successfully.'
\
in
str
(
err
.
value
)
in
str
(
err
.
value
)
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
for
x
in
range
(
FILES_NUM
)]
for
x
in
paths
:
if
os
.
path
.
exists
(
"{}"
.
format
(
x
)):
os
.
remove
(
"{}"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
)):
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_lack_some_partition_first
():
def
test_lack_some_partition_first
(
fixture_cv_file
):
"""test file reader when first partition does not exist."""
"""test file reader when first partition does not exist."""
create_cv_mindrecord
(
4
)
create_cv_mindrecord
(
4
)
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
...
@@ -131,14 +148,8 @@ def test_lack_some_partition_first():
...
@@ -131,14 +148,8 @@ def test_lack_some_partition_first():
assert
'[MRMOpenError]: error_code: 1347690596, '
\
assert
'[MRMOpenError]: error_code: 1347690596, '
\
'error_msg: MindRecord File could not open successfully.'
\
'error_msg: MindRecord File could not open successfully.'
\
in
str
(
err
.
value
)
in
str
(
err
.
value
)
for
x
in
paths
:
if
os
.
path
.
exists
(
"{}"
.
format
(
x
)):
os
.
remove
(
"{}"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
)):
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_lack_some_partition_middle
():
def
test_lack_some_partition_middle
(
fixture_cv_file
):
"""test file reader when some partition does not exist."""
"""test file reader when some partition does not exist."""
create_cv_mindrecord
(
4
)
create_cv_mindrecord
(
4
)
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
...
@@ -150,14 +161,8 @@ def test_lack_some_partition_middle():
...
@@ -150,14 +161,8 @@ def test_lack_some_partition_middle():
assert
'[MRMOpenError]: error_code: 1347690596, '
\
assert
'[MRMOpenError]: error_code: 1347690596, '
\
'error_msg: MindRecord File could not open successfully.'
\
'error_msg: MindRecord File could not open successfully.'
\
in
str
(
err
.
value
)
in
str
(
err
.
value
)
for
x
in
paths
:
if
os
.
path
.
exists
(
"{}"
.
format
(
x
)):
os
.
remove
(
"{}"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
)):
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_lack_some_partition_last
(
fixture_cv_file
):
def
test_lack_some_partition_last
():
"""test file reader when last partition does not exist."""
"""test file reader when last partition does not exist."""
create_cv_mindrecord
(
4
)
create_cv_mindrecord
(
4
)
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
...
@@ -169,14 +174,8 @@ def test_lack_some_partition_last():
...
@@ -169,14 +174,8 @@ def test_lack_some_partition_last():
assert
'[MRMOpenError]: error_code: 1347690596, '
\
assert
'[MRMOpenError]: error_code: 1347690596, '
\
'error_msg: MindRecord File could not open successfully.'
\
'error_msg: MindRecord File could not open successfully.'
\
in
str
(
err
.
value
)
in
str
(
err
.
value
)
for
x
in
paths
:
if
os
.
path
.
exists
(
"{}"
.
format
(
x
)):
os
.
remove
(
"{}"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
)):
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_mindpage_lack_some_partition
():
def
test_mindpage_lack_some_partition
(
fixture_cv_file
):
"""test page reader when some partition does not exist."""
"""test page reader when some partition does not exist."""
create_cv_mindrecord
(
4
)
create_cv_mindrecord
(
4
)
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
...
@@ -187,14 +186,8 @@ def test_mindpage_lack_some_partition():
...
@@ -187,14 +186,8 @@ def test_mindpage_lack_some_partition():
assert
'[MRMOpenError]: error_code: 1347690596, '
\
assert
'[MRMOpenError]: error_code: 1347690596, '
\
'error_msg: MindRecord File could not open successfully.'
\
'error_msg: MindRecord File could not open successfully.'
\
in
str
(
err
.
value
)
in
str
(
err
.
value
)
for
x
in
paths
:
if
os
.
path
.
exists
(
"{}"
.
format
(
x
)):
os
.
remove
(
"{}"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
)):
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_lack_some_db
():
def
test_lack_some_db
(
fixture_cv_file
):
"""test file reader when some db does not exist."""
"""test file reader when some db does not exist."""
create_cv_mindrecord
(
4
)
create_cv_mindrecord
(
4
)
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
...
@@ -206,11 +199,6 @@ def test_lack_some_db():
...
@@ -206,11 +199,6 @@ def test_lack_some_db():
assert
'[MRMOpenError]: error_code: 1347690596, '
\
assert
'[MRMOpenError]: error_code: 1347690596, '
\
'error_msg: MindRecord File could not open successfully.'
\
'error_msg: MindRecord File could not open successfully.'
\
in
str
(
err
.
value
)
in
str
(
err
.
value
)
for
x
in
paths
:
if
os
.
path
.
exists
(
"{}"
.
format
(
x
)):
os
.
remove
(
"{}"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
)):
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_invalid_mindrecord
():
def
test_invalid_mindrecord
():
...
@@ -225,8 +213,7 @@ def test_invalid_mindrecord():
...
@@ -225,8 +213,7 @@ def test_invalid_mindrecord():
in
str
(
err
.
value
)
in
str
(
err
.
value
)
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
CV_FILE_NAME
)
def
test_invalid_db
(
fixture_cv_file
):
def
test_invalid_db
():
"""test file reader when the content of db is illegal."""
"""test file reader when the content of db is illegal."""
create_cv_mindrecord
(
1
)
create_cv_mindrecord
(
1
)
os
.
remove
(
"imagenet.mindrecord.db"
)
os
.
remove
(
"imagenet.mindrecord.db"
)
...
@@ -237,11 +224,8 @@ def test_invalid_db():
...
@@ -237,11 +224,8 @@ def test_invalid_db():
assert
'[MRMOpenError]: error_code: 1347690596, '
\
assert
'[MRMOpenError]: error_code: 1347690596, '
\
'error_msg: MindRecord File could not open successfully.'
\
'error_msg: MindRecord File could not open successfully.'
\
in
str
(
err
.
value
)
in
str
(
err
.
value
)
os
.
remove
(
"imagenet.mindrecord"
)
os
.
remove
(
"imagenet.mindrecord.db"
)
def
test_overwrite_invalid_mindrecord
():
def
test_overwrite_invalid_mindrecord
(
fixture_cv_file
):
"""test file writer when overwrite invalid mindreocrd file."""
"""test file writer when overwrite invalid mindreocrd file."""
with
open
(
CV_FILE_NAME
,
'w'
)
as
f
:
with
open
(
CV_FILE_NAME
,
'w'
)
as
f
:
f
.
write
(
'just for test'
)
f
.
write
(
'just for test'
)
...
@@ -250,10 +234,8 @@ def test_overwrite_invalid_mindrecord():
...
@@ -250,10 +234,8 @@ def test_overwrite_invalid_mindrecord():
assert
'[MRMOpenError]: error_code: 1347690596, '
\
assert
'[MRMOpenError]: error_code: 1347690596, '
\
'error_msg: MindRecord File could not open successfully.'
\
'error_msg: MindRecord File could not open successfully.'
\
in
str
(
err
.
value
)
in
str
(
err
.
value
)
os
.
remove
(
CV_FILE_NAME
)
def
test_overwrite_invalid_db
():
def
test_overwrite_invalid_db
(
fixture_cv_file
):
"""test file writer when overwrite invalid db file."""
"""test file writer when overwrite invalid db file."""
with
open
(
'imagenet.mindrecord.db'
,
'w'
)
as
f
:
with
open
(
'imagenet.mindrecord.db'
,
'w'
)
as
f
:
f
.
write
(
'just for test'
)
f
.
write
(
'just for test'
)
...
@@ -261,11 +243,8 @@ def test_overwrite_invalid_db():
...
@@ -261,11 +243,8 @@ def test_overwrite_invalid_db():
create_cv_mindrecord
(
1
)
create_cv_mindrecord
(
1
)
assert
'[MRMGenerateIndexError]: error_code: 1347690612, '
\
assert
'[MRMGenerateIndexError]: error_code: 1347690612, '
\
'error_msg: Failed to generate index.'
in
str
(
err
.
value
)
'error_msg: Failed to generate index.'
in
str
(
err
.
value
)
os
.
remove
(
"imagenet.mindrecord"
)
os
.
remove
(
"imagenet.mindrecord.db"
)
def
test_read_after_close
():
def
test_read_after_close
(
fixture_cv_file
):
"""test file reader when close read."""
"""test file reader when close read."""
create_cv_mindrecord
(
1
)
create_cv_mindrecord
(
1
)
reader
=
FileReader
(
CV_FILE_NAME
)
reader
=
FileReader
(
CV_FILE_NAME
)
...
@@ -275,11 +254,8 @@ def test_read_after_close():
...
@@ -275,11 +254,8 @@ def test_read_after_close():
count
=
count
+
1
count
=
count
+
1
logger
.
info
(
"#item{}: {}"
.
format
(
index
,
x
))
logger
.
info
(
"#item{}: {}"
.
format
(
index
,
x
))
assert
count
==
0
assert
count
==
0
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
def
test_file_read_after_read
():
def
test_file_read_after_read
(
fixture_cv_file
):
"""test file reader when finish read."""
"""test file reader when finish read."""
create_cv_mindrecord
(
1
)
create_cv_mindrecord
(
1
)
reader
=
FileReader
(
CV_FILE_NAME
)
reader
=
FileReader
(
CV_FILE_NAME
)
...
@@ -295,8 +271,6 @@ def test_file_read_after_read():
...
@@ -295,8 +271,6 @@ def test_file_read_after_read():
cnt
=
cnt
+
1
cnt
=
cnt
+
1
logger
.
info
(
"#item{}: {}"
.
format
(
index
,
x
))
logger
.
info
(
"#item{}: {}"
.
format
(
index
,
x
))
assert
cnt
==
0
assert
cnt
==
0
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
def
test_cv_file_writer_shard_num_greater_than_1000
():
def
test_cv_file_writer_shard_num_greater_than_1000
():
...
@@ -312,8 +286,7 @@ def test_add_index_without_add_schema():
...
@@ -312,8 +286,7 @@ def test_add_index_without_add_schema():
fw
.
add_index
([
"label"
])
fw
.
add_index
([
"label"
])
assert
'Failed to get meta info'
in
str
(
err
.
value
)
assert
'Failed to get meta info'
in
str
(
err
.
value
)
def
test_mindpage_pageno_pagesize_not_int
(
fixture_cv_file
):
def
test_mindpage_pageno_pagesize_not_int
():
"""test page reader when some partition does not exist."""
"""test page reader when some partition does not exist."""
create_cv_mindrecord
(
4
)
create_cv_mindrecord
(
4
)
reader
=
MindPage
(
CV_FILE_NAME
+
"0"
)
reader
=
MindPage
(
CV_FILE_NAME
+
"0"
)
...
@@ -342,14 +315,8 @@ def test_mindpage_pageno_pagesize_not_int():
...
@@ -342,14 +315,8 @@ def test_mindpage_pageno_pagesize_not_int():
with
pytest
.
raises
(
MRMFetchDataError
,
match
=
"Failed to fetch data by category."
):
with
pytest
.
raises
(
MRMFetchDataError
,
match
=
"Failed to fetch data by category."
):
reader
.
read_at_page_by_id
(
99999
,
0
,
1
)
reader
.
read_at_page_by_id
(
99999
,
0
,
1
)
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
for
x
in
range
(
FILES_NUM
)]
for
x
in
paths
:
os
.
remove
(
"{}"
.
format
(
x
))
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_mindpage_filename_not_exist
():
def
test_mindpage_filename_not_exist
(
fixture_cv_file
):
"""test page reader when some partition does not exist."""
"""test page reader when some partition does not exist."""
create_cv_mindrecord
(
4
)
create_cv_mindrecord
(
4
)
reader
=
MindPage
(
CV_FILE_NAME
+
"0"
)
reader
=
MindPage
(
CV_FILE_NAME
+
"0"
)
...
@@ -374,6 +341,3 @@ def test_mindpage_filename_not_exist():
...
@@ -374,6 +341,3 @@ def test_mindpage_filename_not_exist():
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
for
x
in
range
(
FILES_NUM
)]
for
x
in
range
(
FILES_NUM
)]
for
x
in
paths
:
os
.
remove
(
"{}"
.
format
(
x
))
os
.
remove
(
"{}.db"
.
format
(
x
))
tests/ut/python/mindrecord/test_mnist_to_mr.py
浏览文件 @
decf12cd
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""test mnist to mindrecord tool"""
"""test mnist to mindrecord tool"""
import
cv2
import
cv2
import
gzip
import
gzip
import
pytest
import
numpy
as
np
import
numpy
as
np
import
os
import
os
...
@@ -27,6 +28,34 @@ PARTITION_NUM = 4
...
@@ -27,6 +28,34 @@ PARTITION_NUM = 4
IMAGE_SIZE
=
28
IMAGE_SIZE
=
28
NUM_CHANNELS
=
1
NUM_CHANNELS
=
1
@
pytest
.
fixture
def
fixture_file
():
"""add/remove file"""
def
remove_one_file
(
x
):
if
os
.
path
.
exists
(
x
):
os
.
remove
(
x
)
def
remove_file
():
x
=
"mnist_train.mindrecord"
remove_one_file
(
x
)
x
=
"mnist_train.mindrecord.db"
remove_one_file
(
x
)
x
=
"mnist_test.mindrecord"
remove_one_file
(
x
)
x
=
"mnist_test.mindrecord.db"
remove_one_file
(
x
)
for
i
in
range
(
PARTITION_NUM
):
x
=
"mnist_train.mindrecord"
+
str
(
i
)
remove_one_file
(
x
)
x
=
"mnist_train.mindrecord"
+
str
(
i
)
+
".db"
remove_one_file
(
x
)
x
=
"mnist_test.mindrecord"
+
str
(
i
)
remove_one_file
(
x
)
x
=
"mnist_test.mindrecord"
+
str
(
i
)
+
".db"
remove_one_file
(
x
)
remove_file
()
yield
"yield_fixture_data"
remove_file
()
def
read
(
train_name
,
test_name
):
def
read
(
train_name
,
test_name
):
"""test file reader"""
"""test file reader"""
...
@@ -51,7 +80,7 @@ def read(train_name, test_name):
...
@@ -51,7 +80,7 @@ def read(train_name, test_name):
reader
.
close
()
reader
.
close
()
def
test_mnist_to_mindrecord
():
def
test_mnist_to_mindrecord
(
fixture_file
):
"""test transform mnist dataset to mindrecord."""
"""test transform mnist dataset to mindrecord."""
mnist_transformer
=
MnistToMR
(
MNIST_DIR
,
FILE_NAME
)
mnist_transformer
=
MnistToMR
(
MNIST_DIR
,
FILE_NAME
)
mnist_transformer
.
transform
()
mnist_transformer
.
transform
()
...
@@ -60,13 +89,7 @@ def test_mnist_to_mindrecord():
...
@@ -60,13 +89,7 @@ def test_mnist_to_mindrecord():
read
(
"mnist_train.mindrecord"
,
"mnist_test.mindrecord"
)
read
(
"mnist_train.mindrecord"
,
"mnist_test.mindrecord"
)
os
.
remove
(
"{}"
.
format
(
"mnist_train.mindrecord"
))
def
test_mnist_to_mindrecord_compare_data
(
fixture_file
):
os
.
remove
(
"{}.db"
.
format
(
"mnist_train.mindrecord"
))
os
.
remove
(
"{}"
.
format
(
"mnist_test.mindrecord"
))
os
.
remove
(
"{}.db"
.
format
(
"mnist_test.mindrecord"
))
def
test_mnist_to_mindrecord_compare_data
():
"""test transform mnist dataset to mindrecord and compare data."""
"""test transform mnist dataset to mindrecord and compare data."""
mnist_transformer
=
MnistToMR
(
MNIST_DIR
,
FILE_NAME
)
mnist_transformer
=
MnistToMR
(
MNIST_DIR
,
FILE_NAME
)
mnist_transformer
.
transform
()
mnist_transformer
.
transform
()
...
@@ -121,21 +144,10 @@ def test_mnist_to_mindrecord_compare_data():
...
@@ -121,21 +144,10 @@ def test_mnist_to_mindrecord_compare_data():
assert
np
.
array
(
x
[
'label'
])
==
label
assert
np
.
array
(
x
[
'label'
])
==
label
reader
.
close
()
reader
.
close
()
os
.
remove
(
"{}"
.
format
(
"mnist_train.mindrecord"
))
def
test_mnist_to_mindrecord_multi_partition
(
fixture_file
):
os
.
remove
(
"{}.db"
.
format
(
"mnist_train.mindrecord"
))
os
.
remove
(
"{}"
.
format
(
"mnist_test.mindrecord"
))
os
.
remove
(
"{}.db"
.
format
(
"mnist_test.mindrecord"
))
def
test_mnist_to_mindrecord_multi_partition
():
"""test transform mnist dataset to multiple mindrecord files."""
"""test transform mnist dataset to multiple mindrecord files."""
mnist_transformer
=
MnistToMR
(
MNIST_DIR
,
FILE_NAME
,
PARTITION_NUM
)
mnist_transformer
=
MnistToMR
(
MNIST_DIR
,
FILE_NAME
,
PARTITION_NUM
)
mnist_transformer
.
transform
()
mnist_transformer
.
transform
()
read
(
"mnist_train.mindrecord0"
,
"mnist_test.mindrecord0"
)
read
(
"mnist_train.mindrecord0"
,
"mnist_test.mindrecord0"
)
for
i
in
range
(
PARTITION_NUM
):
os
.
remove
(
"{}"
.
format
(
"mnist_train.mindrecord"
+
str
(
i
)))
os
.
remove
(
"{}.db"
.
format
(
"mnist_train.mindrecord"
+
str
(
i
)))
os
.
remove
(
"{}"
.
format
(
"mnist_test.mindrecord"
+
str
(
i
)))
os
.
remove
(
"{}.db"
.
format
(
"mnist_test.mindrecord"
+
str
(
i
)))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录