Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
526770e0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
526770e0
编写于
7月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3049 [dataset] add save operator in dataset
Merge pull request !3049 from liyong126/dataset_save_op
上级
6335598f
bc676fe2
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
828 addition
and
9 deletion
+828
-9
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
+226
-0
mindspore/ccsrc/minddata/dataset/api/de_pipeline.h
mindspore/ccsrc/minddata/dataset/api/de_pipeline.h
+16
-0
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
+5
-1
mindspore/ccsrc/minddata/dataset/core/tensor.h
mindspore/ccsrc/minddata/dataset/core/tensor.h
+5
-5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc
...inddata/dataset/engine/datasetops/source/mindrecord_op.cc
+1
-1
mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h
...re/ccsrc/minddata/mindrecord/include/common/shard_utils.h
+4
-0
mindspore/ccsrc/minddata/mindrecord/include/shard_header.h
mindspore/ccsrc/minddata/mindrecord/include/shard_header.h
+4
-0
mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h
...ccsrc/minddata/mindrecord/include/shard_index_generator.h
+2
-0
mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h
mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h
+7
-0
mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc
...ore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc
+16
-0
mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc
mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc
+52
-0
mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc
mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc
+30
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+30
-2
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+23
-0
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+17
-0
tests/ut/python/dataset/test_save_op.py
tests/ut/python/dataset/test_save_op.py
+390
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
浏览文件 @
526770e0
...
...
@@ -42,11 +42,17 @@
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_category.h"
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "minddata/mindrecord/include/shard_header.h"
#include "minddata/mindrecord/include/shard_index_generator.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
#include "minddata/mindrecord/include/shard_writer.h"
#include "pybind11/stl.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
dataset
{
using
json
=
nlohmann
::
json
;
using
pFunction
=
Status
(
DEPipeline
::*
)(
const
py
::
dict
&
,
std
::
shared_ptr
<
DatasetOp
>
*
,
std
::
shared_ptr
<
DatasetOp
>
*
);
static
std
::
unordered_map
<
uint32_t
,
pFunction
>
g_parse_op_func_
=
{
...
...
@@ -355,6 +361,226 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
return
Status
::
OK
();
}
Status
DEPipeline
::
SaveDataset
(
const
std
::
vector
<
std
::
string
>
&
file_names
,
const
std
::
string
&
file_type
)
{
Status
s
;
auto
mr_header
=
std
::
make_shared
<
mindrecord
::
ShardHeader
>
();
auto
mr_writer
=
std
::
make_unique
<
mindrecord
::
ShardWriter
>
();
std
::
vector
<
std
::
string
>
blob_fields
;
uint64_t
mr_schema_id
=
0
;
if
(
mindrecord
::
SUCCESS
!=
mindrecord
::
ShardWriter
::
initialize
(
&
mr_writer
,
file_names
))
{
RETURN_STATUS_UNEXPECTED
(
"Error: failed to initialize ShardWriter."
);
}
TensorRow
row
;
std
::
unordered_map
<
std
::
string
,
int32_t
>
column_name_id_map
=
iterator_
->
GetColumnNameMap
();
// map of column name, id
bool
first_loop
=
true
;
// build schema in first loop
do
{
json
row_raw_data
;
std
::
map
<
std
::
string
,
std
::
unique_ptr
<
std
::
vector
<
uint8_t
>>>
row_bin_data
;
{
py
::
gil_scoped_release
gil_release
;
s
=
iterator_
->
FetchNextTensorRow
(
&
row
);
}
RETURN_IF_NOT_OK
(
s
);
if
(
row
.
empty
())
break
;
if
(
first_loop
)
{
json
mr_json
;
std
::
vector
<
std
::
string
>
index_fields
;
s
=
FetchMetaFromTensorRow
(
column_name_id_map
,
row
,
&
mr_json
,
&
index_fields
);
RETURN_IF_NOT_OK
(
s
);
mindrecord
::
ShardHeader
::
initialize
(
&
mr_header
,
mr_json
,
index_fields
,
blob_fields
,
mr_schema_id
);
mr_writer
->
SetShardHeader
(
mr_header
);
first_loop
=
false
;
}
// construct data
if
(
!
row
.
empty
())
{
// write data
s
=
FetchDataFromTensorRow
(
row
,
column_name_id_map
,
&
row_raw_data
,
&
row_bin_data
);
RETURN_IF_NOT_OK
(
s
);
std
::
shared_ptr
<
std
::
vector
<
uint8_t
>>
output_bin_data
;
mr_writer
->
MergeBlobData
(
blob_fields
,
row_bin_data
,
&
output_bin_data
);
std
::
map
<
std
::
uint64_t
,
std
::
vector
<
json
>>
raw_data
;
raw_data
.
insert
(
std
::
pair
<
uint64_t
,
std
::
vector
<
json
>>
(
mr_schema_id
,
std
::
vector
<
json
>
{
row_raw_data
}));
std
::
vector
<
std
::
vector
<
uint8_t
>>
bin_data
;
if
(
nullptr
!=
output_bin_data
)
{
bin_data
.
emplace_back
(
*
output_bin_data
);
}
mr_writer
->
WriteRawData
(
raw_data
,
bin_data
);
}
}
while
(
!
row
.
empty
());
mr_writer
->
Commit
();
mindrecord
::
ShardIndexGenerator
::
finalize
(
file_names
);
return
Status
::
OK
();
}
Status
DEPipeline
::
FetchDataFromTensorRow
(
const
TensorRow
&
row
,
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
column_name_id_map
,
json
*
row_raw_data
,
std
::
map
<
std
::
string
,
std
::
unique_ptr
<
std
::
vector
<
uint8_t
>>>
*
row_bin_data
)
{
if
(
row_raw_data
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"error: row raw data is NULL."
);
}
if
(
row_bin_data
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"error: row bin data is NULL."
);
}
if
(
column_name_id_map
.
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"Error: column not found"
);
}
Status
s
;
for
(
auto
&
col
:
column_name_id_map
)
{
auto
idx
=
col
.
second
;
auto
column_name
=
col
.
first
;
auto
&
tensor
=
row
[
idx
];
auto
column_type
=
tensor
->
type
();
std
::
unique_ptr
<
std
::
vector
<
uint8_t
>>
data_ptr
;
if
(
column_type
==
DataType
::
DE_INT8
)
{
std
::
unique_ptr
<
int32_t
>
data
;
std
::
unique_ptr
<
int8_t
>
dummy
;
s
=
TransfromTensor
(
tensor
->
GetBuffer
(),
tensor
->
shape
(),
tensor
->
Size
(),
&
data
,
&
data_ptr
,
&
dummy
,
true
);
RETURN_IF_NOT_OK
(
s
);
if
(
data
!=
nullptr
)
(
*
row_raw_data
)[
column_name
]
=
std
::
move
(
*
data
);
}
else
if
(
column_type
==
DataType
::
DE_INT16
)
{
std
::
unique_ptr
<
int32_t
>
data
;
std
::
unique_ptr
<
int16_t
>
dummy
;
s
=
TransfromTensor
(
tensor
->
GetBuffer
(),
tensor
->
shape
(),
tensor
->
Size
(),
&
data
,
&
data_ptr
,
&
dummy
,
true
);
RETURN_IF_NOT_OK
(
s
);
if
(
data
!=
nullptr
)
(
*
row_raw_data
)[
column_name
]
=
std
::
move
(
*
data
);
}
else
if
(
column_type
==
DataType
::
DE_UINT16
)
{
std
::
unique_ptr
<
int32_t
>
data
;
std
::
unique_ptr
<
uint16_t
>
dummy
;
s
=
TransfromTensor
(
tensor
->
GetBuffer
(),
tensor
->
shape
(),
tensor
->
Size
(),
&
data
,
&
data_ptr
,
&
dummy
,
true
);
RETURN_IF_NOT_OK
(
s
);
if
(
data
!=
nullptr
)
(
*
row_raw_data
)[
column_name
]
=
std
::
move
(
*
data
);
}
else
if
(
column_type
==
DataType
::
DE_UINT8
)
{
std
::
unique_ptr
<
uint8_t
>
data
,
dummy
;
s
=
TransfromTensor
(
tensor
->
GetBuffer
(),
tensor
->
shape
(),
tensor
->
Size
(),
&
data
,
&
data_ptr
,
&
dummy
);
RETURN_IF_NOT_OK
(
s
);
if
(
data
!=
nullptr
)
(
*
row_raw_data
)[
column_name
]
=
std
::
move
(
*
data
);
}
else
if
(
column_type
==
DataType
::
DE_INT32
)
{
std
::
unique_ptr
<
int32_t
>
data
,
dummy
;
s
=
TransfromTensor
(
tensor
->
GetBuffer
(),
tensor
->
shape
(),
tensor
->
Size
(),
&
data
,
&
data_ptr
,
&
dummy
);
RETURN_IF_NOT_OK
(
s
);
if
(
data
!=
nullptr
)
(
*
row_raw_data
)[
column_name
]
=
std
::
move
(
*
data
);
}
else
if
(
column_type
==
DataType
::
DE_UINT32
)
{
std
::
unique_ptr
<
int64_t
>
data
;
std
::
unique_ptr
<
uint32_t
>
dummy
;
s
=
TransfromTensor
(
tensor
->
GetBuffer
(),
tensor
->
shape
(),
tensor
->
Size
(),
&
data
,
&
data_ptr
,
&
dummy
,
true
);
RETURN_IF_NOT_OK
(
s
);
if
(
data
!=
nullptr
)
(
*
row_raw_data
)[
column_name
]
=
std
::
move
(
*
data
);
}
else
if
(
column_type
==
DataType
::
DE_INT64
)
{
std
::
unique_ptr
<
int64_t
>
data
,
dummy
;
s
=
TransfromTensor
(
tensor
->
GetBuffer
(),
tensor
->
shape
(),
tensor
->
Size
(),
&
data
,
&
data_ptr
,
&
dummy
);
RETURN_IF_NOT_OK
(
s
);
if
(
data
!=
nullptr
)
(
*
row_raw_data
)[
column_name
]
=
std
::
move
(
*
data
);
}
else
if
(
column_type
==
DataType
::
DE_FLOAT32
)
{
std
::
unique_ptr
<
float
>
data
,
dummy
;
s
=
TransfromTensor
(
tensor
->
GetBuffer
(),
tensor
->
shape
(),
tensor
->
Size
(),
&
data
,
&
data_ptr
,
&
dummy
);
RETURN_IF_NOT_OK
(
s
);
if
(
data
!=
nullptr
)
(
*
row_raw_data
)[
column_name
]
=
std
::
move
(
*
data
);
}
else
if
(
column_type
==
DataType
::
DE_FLOAT64
)
{
std
::
unique_ptr
<
double
>
data
,
dummy
;
s
=
TransfromTensor
(
tensor
->
GetBuffer
(),
tensor
->
shape
(),
tensor
->
Size
(),
&
data
,
&
data_ptr
,
&
dummy
);
RETURN_IF_NOT_OK
(
s
);
if
(
data
!=
nullptr
)
(
*
row_raw_data
)[
column_name
]
=
std
::
move
(
*
data
);
}
else
if
(
column_type
==
DataType
::
DE_STRING
)
{
auto
buffer
=
tensor
->
GetStringsBuffer
();
std
::
string
ss
(
reinterpret_cast
<
const
char
*>
(
buffer
));
// assume scalar string tensor
(
*
row_raw_data
)[
column_name
]
=
std
::
move
(
ss
);
continue
;
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Got unexpected type when casting data."
);
}
RETURN_IF_NOT_OK
(
s
);
if
(
data_ptr
!=
nullptr
)
{
(
*
row_bin_data
)[
column_name
]
=
std
::
move
(
data_ptr
);
}
}
return
Status
::
OK
();
}
template
<
typename
T
,
typename
S
>
Status
DEPipeline
::
TransfromTensor
(
const
unsigned
char
*
src
,
const
TensorShape
&
shape
,
const
int64_t
num_of_elements
,
std
::
unique_ptr
<
T
>
*
data
,
std
::
unique_ptr
<
std
::
vector
<
uint8_t
>>
*
data_ptr
,
std
::
unique_ptr
<
S
>
*
s
,
bool
need_convert
)
{
if
(
nullptr
==
src
)
{
RETURN_STATUS_UNEXPECTED
(
"Error: buffer of Tensor is NULL."
);
}
*
data_ptr
=
std
::
make_unique
<
std
::
vector
<
uint8_t
>>
(
num_of_elements
*
sizeof
(
T
));
if
(
need_convert
)
{
auto
tmp_ptr
=
std
::
make_unique
<
std
::
vector
<
uint8_t
>>
(
num_of_elements
*
sizeof
(
S
));
std
::
copy
(
src
,
src
+
sizeof
(
S
)
*
num_of_elements
,
tmp_ptr
->
begin
());
auto
s_ptr
=
reinterpret_cast
<
S
*>
(
&
(
*
(
tmp_ptr
->
begin
())));
auto
el
=
std
::
make_unique
<
T
>
();
for
(
uint32_t
i
=
0
;
i
<
num_of_elements
;
++
i
)
{
*
el
=
*
(
s_ptr
+
i
);
auto
t_ptr
=
reinterpret_cast
<
uint8_t
*>
(
el
.
get
());
for
(
uint32_t
j
=
0
;
j
<
sizeof
(
T
);
++
j
)
{
*
((
*
data_ptr
)
->
begin
()
+
i
*
sizeof
(
T
)
+
j
)
=
*
(
t_ptr
+
j
);
}
}
}
else
{
std
::
copy
(
src
,
src
+
sizeof
(
T
)
*
num_of_elements
,
(
*
data_ptr
)
->
begin
());
}
if
(
shape
.
empty
())
{
*
data
=
std
::
make_unique
<
T
>
();
auto
t_ptr
=
reinterpret_cast
<
uint8_t
*>
((
*
data
).
get
());
for
(
uint32_t
i
=
0
;
i
<
sizeof
(
T
);
++
i
)
{
*
(
t_ptr
+
i
)
=
*
((
*
data_ptr
)
->
begin
()
+
i
);
}
}
return
Status
::
OK
();
}
Status
DEPipeline
::
FetchMetaFromTensorRow
(
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
column_name_id_map
,
const
TensorRow
&
row
,
json
*
schema
,
std
::
vector
<
std
::
string
>
*
index_fields
)
{
if
(
schema
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"error: schema is NULL."
);
}
if
(
index_fields
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"error: index fields is NULL."
);
}
if
(
column_name_id_map
.
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"Error: column not found."
);
}
for
(
auto
&
col
:
column_name_id_map
)
{
auto
idx
=
col
.
second
;
auto
column_name
=
col
.
first
;
auto
&
tensor
=
row
[
idx
];
auto
column_type
=
tensor
->
type
();
auto
column_shape
=
tensor
->
shape
();
std
::
string
mr_type
;
auto
shapes
=
column_shape
.
AsVector
();
std
::
vector
<
int
>
mr_shape
(
shapes
.
begin
(),
shapes
.
end
());
std
::
string
el
=
column_type
.
ToString
();
if
(
mindrecord
::
kTypesMap
.
find
(
el
)
==
mindrecord
::
kTypesMap
.
end
())
{
std
::
string
err_msg
(
"Error: can not support data type: "
+
el
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
else
{
mr_type
=
mindrecord
::
kTypesMap
.
at
(
el
);
}
if
(
mr_shape
.
empty
())
{
if
(
mr_type
==
"bytes"
)
{
// map to int32 when bytes without shape.
mr_type
==
"int32"
;
}
(
*
schema
)[
column_name
]
=
{{
"type"
,
mr_type
}};
}
else
{
if
(
mr_type
==
"string"
)
{
// mindrecord can not support string with shape.
std
::
string
err_msg
(
"Error: mindrecord can not support multi-dimensional string tensor."
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
if
(
mr_type
==
"bytes"
)
{
// ignore shape of bytes in minrecord
(
*
schema
)[
column_name
]
=
{{
"type"
,
mr_type
}};
}
else
{
(
*
schema
)[
column_name
]
=
{{
"type"
,
mr_type
},
{
"shape"
,
mr_shape
}};
}
}
if
(
mr_type
==
"bytes"
||
!
mr_shape
.
empty
())
continue
;
index_fields
->
emplace_back
(
column_name
);
// candidate of index fields
}
return
Status
::
OK
();
}
Status
DEPipeline
::
BuildMindrecordSamplerChain
(
const
py
::
handle
&
handle
,
std
::
vector
<
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
*
operators
,
int
num_padded
)
{
...
...
mindspore/ccsrc/minddata/dataset/api/de_pipeline.h
浏览文件 @
526770e0
...
...
@@ -17,6 +17,7 @@
#define DATASET_API_DE_PIPELINE_H_
#include <iostream>
#include <map>
#include <memory>
#include <stack>
#include <string>
...
...
@@ -33,6 +34,7 @@
namespace
py
=
pybind11
;
namespace
mindspore
{
namespace
dataset
{
using
json
=
nlohmann
::
json
;
using
DsOpPtr
=
std
::
shared_ptr
<
DatasetOp
>
;
class
CacheClient
;
...
...
@@ -100,6 +102,8 @@ class DEPipeline {
Status
GetOutputTypes
(
py
::
list
*
output
);
Status
SaveDataset
(
const
std
::
vector
<
std
::
string
>
&
file_names
,
const
std
::
string
&
file_type
);
int
GetDatasetSize
()
const
;
int
GetBatchSize
()
const
;
...
...
@@ -110,6 +114,18 @@ class DEPipeline {
Status
ParseMindRecordOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
template
<
typename
T
,
typename
S
>
Status
TransfromTensor
(
const
unsigned
char
*
src
,
const
TensorShape
&
shape
,
const
int64_t
num_of_elements
,
std
::
unique_ptr
<
T
>
*
data
,
std
::
unique_ptr
<
std
::
vector
<
uint8_t
>>
*
data_ptr
,
std
::
unique_ptr
<
S
>
*
s
,
bool
need_convert
=
false
);
Status
FetchMetaFromTensorRow
(
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
column_name_id_map
,
const
TensorRow
&
row
,
json
*
schema
,
std
::
vector
<
std
::
string
>
*
index_fields
);
Status
FetchDataFromTensorRow
(
const
TensorRow
&
row
,
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
column_name_id_map
,
json
*
row_raw_data
,
std
::
map
<
std
::
string
,
std
::
unique_ptr
<
std
::
vector
<
uint8_t
>>>
*
row_bin_data
);
Status
BuildMindrecordSamplerChain
(
const
py
::
handle
&
handle
,
std
::
vector
<
std
::
shared_ptr
<
mindrecord
::
ShardOperator
>>
*
operators
,
int
num_padded
);
...
...
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
浏览文件 @
526770e0
...
...
@@ -184,7 +184,11 @@ void bindDEPipeline(py::module *m) {
.
def
(
"GetDatasetSize"
,
&
DEPipeline
::
GetDatasetSize
)
.
def
(
"GetBatchSize"
,
&
DEPipeline
::
GetBatchSize
)
.
def
(
"GetNumClasses"
,
&
DEPipeline
::
GetNumClasses
)
.
def
(
"GetRepeatCount"
,
&
DEPipeline
::
GetRepeatCount
);
.
def
(
"GetRepeatCount"
,
&
DEPipeline
::
GetRepeatCount
)
.
def
(
"SaveDataset"
,
[](
DEPipeline
&
de
,
const
std
::
vector
<
std
::
string
>
&
file_names
,
const
std
::
string
&
file_type
)
{
THROW_IF_ERROR
(
de
.
SaveDataset
(
file_names
,
file_type
));
return
true
;
});
}
void
bindDatasetOps
(
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
TFReaderOp
,
DatasetOp
,
std
::
shared_ptr
<
TFReaderOp
>>
(
*
m
,
"TFReaderOp"
)
...
...
mindspore/ccsrc/minddata/dataset/core/tensor.h
浏览文件 @
526770e0
...
...
@@ -312,6 +312,11 @@ class Tensor {
// @return const unsigned char*
const
unsigned
char
*
GetBuffer
()
const
;
// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the
// tensor's type is a string, otherwise undefined address would be returned.
// @return address of the first string of the tensor.
uchar
*
GetStringsBuffer
()
const
{
return
data_
+
kOffsetSize
*
shape_
.
NumOfElements
()
+
kOffsetSize
;
}
// Getter of the type
// @return
DataType
type
()
const
{
return
type_
;
}
...
...
@@ -643,11 +648,6 @@ class Tensor {
// @return length of the string
Status
GetStringAt
(
dsize_t
index
,
uchar
**
string_start
,
offset_t
*
length
)
const
;
// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the
// tensor's type is a string, otherwise undefined address would be returned.
// @return address of the first string of the tensor.
uchar
*
GetStringsBuffer
()
const
{
return
data_
+
kOffsetSize
*
shape_
.
NumOfElements
()
+
kOffsetSize
;
}
// all access to shape_ should be via shape
TensorShape
shape_
;
// data type of tensor
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc
浏览文件 @
526770e0
...
...
@@ -215,7 +215,7 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
ParallelOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal stuff
out
<<
"
\n
Dataset file : "
;
out
<<
"
\n
Dataset file : "
;
for
(
auto
&
file
:
dataset_file_
)
{
out
<<
file
<<
" "
;
}
...
...
mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h
浏览文件 @
526770e0
...
...
@@ -137,6 +137,10 @@ const std::set<std::string> kScalarFieldTypeSet = {"string", "int32", "int64", "
// number field list
const
std
::
set
<
std
::
string
>
kNumberFieldTypeSet
=
{
"int32"
,
"int64"
,
"float32"
,
"float64"
};
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
kTypesMap
=
{
{
"bool"
,
"int32"
},
{
"int8"
,
"int32"
},
{
"uint8"
,
"bytes"
},
{
"int16"
,
"int32"
},
{
"uint16"
,
"int32"
},
{
"int32"
,
"int32"
},
{
"uint32"
,
"int64"
},
{
"int64"
,
"int64"
},
{
"float16"
,
"float32"
},
{
"float32"
,
"float32"
},
{
"float64"
,
"float64"
},
{
"string"
,
"string"
}};
/// \brief split a string using a character
/// \param[in] field target string
/// \param[in] separator a character for spliting
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_header.h
浏览文件 @
526770e0
...
...
@@ -124,6 +124,10 @@ class ShardHeader {
MSRStatus
FileToPages
(
const
std
::
string
dump_file_name
);
static
MSRStatus
initialize
(
const
std
::
shared_ptr
<
ShardHeader
>
*
header_ptr
,
const
json
&
schema
,
const
std
::
vector
<
std
::
string
>
&
index_fields
,
std
::
vector
<
std
::
string
>
&
blob_fields
,
uint64_t
&
schema_id
);
private:
MSRStatus
InitializeHeader
(
const
std
::
vector
<
json
>
&
headers
,
bool
load_dataset
);
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h
浏览文件 @
526770e0
...
...
@@ -57,6 +57,8 @@ class ShardIndexGenerator {
/// \brief create databases for indexes
MSRStatus
WriteToDatabase
();
static
MSRStatus
finalize
(
const
std
::
vector
<
std
::
string
>
file_names
);
private:
static
int
Callback
(
void
*
not_used
,
int
argc
,
char
**
argv
,
char
**
az_col_name
);
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h
浏览文件 @
526770e0
...
...
@@ -108,6 +108,13 @@ class ShardWriter {
std
::
map
<
uint64_t
,
std
::
vector
<
py
::
handle
>>
&
blob_data
,
bool
sign
=
true
,
bool
parallel_writer
=
false
);
MSRStatus
MergeBlobData
(
const
std
::
vector
<
string
>
&
blob_fields
,
const
std
::
map
<
std
::
string
,
std
::
unique_ptr
<
std
::
vector
<
uint8_t
>>>
&
row_bin_data
,
std
::
shared_ptr
<
std
::
vector
<
uint8_t
>>
*
output
);
static
MSRStatus
initialize
(
const
std
::
unique_ptr
<
ShardWriter
>
*
writer_ptr
,
const
std
::
vector
<
std
::
string
>
&
file_names
);
private:
/// \brief write shard header data to disk
MSRStatus
WriteShardHeader
();
...
...
mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc
浏览文件 @
526770e0
...
...
@@ -622,5 +622,21 @@ void ShardIndexGenerator::DatabaseWriter() {
shard_no
=
task_
++
;
}
}
MSRStatus
ShardIndexGenerator
::
finalize
(
const
std
::
vector
<
std
::
string
>
file_names
)
{
if
(
file_names
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"Mindrecord files is empty."
;
return
FAILED
;
}
ShardIndexGenerator
sg
{
file_names
[
0
]};
if
(
SUCCESS
!=
sg
.
Build
())
{
MS_LOG
(
ERROR
)
<<
"Failed to build index generator."
;
return
FAILED
;
}
if
(
SUCCESS
!=
sg
.
WriteToDatabase
())
{
MS_LOG
(
ERROR
)
<<
"Failed to write to database."
;
return
FAILED
;
}
return
SUCCESS
;
}
}
// namespace mindrecord
}
// namespace mindspore
mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc
浏览文件 @
526770e0
...
...
@@ -637,6 +637,42 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
*
row_count
=
std
::
get
<
2
>
(
v
);
return
SUCCESS
;
}
MSRStatus
ShardWriter
::
MergeBlobData
(
const
std
::
vector
<
string
>
&
blob_fields
,
const
std
::
map
<
std
::
string
,
std
::
unique_ptr
<
std
::
vector
<
uint8_t
>>>
&
row_bin_data
,
std
::
shared_ptr
<
std
::
vector
<
uint8_t
>>
*
output
)
{
if
(
blob_fields
.
empty
())
{
return
SUCCESS
;
}
if
(
blob_fields
.
size
()
==
1
)
{
auto
&
blob
=
row_bin_data
.
at
(
blob_fields
[
0
]);
auto
blob_size
=
blob
->
size
();
*
output
=
std
::
make_shared
<
std
::
vector
<
uint8_t
>>
(
blob_size
);
std
::
copy
(
blob
->
begin
(),
blob
->
end
(),
(
*
output
)
->
begin
());
}
else
{
size_t
output_size
=
0
;
for
(
auto
&
field
:
blob_fields
)
{
output_size
+=
row_bin_data
.
at
(
field
)
->
size
();
}
output_size
+=
blob_fields
.
size
()
*
sizeof
(
uint64_t
);
*
output
=
std
::
make_shared
<
std
::
vector
<
uint8_t
>>
(
output_size
);
std
::
vector
<
uint8_t
>
buf
(
sizeof
(
uint64_t
),
0
);
size_t
idx
=
0
;
for
(
auto
&
field
:
blob_fields
)
{
auto
&
blob
=
row_bin_data
.
at
(
field
);
uint64_t
blob_size
=
blob
->
size
();
// big edian
for
(
size_t
i
=
0
;
i
<
buf
.
size
();
++
i
)
{
buf
[
buf
.
size
()
-
1
-
i
]
=
std
::
numeric_limits
<
uint8_t
>::
max
()
&
blob_size
;
blob_size
>>=
8u
;
}
std
::
copy
(
buf
.
begin
(),
buf
.
end
(),
(
*
output
)
->
begin
()
+
idx
);
idx
+=
buf
.
size
();
std
::
copy
(
blob
->
begin
(),
blob
->
end
(),
(
*
output
)
->
begin
()
+
idx
);
idx
+=
blob
->
size
();
}
}
return
SUCCESS
;
}
MSRStatus
ShardWriter
::
WriteRawData
(
std
::
map
<
uint64_t
,
std
::
vector
<
json
>>
&
raw_data
,
std
::
vector
<
std
::
vector
<
uint8_t
>>
&
blob_data
,
bool
sign
,
bool
parallel_writer
)
{
...
...
@@ -1250,5 +1286,21 @@ void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &la
last_blob_page
=
page
.
first
;
}
}
MSRStatus
ShardWriter
::
initialize
(
const
std
::
unique_ptr
<
ShardWriter
>
*
writer_ptr
,
const
std
::
vector
<
std
::
string
>
&
file_names
)
{
if
(
nullptr
==
writer_ptr
)
{
MS_LOG
(
ERROR
)
<<
"ShardWriter pointer is NULL."
;
return
FAILED
;
}
auto
res
=
(
*
writer_ptr
)
->
Open
(
file_names
,
false
);
if
(
SUCCESS
!=
res
)
{
MS_LOG
(
ERROR
)
<<
"Failed to open mindrecord files to writer."
;
return
FAILED
;
}
(
*
writer_ptr
)
->
SetHeaderSize
(
1
<<
24
);
(
*
writer_ptr
)
->
SetPageSize
(
1
<<
25
);
return
SUCCESS
;
}
}
// namespace mindrecord
}
// namespace mindspore
mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc
浏览文件 @
526770e0
...
...
@@ -721,5 +721,35 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
page_in_handle
.
close
();
return
SUCCESS
;
}
MSRStatus
ShardHeader
::
initialize
(
const
std
::
shared_ptr
<
ShardHeader
>
*
header_ptr
,
const
json
&
schema
,
const
std
::
vector
<
std
::
string
>
&
index_fields
,
std
::
vector
<
std
::
string
>
&
blob_fields
,
uint64_t
&
schema_id
)
{
if
(
nullptr
==
header_ptr
)
{
MS_LOG
(
ERROR
)
<<
"ShardHeader pointer is NULL."
;
return
FAILED
;
}
auto
schema_ptr
=
Schema
::
Build
(
"mindrecord"
,
schema
);
if
(
nullptr
==
schema_ptr
)
{
MS_LOG
(
ERROR
)
<<
"Got unexpected error when building mindrecord schema."
;
return
FAILED
;
}
schema_id
=
(
*
header_ptr
)
->
AddSchema
(
schema_ptr
);
// create index
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
id_index_fields
;
if
(
!
index_fields
.
empty
())
{
for
(
auto
&
el
:
index_fields
)
{
id_index_fields
.
emplace_back
(
schema_id
,
el
);
}
if
(
SUCCESS
!=
(
*
header_ptr
)
->
AddIndexFields
(
id_index_fields
))
{
MS_LOG
(
ERROR
)
<<
"Got unexpected error when adding mindrecord index."
;
return
FAILED
;
}
}
auto
build_schema_ptr
=
(
*
header_ptr
)
->
GetSchemas
()[
0
];
blob_fields
=
build_schema_ptr
->
GetBlobFields
();
return
SUCCESS
;
}
}
// namespace mindrecord
}
// namespace mindspore
mindspore/dataset/engine/datasets.py
浏览文件 @
526770e0
...
...
@@ -38,13 +38,13 @@ from mindspore._c_expression import typing
from
mindspore
import
log
as
logger
from
.
import
samplers
from
.iterators
import
DictIterator
,
TupleIterator
,
DummyIterator
from
.iterators
import
DictIterator
,
TupleIterator
,
DummyIterator
,
SaveOp
from
.validators
import
check_batch
,
check_shuffle
,
check_map
,
check_filter
,
check_repeat
,
check_skip
,
check_zip
,
\
check_rename
,
check_numpyslicesdataset
,
\
check_take
,
check_project
,
check_imagefolderdatasetv2
,
check_mnist_cifar_dataset
,
check_manifestdataset
,
\
check_tfrecorddataset
,
check_vocdataset
,
check_cocodataset
,
check_celebadataset
,
check_minddataset
,
\
check_generatordataset
,
check_sync_wait
,
check_zip_dataset
,
check_add_column
,
check_textfiledataset
,
check_concat
,
\
check_random_dataset
,
check_split
,
check_bucket_batch_by_length
,
check_cluedataset
,
check_positive_int32
check_random_dataset
,
check_split
,
check_bucket_batch_by_length
,
check_cluedataset
,
check_positive_int32
,
check_save
from
..core.datatypes
import
mstype_to_detype
,
mstypelist_to_detypelist
try
:
...
...
@@ -1044,6 +1044,34 @@ class Dataset:
return
TransferDataset
(
self
,
queue_name
,
device_id
,
device_type
,
num_batch
)
@
check_save
def
save
(
self
,
file_name
,
num_files
=
1
,
file_type
=
'mindrecord'
):
"""
Save the dynamic data processed by dataset pipeline as common dataset format, support: mindrecord.
Note:
1. To save the samples in order, should set dataset's shuffle false and num_files 1.
2. Before call the function, do not use batch, repeat operator or data augmentation operators
with random attribute in map operator.
3. Mindreocrd do not support np.uint64, multi-dimensional np.uint8(drop dimension) and
multi-dimensional string.
Args:
file_name (str): Path to dataset file.
num_files (int, optional): Number of dataset files.(default=1).
file_type (str, optional): dataset format.(default='mindrecord')
"""
if
num_files
==
1
:
file_names
=
[
file_name
]
else
:
suffix
=
len
(
str
(
num_files
-
1
))
file_names
=
[
"{}{}"
.
format
(
file_name
,
str
(
x
).
rjust
(
suffix
,
'0'
))
for
x
in
range
(
num_files
)]
return
SaveOp
(
self
).
save
(
file_names
,
file_type
)
def
create_tuple_iterator
(
self
,
columns
=
None
):
"""
Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data.
...
...
mindspore/dataset/engine/iterators.py
浏览文件 @
526770e0
...
...
@@ -173,6 +173,7 @@ class Iterator:
# Convert python node into C node and add to C layer execution tree in postorder traversal.
def
__convert_node_postorder
(
self
,
node
):
self
.
check_node_type
(
node
)
op_type
=
self
.
__get_dataset_type
(
node
)
c_nodes
=
self
.
depipeline
.
AddNodeToTree
(
op_type
,
node
.
get_args
())
...
...
@@ -224,6 +225,10 @@ class Iterator:
self
.
_index
+=
1
return
data
@
abstractmethod
def
check_node_type
(
self
,
node
):
pass
def
get_output_shapes
(
self
):
return
[
t
for
t
in
self
.
depipeline
.
GetOutputShapes
()]
...
...
@@ -245,11 +250,27 @@ class Iterator:
def
__deepcopy__
(
self
,
memo
):
return
self
class
SaveOp
(
Iterator
):
"""
The derived class of Iterator with dict type.
"""
def
get_next
(
self
):
pass
def
check_node_type
(
self
,
node
):
if
isinstance
(
node
,
(
de
.
ShuffleDataset
,
de
.
RepeatDataset
,
de
.
BatchDataset
)):
logger
.
warning
(
"Used shuffle, repeat, batch before save operator."
)
def
save
(
self
,
file_names
,
file_type
):
return
self
.
depipeline
.
SaveDataset
(
file_names
,
file_type
)
class
DictIterator
(
Iterator
):
"""
The derived class of Iterator with dict type.
"""
def
check_node_type
(
self
,
node
):
pass
def
__iter__
(
self
):
return
self
...
...
@@ -269,6 +290,8 @@ class TupleIterator(Iterator):
"""
The derived class of Iterator with list type.
"""
def
check_node_type
(
self
,
node
):
pass
def
__init__
(
self
,
dataset
,
columns
=
None
):
if
columns
is
not
None
:
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
526770e0
...
...
@@ -246,7 +246,24 @@ def check_celebadataset(method):
return
new_method
def
check_save
(
method
):
"""A wrapper that wrap a parameter checker to the save op."""
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
_
,
param_dict
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
nreq_param_int
=
[
'num_files'
]
nreq_param_str
=
[
'file_name'
,
'file_type'
]
validate_dataset_param_value
(
nreq_param_int
,
param_dict
,
int
)
if
(
param_dict
.
get
(
'num_files'
)
<=
0
or
param_dict
.
get
(
'num_files'
)
>
1000
):
raise
ValueError
(
"num_files should between {} and {}."
.
format
(
1
,
1000
))
validate_dataset_param_value
(
nreq_param_str
,
param_dict
,
str
)
if
param_dict
.
get
(
'file_type'
)
!=
'mindrecord'
:
raise
ValueError
(
"{} dataset format is not supported."
.
format
(
param_dict
.
get
(
'file_type'
)))
return
method
(
self
,
*
args
,
**
kwargs
)
return
new_method
def
check_minddataset
(
method
):
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
...
...
tests/ut/python/dataset/test_save_op.py
0 → 100644
浏览文件 @
526770e0
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This is the test module for saveOp.
"""
import
os
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
mindspore.mindrecord
import
FileWriter
import
numpy
as
np
import
pytest
CV_FILE_NAME1
=
"../data/mindrecord/testMindDataSet/temp.mindrecord"
CV_FILE_NAME2
=
"../data/mindrecord/testMindDataSet/auto.mindrecord"
FILES_NUM
=
1
num_readers
=
1
@
pytest
.
fixture
(
name
=
"add_and_remove_cv_file"
)
def
fixture_remove
():
"""add/remove cv file"""
if
os
.
path
.
exists
(
"{}"
.
format
(
CV_FILE_NAME1
)):
os
.
remove
(
"{}"
.
format
(
CV_FILE_NAME1
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV_FILE_NAME1
)):
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME1
))
if
os
.
path
.
exists
(
"{}"
.
format
(
CV_FILE_NAME2
)):
os
.
remove
(
"{}"
.
format
(
CV_FILE_NAME2
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV_FILE_NAME2
)):
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME2
))
yield
"yield_cv_data"
if
os
.
path
.
exists
(
"{}"
.
format
(
CV_FILE_NAME1
)):
os
.
remove
(
"{}"
.
format
(
CV_FILE_NAME1
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV_FILE_NAME1
)):
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME1
))
if
os
.
path
.
exists
(
"{}"
.
format
(
CV_FILE_NAME2
)):
os
.
remove
(
"{}"
.
format
(
CV_FILE_NAME2
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV_FILE_NAME2
)):
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME2
))
def
test_case_00
(
add_and_remove_cv_file
):
# only bin data
data
=
[{
"image1"
:
bytes
(
"image1 bytes abc"
,
encoding
=
'UTF-8'
),
"image2"
:
bytes
(
"image1 bytes def"
,
encoding
=
'UTF-8'
),
"image3"
:
bytes
(
"image1 bytes ghi"
,
encoding
=
'UTF-8'
),
"image4"
:
bytes
(
"image1 bytes jkl"
,
encoding
=
'UTF-8'
),
"image5"
:
bytes
(
"image1 bytes mno"
,
encoding
=
'UTF-8'
)},
{
"image1"
:
bytes
(
"image2 bytes abc"
,
encoding
=
'UTF-8'
),
"image2"
:
bytes
(
"image2 bytes def"
,
encoding
=
'UTF-8'
),
"image3"
:
bytes
(
"image2 bytes ghi"
,
encoding
=
'UTF-8'
),
"image4"
:
bytes
(
"image2 bytes jkl"
,
encoding
=
'UTF-8'
),
"image5"
:
bytes
(
"image2 bytes mno"
,
encoding
=
'UTF-8'
)},
{
"image1"
:
bytes
(
"image3 bytes abc"
,
encoding
=
'UTF-8'
),
"image2"
:
bytes
(
"image3 bytes def"
,
encoding
=
'UTF-8'
),
"image3"
:
bytes
(
"image3 bytes ghi"
,
encoding
=
'UTF-8'
),
"image4"
:
bytes
(
"image3 bytes jkl"
,
encoding
=
'UTF-8'
),
"image5"
:
bytes
(
"image3 bytes mno"
,
encoding
=
'UTF-8'
)},
{
"image1"
:
bytes
(
"image5 bytes abc"
,
encoding
=
'UTF-8'
),
"image2"
:
bytes
(
"image5 bytes def"
,
encoding
=
'UTF-8'
),
"image3"
:
bytes
(
"image5 bytes ghi"
,
encoding
=
'UTF-8'
),
"image4"
:
bytes
(
"image5 bytes jkl"
,
encoding
=
'UTF-8'
),
"image5"
:
bytes
(
"image5 bytes mno"
,
encoding
=
'UTF-8'
)},
{
"image1"
:
bytes
(
"image6 bytes abc"
,
encoding
=
'UTF-8'
),
"image2"
:
bytes
(
"image6 bytes def"
,
encoding
=
'UTF-8'
),
"image3"
:
bytes
(
"image6 bytes ghi"
,
encoding
=
'UTF-8'
),
"image4"
:
bytes
(
"image6 bytes jkl"
,
encoding
=
'UTF-8'
),
"image5"
:
bytes
(
"image6 bytes mno"
,
encoding
=
'UTF-8'
)}]
schema
=
{
"image1"
:
{
"type"
:
"bytes"
},
"image2"
:
{
"type"
:
"bytes"
},
"image3"
:
{
"type"
:
"bytes"
},
"image4"
:
{
"type"
:
"bytes"
},
"image5"
:
{
"type"
:
"bytes"
}}
writer
=
FileWriter
(
CV_FILE_NAME1
,
FILES_NUM
)
writer
.
add_schema
(
schema
,
"schema"
)
writer
.
write_raw_data
(
data
)
writer
.
commit
()
d1
=
ds
.
MindDataset
(
CV_FILE_NAME1
,
None
,
num_readers
,
shuffle
=
False
)
d1
.
save
(
CV_FILE_NAME2
,
FILES_NUM
)
data_value_to_list
=
[]
for
item
in
data
:
new_data
=
{}
new_data
[
'image1'
]
=
np
.
asarray
(
list
(
item
[
"image1"
]),
dtype
=
np
.
uint8
)
new_data
[
'image2'
]
=
np
.
asarray
(
list
(
item
[
"image2"
]),
dtype
=
np
.
uint8
)
new_data
[
'image3'
]
=
np
.
asarray
(
list
(
item
[
"image3"
]),
dtype
=
np
.
uint8
)
new_data
[
'image4'
]
=
np
.
asarray
(
list
(
item
[
"image4"
]),
dtype
=
np
.
uint8
)
new_data
[
'image5'
]
=
np
.
asarray
(
list
(
item
[
"image5"
]),
dtype
=
np
.
uint8
)
data_value_to_list
.
append
(
new_data
)
d2
=
ds
.
MindDataset
(
dataset_file
=
CV_FILE_NAME2
,
num_parallel_workers
=
num_readers
,
shuffle
=
False
)
assert
d2
.
get_dataset_size
()
==
5
num_iter
=
0
for
item
in
d2
.
create_dict_iterator
():
assert
len
(
item
)
==
5
for
field
in
item
:
if
isinstance
(
item
[
field
],
np
.
ndarray
):
assert
(
item
[
field
]
==
data_value_to_list
[
num_iter
][
field
]).
all
()
else
:
assert
item
[
field
]
==
data_value_to_list
[
num_iter
][
field
]
num_iter
+=
1
assert
num_iter
==
5
def
test_case_01
(
add_and_remove_cv_file
):
# only raw data
data
=
[{
"file_name"
:
"001.jpg"
,
"label"
:
43
},
{
"file_name"
:
"002.jpg"
,
"label"
:
91
},
{
"file_name"
:
"003.jpg"
,
"label"
:
61
},
{
"file_name"
:
"004.jpg"
,
"label"
:
29
},
{
"file_name"
:
"005.jpg"
,
"label"
:
78
},
{
"file_name"
:
"006.jpg"
,
"label"
:
37
}]
schema
=
{
"file_name"
:
{
"type"
:
"string"
},
"label"
:
{
"type"
:
"int32"
}
}
writer
=
FileWriter
(
CV_FILE_NAME1
,
FILES_NUM
)
writer
.
add_schema
(
schema
,
"schema"
)
writer
.
write_raw_data
(
data
)
writer
.
commit
()
d1
=
ds
.
MindDataset
(
CV_FILE_NAME1
,
None
,
num_readers
,
shuffle
=
False
)
d1
.
save
(
CV_FILE_NAME2
,
FILES_NUM
)
data_value_to_list
=
[]
for
item
in
data
:
new_data
=
{}
new_data
[
'file_name'
]
=
np
.
asarray
(
item
[
"file_name"
],
dtype
=
'S'
)
new_data
[
'label'
]
=
np
.
asarray
(
list
([
item
[
"label"
]]),
dtype
=
np
.
int32
)
data_value_to_list
.
append
(
new_data
)
d2
=
ds
.
MindDataset
(
dataset_file
=
CV_FILE_NAME2
,
num_parallel_workers
=
num_readers
,
shuffle
=
False
)
assert
d2
.
get_dataset_size
()
==
6
num_iter
=
0
for
item
in
d2
.
create_dict_iterator
():
logger
.
info
(
item
)
assert
len
(
item
)
==
2
for
field
in
item
:
if
isinstance
(
item
[
field
],
np
.
ndarray
):
assert
(
item
[
field
]
==
data_value_to_list
[
num_iter
][
field
]).
all
()
else
:
assert
item
[
field
]
==
data_value_to_list
[
num_iter
][
field
]
num_iter
+=
1
assert
num_iter
==
6
def
test_case_02
(
add_and_remove_cv_file
):
# muti-bytes
data
=
[{
"file_name"
:
"001.jpg"
,
"label"
:
43
,
"float32_array"
:
np
.
array
([
1.2
,
2.78
,
3.1234
,
4.9871
,
5.12341
],
dtype
=
np
.
float32
),
"float64_array"
:
np
.
array
([
48.1234556789
,
49.3251241431
,
50.13514312414
,
51.8971298471
,
123414314.2141243
,
87.1212122
],
dtype
=
np
.
float64
),
"float32"
:
3456.12345
,
"float64"
:
1987654321.123456785
,
"source_sos_ids"
:
np
.
array
([
1
,
2
,
3
,
4
,
5
],
dtype
=
np
.
int32
),
"source_sos_mask"
:
np
.
array
([
6
,
7
,
8
,
9
,
10
,
11
,
12
],
dtype
=
np
.
int64
),
"image1"
:
bytes
(
"image1 bytes abc"
,
encoding
=
'UTF-8'
),
"image2"
:
bytes
(
"image1 bytes def"
,
encoding
=
'UTF-8'
),
"image3"
:
bytes
(
"image1 bytes ghi"
,
encoding
=
'UTF-8'
),
"image4"
:
bytes
(
"image1 bytes jkl"
,
encoding
=
'UTF-8'
),
"image5"
:
bytes
(
"image1 bytes mno"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"002.jpg"
,
"label"
:
91
,
"float32_array"
:
np
.
array
([
1.2
,
2.78
,
4.1234
,
4.9871
,
5.12341
],
dtype
=
np
.
float32
),
"float64_array"
:
np
.
array
([
48.1234556789
,
49.3251241431
,
60.13514312414
,
51.8971298471
,
123414314.2141243
,
87.1212122
],
dtype
=
np
.
float64
),
"float32"
:
3456.12445
,
"float64"
:
1987654321.123456786
,
"source_sos_ids"
:
np
.
array
([
11
,
2
,
3
,
4
,
5
],
dtype
=
np
.
int32
),
"source_sos_mask"
:
np
.
array
([
16
,
7
,
8
,
9
,
10
,
11
,
12
],
dtype
=
np
.
int64
),
"image1"
:
bytes
(
"image2 bytes abc"
,
encoding
=
'UTF-8'
),
"image2"
:
bytes
(
"image2 bytes def"
,
encoding
=
'UTF-8'
),
"image3"
:
bytes
(
"image2 bytes ghi"
,
encoding
=
'UTF-8'
),
"image4"
:
bytes
(
"image2 bytes jkl"
,
encoding
=
'UTF-8'
),
"image5"
:
bytes
(
"image2 bytes mno"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"003.jpg"
,
"label"
:
61
,
"float32_array"
:
np
.
array
([
1.2
,
2.78
,
5.1234
,
4.9871
,
5.12341
],
dtype
=
np
.
float32
),
"float64_array"
:
np
.
array
([
48.1234556789
,
49.3251241431
,
70.13514312414
,
51.8971298471
,
123414314.2141243
,
87.1212122
],
dtype
=
np
.
float64
),
"float32"
:
3456.12545
,
"float64"
:
1987654321.123456787
,
"source_sos_ids"
:
np
.
array
([
21
,
2
,
3
,
4
,
5
],
dtype
=
np
.
int32
),
"source_sos_mask"
:
np
.
array
([
26
,
7
,
8
,
9
,
10
,
11
,
12
],
dtype
=
np
.
int64
),
"image1"
:
bytes
(
"image3 bytes abc"
,
encoding
=
'UTF-8'
),
"image2"
:
bytes
(
"image3 bytes def"
,
encoding
=
'UTF-8'
),
"image3"
:
bytes
(
"image3 bytes ghi"
,
encoding
=
'UTF-8'
),
"image4"
:
bytes
(
"image3 bytes jkl"
,
encoding
=
'UTF-8'
),
"image5"
:
bytes
(
"image3 bytes mno"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"004.jpg"
,
"label"
:
29
,
"float32_array"
:
np
.
array
([
1.2
,
2.78
,
6.1234
,
4.9871
,
5.12341
],
dtype
=
np
.
float32
),
"float64_array"
:
np
.
array
([
48.1234556789
,
49.3251241431
,
80.13514312414
,
51.8971298471
,
123414314.2141243
,
87.1212122
],
dtype
=
np
.
float64
),
"float32"
:
3456.12645
,
"float64"
:
1987654321.123456788
,
"source_sos_ids"
:
np
.
array
([
31
,
2
,
3
,
4
,
5
],
dtype
=
np
.
int32
),
"source_sos_mask"
:
np
.
array
([
36
,
7
,
8
,
9
,
10
,
11
,
12
],
dtype
=
np
.
int64
),
"image1"
:
bytes
(
"image4 bytes abc"
,
encoding
=
'UTF-8'
),
"image2"
:
bytes
(
"image4 bytes def"
,
encoding
=
'UTF-8'
),
"image3"
:
bytes
(
"image4 bytes ghi"
,
encoding
=
'UTF-8'
),
"image4"
:
bytes
(
"image4 bytes jkl"
,
encoding
=
'UTF-8'
),
"image5"
:
bytes
(
"image4 bytes mno"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"005.jpg"
,
"label"
:
78
,
"float32_array"
:
np
.
array
([
1.2
,
2.78
,
7.1234
,
4.9871
,
5.12341
],
dtype
=
np
.
float32
),
"float64_array"
:
np
.
array
([
48.1234556789
,
49.3251241431
,
90.13514312414
,
51.8971298471
,
123414314.2141243
,
87.1212122
],
dtype
=
np
.
float64
),
"float32"
:
3456.12745
,
"float64"
:
1987654321.123456789
,
"source_sos_ids"
:
np
.
array
([
41
,
2
,
3
,
4
,
5
],
dtype
=
np
.
int32
),
"source_sos_mask"
:
np
.
array
([
46
,
7
,
8
,
9
,
10
,
11
,
12
],
dtype
=
np
.
int64
),
"image1"
:
bytes
(
"image5 bytes abc"
,
encoding
=
'UTF-8'
),
"image2"
:
bytes
(
"image5 bytes def"
,
encoding
=
'UTF-8'
),
"image3"
:
bytes
(
"image5 bytes ghi"
,
encoding
=
'UTF-8'
),
"image4"
:
bytes
(
"image5 bytes jkl"
,
encoding
=
'UTF-8'
),
"image5"
:
bytes
(
"image5 bytes mno"
,
encoding
=
'UTF-8'
)},
{
"file_name"
:
"006.jpg"
,
"label"
:
37
,
"float32_array"
:
np
.
array
([
1.2
,
2.78
,
7.1234
,
4.9871
,
5.12341
],
dtype
=
np
.
float32
),
"float64_array"
:
np
.
array
([
48.1234556789
,
49.3251241431
,
90.13514312414
,
51.8971298471
,
123414314.2141243
,
87.1212122
],
dtype
=
np
.
float64
),
"float32"
:
3456.12745
,
"float64"
:
1987654321.123456789
,
"source_sos_ids"
:
np
.
array
([
51
,
2
,
3
,
4
,
5
],
dtype
=
np
.
int32
),
"source_sos_mask"
:
np
.
array
([
56
,
7
,
8
,
9
,
10
,
11
,
12
],
dtype
=
np
.
int64
),
"image1"
:
bytes
(
"image6 bytes abc"
,
encoding
=
'UTF-8'
),
"image2"
:
bytes
(
"image6 bytes def"
,
encoding
=
'UTF-8'
),
"image3"
:
bytes
(
"image6 bytes ghi"
,
encoding
=
'UTF-8'
),
"image4"
:
bytes
(
"image6 bytes jkl"
,
encoding
=
'UTF-8'
),
"image5"
:
bytes
(
"image6 bytes mno"
,
encoding
=
'UTF-8'
)}
]
schema
=
{
"file_name"
:
{
"type"
:
"string"
},
"float32_array"
:
{
"type"
:
"float32"
,
"shape"
:
[
-
1
]},
"float64_array"
:
{
"type"
:
"float64"
,
"shape"
:
[
-
1
]},
"float32"
:
{
"type"
:
"float32"
},
"float64"
:
{
"type"
:
"float64"
},
"source_sos_ids"
:
{
"type"
:
"int32"
,
"shape"
:
[
-
1
]},
"source_sos_mask"
:
{
"type"
:
"int64"
,
"shape"
:
[
-
1
]},
"image1"
:
{
"type"
:
"bytes"
},
"image2"
:
{
"type"
:
"bytes"
},
"image3"
:
{
"type"
:
"bytes"
},
"label"
:
{
"type"
:
"int32"
},
"image4"
:
{
"type"
:
"bytes"
},
"image5"
:
{
"type"
:
"bytes"
}}
writer
=
FileWriter
(
CV_FILE_NAME1
,
FILES_NUM
)
writer
.
add_schema
(
schema
,
"schema"
)
writer
.
write_raw_data
(
data
)
writer
.
commit
()
d1
=
ds
.
MindDataset
(
CV_FILE_NAME1
,
None
,
num_readers
,
shuffle
=
False
)
d1
.
save
(
CV_FILE_NAME2
,
FILES_NUM
)
data_value_to_list
=
[]
for
item
in
data
:
new_data
=
{}
new_data
[
'file_name'
]
=
np
.
asarray
(
item
[
"file_name"
],
dtype
=
'S'
)
new_data
[
'float32_array'
]
=
item
[
"float32_array"
]
new_data
[
'float64_array'
]
=
item
[
"float64_array"
]
new_data
[
'float32'
]
=
item
[
"float32"
]
new_data
[
'float64'
]
=
item
[
"float64"
]
new_data
[
'source_sos_ids'
]
=
item
[
"source_sos_ids"
]
new_data
[
'source_sos_mask'
]
=
item
[
"source_sos_mask"
]
new_data
[
'label'
]
=
np
.
asarray
(
list
([
item
[
"label"
]]),
dtype
=
np
.
int32
)
new_data
[
'image1'
]
=
np
.
asarray
(
list
(
item
[
"image1"
]),
dtype
=
np
.
uint8
)
new_data
[
'image2'
]
=
np
.
asarray
(
list
(
item
[
"image2"
]),
dtype
=
np
.
uint8
)
new_data
[
'image3'
]
=
np
.
asarray
(
list
(
item
[
"image3"
]),
dtype
=
np
.
uint8
)
new_data
[
'image4'
]
=
np
.
asarray
(
list
(
item
[
"image4"
]),
dtype
=
np
.
uint8
)
new_data
[
'image5'
]
=
np
.
asarray
(
list
(
item
[
"image5"
]),
dtype
=
np
.
uint8
)
data_value_to_list
.
append
(
new_data
)
d2
=
ds
.
MindDataset
(
dataset_file
=
CV_FILE_NAME2
,
num_parallel_workers
=
num_readers
,
shuffle
=
False
)
assert
d2
.
get_dataset_size
()
==
6
num_iter
=
0
for
item
in
d2
.
create_dict_iterator
():
assert
len
(
item
)
==
13
for
field
in
item
:
if
isinstance
(
item
[
field
],
np
.
ndarray
):
if
item
[
field
].
dtype
==
np
.
float32
:
assert
(
item
[
field
]
==
np
.
array
(
data_value_to_list
[
num_iter
][
field
],
np
.
float32
)).
all
()
else
:
assert
(
item
[
field
]
==
data_value_to_list
[
num_iter
][
field
]).
all
()
else
:
assert
item
[
field
]
==
data_value_to_list
[
num_iter
][
field
]
num_iter
+=
1
assert
num_iter
==
6
def
generator_1d
():
for
i
in
range
(
10
):
yield
(
np
.
array
([
i
]),)
def
test_case_03
(
add_and_remove_cv_file
):
# apply dataset operations
d1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
],
shuffle
=
False
)
d1
.
save
(
CV_FILE_NAME2
)
d2
=
ds
.
MindDataset
(
dataset_file
=
CV_FILE_NAME2
,
num_parallel_workers
=
num_readers
,
shuffle
=
False
)
i
=
0
for
item
in
d2
.
create_dict_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
1
def
generator_with_type
(
t
):
for
i
in
range
(
64
):
yield
(
np
.
array
([
i
],
dtype
=
t
),)
def
type_tester
(
t
):
logger
.
info
(
"Test with Type {}"
.
format
(
t
.
__name__
))
# apply dataset operations
data1
=
ds
.
GeneratorDataset
((
lambda
:
generator_with_type
(
t
)),
[
"data"
],
shuffle
=
False
)
data1
=
data1
.
batch
(
4
)
data1
=
data1
.
repeat
(
3
)
data1
.
save
(
CV_FILE_NAME2
)
d2
=
ds
.
MindDataset
(
dataset_file
=
CV_FILE_NAME2
,
num_parallel_workers
=
num_readers
,
shuffle
=
False
)
i
=
0
num_repeat
=
0
for
item
in
d2
.
create_dict_iterator
():
# each data is a dictionary
golden
=
np
.
array
([[
i
],
[
i
+
1
],
[
i
+
2
],
[
i
+
3
]],
dtype
=
t
)
logger
.
info
(
item
)
assert
np
.
array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
4
if
i
==
64
:
i
=
0
num_repeat
+=
1
assert
num_repeat
==
3
if
os
.
path
.
exists
(
"{}"
.
format
(
CV_FILE_NAME2
)):
os
.
remove
(
"{}"
.
format
(
CV_FILE_NAME2
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV_FILE_NAME2
)):
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME2
))
def
test_case_04
():
# uint8 will drop shape as mindrecord store uint8 as bytes
types
=
[
np
.
int8
,
np
.
int16
,
np
.
int32
,
np
.
int64
,
np
.
uint16
,
np
.
uint32
,
np
.
float32
,
np
.
float64
]
for
t
in
types
:
type_tester
(
t
)
def
test_case_05
(
add_and_remove_cv_file
):
d1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
],
shuffle
=
False
)
with
pytest
.
raises
(
Exception
,
match
=
"num_files should between 1 and 1000."
):
d1
.
save
(
CV_FILE_NAME2
,
0
)
def
test_case_06
(
add_and_remove_cv_file
):
d1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
],
shuffle
=
False
)
with
pytest
.
raises
(
Exception
,
match
=
"tfrecord dataset format is not supported."
):
d1
.
save
(
CV_FILE_NAME2
,
1
,
"tfrecord"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录