Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f521532a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f521532a
编写于
7月 20, 2020
作者:
L
liyong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix field_name probelem from tfrecord to mindrecord
上级
b5d8dad4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
60 addition
and
6 deletion
+60
-6
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
+15
-5
tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord
tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord
+0
-0
tests/ut/python/dataset/test_save_op.py
tests/ut/python/dataset/test_save_op.py
+45
-1
未找到文件。
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
浏览文件 @
f521532a
...
...
@@ -385,9 +385,14 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
}
TensorRow
row
;
std
::
unordered_map
<
std
::
string
,
int32_t
>
column_name_id_map
=
iterator_
->
GetColumnNameMap
();
// map of column name, id
bool
first_loop
=
true
;
// build schema in first loop
std
::
unordered_map
<
std
::
string
,
int32_t
>
column_name_id_map
;
for
(
auto
el
:
iterator_
->
GetColumnNameMap
())
{
std
::
string
column_name
=
el
.
first
;
std
::
transform
(
column_name
.
begin
(),
column_name
.
end
(),
column_name
.
begin
(),
[](
unsigned
char
c
)
{
return
ispunct
(
c
)
?
'_'
:
c
;
});
column_name_id_map
[
column_name
]
=
el
.
second
;
}
bool
first_loop
=
true
;
// build schema in first loop
do
{
json
row_raw_data
;
std
::
map
<
std
::
string
,
std
::
unique_ptr
<
std
::
vector
<
uint8_t
>>>
row_bin_data
;
...
...
@@ -402,7 +407,10 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
std
::
vector
<
std
::
string
>
index_fields
;
s
=
FetchMetaFromTensorRow
(
column_name_id_map
,
row
,
&
mr_json
,
&
index_fields
);
RETURN_IF_NOT_OK
(
s
);
mindrecord
::
ShardHeader
::
initialize
(
&
mr_header
,
mr_json
,
index_fields
,
blob_fields
,
mr_schema_id
);
if
(
mindrecord
::
SUCCESS
!=
mindrecord
::
ShardHeader
::
initialize
(
&
mr_header
,
mr_json
,
index_fields
,
blob_fields
,
mr_schema_id
))
{
RETURN_STATUS_UNEXPECTED
(
"Error: failed to initialize ShardHeader."
);
}
mr_writer
->
SetShardHeader
(
mr_header
);
first_loop
=
false
;
}
...
...
@@ -422,7 +430,9 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
}
}
while
(
!
row
.
empty
());
mr_writer
->
Commit
();
mindrecord
::
ShardIndexGenerator
::
finalize
(
file_names
);
if
(
mindrecord
::
SUCCESS
!=
mindrecord
::
ShardIndexGenerator
::
finalize
(
file_names
))
{
RETURN_STATUS_UNEXPECTED
(
"Error: failed to finalize ShardIndexGenerator."
);
}
return
Status
::
OK
();
}
...
...
tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord
0 → 100644
浏览文件 @
f521532a
文件已添加
tests/ut/python/dataset/test_save_op.py
浏览文件 @
f521532a
...
...
@@ -16,6 +16,7 @@
This is the test module for saveOp.
"""
import
os
from
string
import
punctuation
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
mindspore.mindrecord
import
FileWriter
...
...
@@ -24,7 +25,7 @@ import pytest
CV_FILE_NAME1
=
"../data/mindrecord/testMindDataSet/temp.mindrecord"
CV_FILE_NAME2
=
"../data/mindrecord/testMindDataSet/auto.mindrecord"
TFRECORD_FILES
=
"../data/mindrecord/testTFRecordData/dummy.tfrecord"
FILES_NUM
=
1
num_readers
=
1
...
...
@@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file):
with
pytest
.
raises
(
Exception
,
match
=
"tfrecord dataset format is not supported."
):
d1
.
save
(
CV_FILE_NAME2
,
1
,
"tfrecord"
)
def
cast_name
(
key
):
"""
Cast schema names which containing special characters to valid names.
"""
special_symbols
=
set
(
'{}{}'
.
format
(
punctuation
,
' '
))
special_symbols
.
remove
(
'_'
)
new_key
=
[
'_'
if
x
in
special_symbols
else
x
for
x
in
key
]
casted_key
=
''
.
join
(
new_key
)
return
casted_key
def
test_case_07
():
if
os
.
path
.
exists
(
"{}"
.
format
(
CV_FILE_NAME2
)):
os
.
remove
(
"{}"
.
format
(
CV_FILE_NAME2
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV_FILE_NAME2
)):
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME2
))
d1
=
ds
.
TFRecordDataset
(
TFRECORD_FILES
,
shuffle
=
False
)
tf_data
=
[]
for
x
in
d1
.
create_dict_iterator
():
tf_data
.
append
(
x
)
d1
.
save
(
CV_FILE_NAME2
,
FILES_NUM
)
d2
=
ds
.
MindDataset
(
dataset_file
=
CV_FILE_NAME2
,
num_parallel_workers
=
num_readers
,
shuffle
=
False
)
mr_data
=
[]
for
x
in
d2
.
create_dict_iterator
():
mr_data
.
append
(
x
)
count
=
0
for
x
in
tf_data
:
for
k
,
v
in
x
.
items
():
if
isinstance
(
v
,
np
.
ndarray
):
assert
(
v
==
mr_data
[
count
][
cast_name
(
k
)]).
all
()
else
:
assert
v
==
mr_data
[
count
][
cast_name
(
k
)]
count
+=
1
assert
count
==
10
if
os
.
path
.
exists
(
"{}"
.
format
(
CV_FILE_NAME2
)):
os
.
remove
(
"{}"
.
format
(
CV_FILE_NAME2
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
CV_FILE_NAME2
)):
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME2
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录