Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0ab7fd98
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0ab7fd98
编写于
6月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2654 fix: tfrecord to mindrecord para check
Merge pull request !2654 from guozhijian/fix_tfrecord_to_mr_para_check
上级
08ff9099
3450c35d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
178 addition
and
13 deletion
+178
-13
mindspore/mindrecord/tools/tfrecord_to_mr.py
mindspore/mindrecord/tools/tfrecord_to_mr.py
+11
-8
tests/ut/python/mindrecord/test_tfrecord_to_mr.py
tests/ut/python/mindrecord/test_tfrecord_to_mr.py
+167
-5
未找到文件。
mindspore/mindrecord/tools/tfrecord_to_mr.py
浏览文件 @
0ab7fd98
...
...
@@ -113,7 +113,7 @@ class TFRecordToMR:
feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string),
\
"yyyy": tf.io.VarLenFeature(tf.int64)},
\
"sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}}
bytes_fields (list
): the bytes fields which are in feature_dict
.
bytes_fields (list
, optional): the bytes fields which are in feature_dict and can be images bytes
.
Raises:
ValueError: If parameter is invalid.
...
...
@@ -147,7 +147,7 @@ class TFRecordToMR:
self
.
feature_dict
=
feature_dict
bytes_fields_list
=
[]
if
bytes_fields
:
if
bytes_fields
is
not
None
:
if
not
isinstance
(
bytes_fields
,
list
):
raise
ValueError
(
"Parameter bytes_fields: {} must be list(str)."
.
format
(
bytes_fields
))
for
item
in
bytes_fields
:
...
...
@@ -161,6 +161,9 @@ class TFRecordToMR:
if
not
isinstance
(
self
.
feature_dict
[
item
].
shape
,
list
):
raise
ValueError
(
"Parameter feature_dict[{}].shape should be a list."
.
format
(
item
))
if
self
.
feature_dict
[
item
].
dtype
!=
tf
.
string
:
raise
ValueError
(
"Parameter bytes_field: {} should be tf.string in feature_dict."
.
format
(
item
))
casted_bytes_field
=
_cast_name
(
item
)
bytes_fields_list
.
append
(
casted_bytes_field
)
...
...
@@ -172,7 +175,7 @@ class TFRecordToMR:
for
key
,
val
in
self
.
feature_dict
.
items
():
if
not
val
.
shape
:
self
.
scalar_set
.
add
(
_cast_name
(
key
))
if
key
in
self
.
bytes_fields_list
:
if
_cast_name
(
key
)
in
self
.
bytes_fields_list
:
mindrecord_schema
[
_cast_name
(
key
)]
=
{
"type"
:
"bytes"
}
else
:
mindrecord_schema
[
_cast_name
(
key
)]
=
{
"type"
:
_cast_type
(
val
.
dtype
)}
...
...
@@ -182,8 +185,8 @@ class TFRecordToMR:
if
val
.
shape
[
0
]
<
1
:
raise
ValueError
(
"Parameter feature_dict[{}].shape[0] should > 0"
.
format
(
key
))
if
val
.
dtype
==
tf
.
string
:
raise
ValueError
(
"Parameter feautre_dict[{}].dtype is tf.string which shape[0]
\
is not None. It is not supported."
.
format
(
key
))
raise
ValueError
(
"Parameter feautre_dict[{}].dtype is tf.string which shape[0]
"
\
"
is not None. It is not supported."
.
format
(
key
))
self
.
list_set
.
add
(
_cast_name
(
key
))
mindrecord_schema
[
_cast_name
(
key
)]
=
{
"type"
:
_cast_type
(
val
.
dtype
),
"shape"
:
[
val
.
shape
[
0
]]}
self
.
mindrecord_schema
=
mindrecord_schema
...
...
@@ -219,12 +222,12 @@ class TFRecordToMR:
index_id
=
index_id
+
1
for
key
,
val
in
features
.
items
():
cast_key
=
_cast_name
(
key
)
if
key
in
self
.
scalar_set
:
if
cast_
key
in
self
.
scalar_set
:
self
.
_get_data_when_scalar_field
(
ms_dict
,
cast_key
,
key
,
val
)
else
:
if
not
isinstance
(
val
.
numpy
(),
np
.
ndarray
)
and
not
isinstance
(
val
.
numpy
(),
list
):
raise
ValueError
(
"
he response key: {}, value: {} from TFRecord should be a ndarray or list."
.
format
(
key
,
val
))
raise
ValueError
(
"
The response key: {}, value: {} from TFRecord should be a ndarray or "
\
"list."
.
format
(
key
,
val
))
# list set
ms_dict
[
cast_key
]
=
\
np
.
asarray
(
val
,
_cast_string_type_to_np_type
(
self
.
mindrecord_schema
[
cast_key
][
"type"
]))
...
...
tests/ut/python/mindrecord/test_tfrecord_to_mr.py
浏览文件 @
0ab7fd98
...
...
@@ -15,6 +15,7 @@
import
collections
from
importlib
import
import_module
import
os
from
string
import
punctuation
import
numpy
as
np
import
pytest
...
...
@@ -35,6 +36,27 @@ TFRECORD_FILE_NAME = "test.tfrecord"
MINDRECORD_FILE_NAME
=
"test.mindrecord"
PARTITION_NUM
=
1
def
cast_name
(
key
):
"""
Cast schema names which containing special characters to valid names.
Here special characters means any characters in
'!"#$%&
\'
()*+,./:;<=>?@[
\\
]^`{|}~
Valid names can only contain a-z, A-Z, and 0-9 and _
Args:
key (str): original key that might contains special characters.
Returns:
str, casted key that replace the special characters with "_". i.e. if
key is "a b" then returns "a_b".
"""
special_symbols
=
set
(
'{}{}'
.
format
(
punctuation
,
' '
))
special_symbols
.
remove
(
'_'
)
new_key
=
[
'_'
if
x
in
special_symbols
else
x
for
x
in
key
]
casted_key
=
''
.
join
(
new_key
)
return
casted_key
def
verify_data
(
transformer
,
reader
):
"""Verify the data by read from mindrecord"""
tf_iter
=
transformer
.
tfrecord_iterator
()
...
...
@@ -43,14 +65,14 @@ def verify_data(transformer, reader):
count
=
0
for
tf_item
,
mr_item
in
zip
(
tf_iter
,
mr_iter
):
count
=
count
+
1
assert
len
(
tf_item
)
==
6
assert
len
(
mr_item
)
==
6
assert
len
(
tf_item
)
==
len
(
mr_item
)
for
key
,
value
in
tf_item
.
items
():
logger
.
info
(
"key: {}, tfrecord: value: {}, mindrecord: value: {}"
.
format
(
key
,
value
,
mr_item
[
key
]))
logger
.
info
(
"key: {}, tfrecord: value: {}, mindrecord: value: {}"
.
format
(
key
,
value
,
mr_item
[
cast_name
(
key
)]))
if
isinstance
(
value
,
np
.
ndarray
):
assert
(
value
==
mr_item
[
key
]).
all
()
assert
(
value
==
mr_item
[
cast_name
(
key
)
]).
all
()
else
:
assert
value
==
mr_item
[
key
]
assert
value
==
mr_item
[
cast_name
(
key
)
]
assert
count
==
10
def
generate_tfrecord
():
...
...
@@ -102,6 +124,39 @@ def generate_tfrecord():
writer
.
close
()
logger
.
info
(
"Write {} rows in tfrecord."
.
format
(
example_count
))
def
generate_tfrecord_with_special_field_name
():
def
create_int_feature
(
values
):
if
isinstance
(
values
,
list
):
feature
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
# values: [int, int, int]
else
:
feature
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
values
]))
# values: int
return
feature
def
create_bytes_feature
(
values
):
if
isinstance
(
values
,
bytes
):
feature
=
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
values
]))
# values: bytes
else
:
# values: string
feature
=
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
bytes
(
values
,
encoding
=
'utf-8'
)]))
return
feature
writer
=
tf
.
io
.
TFRecordWriter
(
os
.
path
.
join
(
TFRECORD_DATA_DIR
,
TFRECORD_FILE_NAME
))
example_count
=
0
for
i
in
range
(
10
):
label
=
i
image_bytes
=
bytes
(
str
(
"aaaabbbbcccc"
+
str
(
i
)),
encoding
=
"utf-8"
)
features
=
collections
.
OrderedDict
()
features
[
"image/class/label"
]
=
create_int_feature
(
label
)
features
[
"image/encoded"
]
=
create_bytes_feature
(
image_bytes
)
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
example_count
+=
1
writer
.
close
()
logger
.
info
(
"Write {} rows in tfrecord."
.
format
(
example_count
))
def
test_tfrecord_to_mindrecord
():
"""test transform tfrecord to mindrecord."""
if
not
tf
or
tf
.
__version__
<
SupportedTensorFlowVersion
:
...
...
@@ -398,3 +453,110 @@ def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception():
os
.
remove
(
MINDRECORD_FILE_NAME
+
".db"
)
os
.
remove
(
os
.
path
.
join
(
TFRECORD_DATA_DIR
,
TFRECORD_FILE_NAME
))
def
test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_string_type
():
"""test transform tfrecord to mindrecord."""
if
not
tf
or
tf
.
__version__
<
SupportedTensorFlowVersion
:
# skip the test
logger
.
warning
(
"Module tensorflow is not found or version wrong,
\
please use pip install it / reinstall version >= {}."
.
format
(
SupportedTensorFlowVersion
))
return
generate_tfrecord
()
assert
os
.
path
.
exists
(
os
.
path
.
join
(
TFRECORD_DATA_DIR
,
TFRECORD_FILE_NAME
))
feature_dict
=
{
"file_name"
:
tf
.
io
.
FixedLenFeature
([],
tf
.
string
),
"image_bytes"
:
tf
.
io
.
FixedLenFeature
([],
tf
.
string
),
"int64_scalar"
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
"float_scalar"
:
tf
.
io
.
FixedLenFeature
([],
tf
.
float32
),
"int64_list"
:
tf
.
io
.
FixedLenFeature
([
6
],
tf
.
int64
),
"float_list"
:
tf
.
io
.
FixedLenFeature
([
7
],
tf
.
float32
),
}
if
os
.
path
.
exists
(
MINDRECORD_FILE_NAME
):
os
.
remove
(
MINDRECORD_FILE_NAME
)
if
os
.
path
.
exists
(
MINDRECORD_FILE_NAME
+
".db"
):
os
.
remove
(
MINDRECORD_FILE_NAME
+
".db"
)
with
pytest
.
raises
(
ValueError
):
tfrecord_transformer
=
TFRecordToMR
(
os
.
path
.
join
(
TFRECORD_DATA_DIR
,
TFRECORD_FILE_NAME
),
MINDRECORD_FILE_NAME
,
feature_dict
,
[
"int64_list"
])
tfrecord_transformer
.
transform
()
if
os
.
path
.
exists
(
MINDRECORD_FILE_NAME
):
os
.
remove
(
MINDRECORD_FILE_NAME
)
if
os
.
path
.
exists
(
MINDRECORD_FILE_NAME
+
".db"
):
os
.
remove
(
MINDRECORD_FILE_NAME
+
".db"
)
os
.
remove
(
os
.
path
.
join
(
TFRECORD_DATA_DIR
,
TFRECORD_FILE_NAME
))
def
test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_list
():
"""test transform tfrecord to mindrecord."""
if
not
tf
or
tf
.
__version__
<
SupportedTensorFlowVersion
:
# skip the test
logger
.
warning
(
"Module tensorflow is not found or version wrong,
\
please use pip install it / reinstall version >= {}."
.
format
(
SupportedTensorFlowVersion
))
return
generate_tfrecord
()
assert
os
.
path
.
exists
(
os
.
path
.
join
(
TFRECORD_DATA_DIR
,
TFRECORD_FILE_NAME
))
feature_dict
=
{
"file_name"
:
tf
.
io
.
FixedLenFeature
([],
tf
.
string
),
"image_bytes"
:
tf
.
io
.
FixedLenFeature
([],
tf
.
string
),
"int64_scalar"
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
"float_scalar"
:
tf
.
io
.
FixedLenFeature
([],
tf
.
float32
),
"int64_list"
:
tf
.
io
.
FixedLenFeature
([
6
],
tf
.
int64
),
"float_list"
:
tf
.
io
.
FixedLenFeature
([
7
],
tf
.
float32
),
}
if
os
.
path
.
exists
(
MINDRECORD_FILE_NAME
):
os
.
remove
(
MINDRECORD_FILE_NAME
)
if
os
.
path
.
exists
(
MINDRECORD_FILE_NAME
+
".db"
):
os
.
remove
(
MINDRECORD_FILE_NAME
+
".db"
)
with
pytest
.
raises
(
ValueError
):
tfrecord_transformer
=
TFRecordToMR
(
os
.
path
.
join
(
TFRECORD_DATA_DIR
,
TFRECORD_FILE_NAME
),
MINDRECORD_FILE_NAME
,
feature_dict
,
""
)
tfrecord_transformer
.
transform
()
if
os
.
path
.
exists
(
MINDRECORD_FILE_NAME
):
os
.
remove
(
MINDRECORD_FILE_NAME
)
if
os
.
path
.
exists
(
MINDRECORD_FILE_NAME
+
".db"
):
os
.
remove
(
MINDRECORD_FILE_NAME
+
".db"
)
os
.
remove
(
os
.
path
.
join
(
TFRECORD_DATA_DIR
,
TFRECORD_FILE_NAME
))
def
test_tfrecord_to_mindrecord_with_special_field_name
():
"""test transform tfrecord to mindrecord."""
if
not
tf
or
tf
.
__version__
<
SupportedTensorFlowVersion
:
# skip the test
logger
.
warning
(
"Module tensorflow is not found or version wrong,
\
please use pip install it / reinstall version >= {}."
.
format
(
SupportedTensorFlowVersion
))
return
generate_tfrecord_with_special_field_name
()
assert
os
.
path
.
exists
(
os
.
path
.
join
(
TFRECORD_DATA_DIR
,
TFRECORD_FILE_NAME
))
feature_dict
=
{
"image/class/label"
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
"image/encoded"
:
tf
.
io
.
FixedLenFeature
([],
tf
.
string
),
}
if
os
.
path
.
exists
(
MINDRECORD_FILE_NAME
):
os
.
remove
(
MINDRECORD_FILE_NAME
)
if
os
.
path
.
exists
(
MINDRECORD_FILE_NAME
+
".db"
):
os
.
remove
(
MINDRECORD_FILE_NAME
+
".db"
)
tfrecord_transformer
=
TFRecordToMR
(
os
.
path
.
join
(
TFRECORD_DATA_DIR
,
TFRECORD_FILE_NAME
),
MINDRECORD_FILE_NAME
,
feature_dict
,
[
"image/encoded"
])
tfrecord_transformer
.
transform
()
assert
os
.
path
.
exists
(
MINDRECORD_FILE_NAME
)
assert
os
.
path
.
exists
(
MINDRECORD_FILE_NAME
+
".db"
)
fr_mindrecord
=
FileReader
(
MINDRECORD_FILE_NAME
)
verify_data
(
tfrecord_transformer
,
fr_mindrecord
)
os
.
remove
(
MINDRECORD_FILE_NAME
)
os
.
remove
(
MINDRECORD_FILE_NAME
+
".db"
)
os
.
remove
(
os
.
path
.
join
(
TFRECORD_DATA_DIR
,
TFRECORD_FILE_NAME
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录