Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
948ff63a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
948ff63a
编写于
4月 01, 2020
作者:
Y
yanzhenxiang2020
提交者:
高东海
4月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix mindrecord ut long time
上级
d75745bc
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
20 addition
and
20 deletion
+20
-20
mindspore/mindrecord/tools/mnist_to_mr.py
mindspore/mindrecord/tools/mnist_to_mr.py
+9
-9
tests/ut/data/mindrecord/testMnistData/t10k-images-idx3-ubyte.gz
...t/data/mindrecord/testMnistData/t10k-images-idx3-ubyte.gz
+0
-0
tests/ut/data/mindrecord/testMnistData/t10k-labels-idx1-ubyte.gz
...t/data/mindrecord/testMnistData/t10k-labels-idx1-ubyte.gz
+0
-0
tests/ut/data/mindrecord/testMnistData/train-images-idx3-ubyte.gz
.../data/mindrecord/testMnistData/train-images-idx3-ubyte.gz
+0
-0
tests/ut/data/mindrecord/testMnistData/train-labels-idx1-ubyte.gz
.../data/mindrecord/testMnistData/train-labels-idx1-ubyte.gz
+0
-0
tests/ut/python/mindrecord/test_mindrecord_base.py
tests/ut/python/mindrecord/test_mindrecord_base.py
+5
-5
tests/ut/python/mindrecord/test_mnist_to_mr.py
tests/ut/python/mindrecord/test_mnist_to_mr.py
+6
-6
未找到文件。
mindspore/mindrecord/tools/mnist_to_mr.py
浏览文件 @
948ff63a
...
...
@@ -77,20 +77,20 @@ class MnistToMR:
self
.
mnist_schema_json
=
{
"label"
:
{
"type"
:
"int64"
},
"data"
:
{
"type"
:
"bytes"
}}
def
_extract_images
(
self
,
filename
,
num_images
):
def
_extract_images
(
self
,
filename
):
"""Extract the images into a 4D tensor [image index, y, x, channels]."""
with
gzip
.
open
(
filename
)
as
bytestream
:
bytestream
.
read
(
16
)
buf
=
bytestream
.
read
(
self
.
image_size
*
self
.
image_size
*
num_images
*
self
.
num_channels
)
buf
=
bytestream
.
read
()
data
=
np
.
frombuffer
(
buf
,
dtype
=
np
.
uint8
)
data
=
data
.
reshape
(
num_images
,
self
.
image_size
,
self
.
image_size
,
self
.
num_channels
)
data
=
data
.
reshape
(
-
1
,
self
.
image_size
,
self
.
image_size
,
self
.
num_channels
)
return
data
def
_extract_labels
(
self
,
filename
,
num_images
):
def
_extract_labels
(
self
,
filename
):
"""Extract the labels into a vector of int64 label IDs."""
with
gzip
.
open
(
filename
)
as
bytestream
:
bytestream
.
read
(
8
)
buf
=
bytestream
.
read
(
1
*
num_images
)
buf
=
bytestream
.
read
()
labels
=
np
.
frombuffer
(
buf
,
dtype
=
np
.
uint8
).
astype
(
np
.
int64
)
return
labels
...
...
@@ -101,8 +101,8 @@ class MnistToMR:
Yields:
data (dict of list): mnist data list which contains dict.
"""
train_data
=
self
.
_extract_images
(
self
.
train_data_filename_
,
60000
)
train_labels
=
self
.
_extract_labels
(
self
.
train_labels_filename_
,
60000
)
train_data
=
self
.
_extract_images
(
self
.
train_data_filename_
)
train_labels
=
self
.
_extract_labels
(
self
.
train_labels_filename_
)
for
data
,
label
in
zip
(
train_data
,
train_labels
):
_
,
img
=
cv2
.
imencode
(
".jpeg"
,
data
)
yield
{
"label"
:
int
(
label
),
"data"
:
img
.
tobytes
()}
...
...
@@ -114,8 +114,8 @@ class MnistToMR:
Yields:
data (dict of list): mnist data list which contains dict.
"""
test_data
=
self
.
_extract_images
(
self
.
test_data_filename_
,
10000
)
test_labels
=
self
.
_extract_labels
(
self
.
test_labels_filename_
,
10000
)
test_data
=
self
.
_extract_images
(
self
.
test_data_filename_
)
test_labels
=
self
.
_extract_labels
(
self
.
test_labels_filename_
)
for
data
,
label
in
zip
(
test_data
,
test_labels
):
_
,
img
=
cv2
.
imencode
(
".jpeg"
,
data
)
yield
{
"label"
:
int
(
label
),
"data"
:
img
.
tobytes
()}
...
...
tests/ut/data/mindrecord/testMnistData/t10k-images-idx3-ubyte.gz
浏览文件 @
948ff63a
无法预览此类型文件
tests/ut/data/mindrecord/testMnistData/t10k-labels-idx1-ubyte.gz
浏览文件 @
948ff63a
无法预览此类型文件
tests/ut/data/mindrecord/testMnistData/train-images-idx3-ubyte.gz
浏览文件 @
948ff63a
无法预览此类型文件
tests/ut/data/mindrecord/testMnistData/train-labels-idx1-ubyte.gz
浏览文件 @
948ff63a
无法预览此类型文件
tests/ut/python/mindrecord/test_mindrecord_base.py
浏览文件 @
948ff63a
...
...
@@ -203,9 +203,9 @@ def test_nlp_page_reader_tutorial():
os
.
remove
(
"{}"
.
format
(
x
))
os
.
remove
(
"{}.db"
.
format
(
x
))
def
test_cv_file_writer_shard_num_10
00
():
"""test file writer when shard num equals 10
00
."""
writer
=
FileWriter
(
CV_FILE_NAME
,
10
00
)
def
test_cv_file_writer_shard_num_10
():
"""test file writer when shard num equals 10."""
writer
=
FileWriter
(
CV_FILE_NAME
,
10
)
data
=
get_data
(
"../data/mindrecord/testImageNetData/"
)
cv_schema_json
=
{
"file_name"
:
{
"type"
:
"string"
},
"label"
:
{
"type"
:
"int64"
},
"data"
:
{
"type"
:
"bytes"
}}
...
...
@@ -214,8 +214,8 @@ def test_cv_file_writer_shard_num_1000():
writer
.
write_raw_data
(
data
)
writer
.
commit
()
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
3
,
'0'
))
for
x
in
range
(
100
0
)]
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
for
x
in
range
(
1
0
)]
for
x
in
paths
:
os
.
remove
(
"{}"
.
format
(
x
))
os
.
remove
(
"{}.db"
.
format
(
x
))
...
...
tests/ut/python/mindrecord/test_mnist_to_mr.py
浏览文件 @
948ff63a
...
...
@@ -37,7 +37,7 @@ def read(train_name, test_name):
count
=
count
+
1
if
count
==
1
:
logger
.
info
(
"data: {}"
.
format
(
x
))
assert
count
==
6000
0
assert
count
==
2
0
reader
.
close
()
count
=
0
...
...
@@ -47,7 +47,7 @@ def read(train_name, test_name):
count
=
count
+
1
if
count
==
1
:
logger
.
info
(
"data: {}"
.
format
(
x
))
assert
count
==
10
000
assert
count
==
10
reader
.
close
()
...
...
@@ -102,10 +102,10 @@ def test_mnist_to_mindrecord_compare_data():
't10k-images-idx3-ubyte.gz'
)
test_labels_filename_
=
os
.
path
.
join
(
MNIST_DIR
,
't10k-labels-idx1-ubyte.gz'
)
train_data
=
_extract_images
(
train_data_filename_
,
6000
0
)
train_labels
=
_extract_labels
(
train_labels_filename_
,
6000
0
)
test_data
=
_extract_images
(
test_data_filename_
,
10
000
)
test_labels
=
_extract_labels
(
test_labels_filename_
,
10
000
)
train_data
=
_extract_images
(
train_data_filename_
,
2
0
)
train_labels
=
_extract_labels
(
train_labels_filename_
,
2
0
)
test_data
=
_extract_images
(
test_data_filename_
,
10
)
test_labels
=
_extract_labels
(
test_labels_filename_
,
10
)
reader
=
FileReader
(
train_name
)
for
x
,
data
,
label
in
zip
(
reader
.
get_next
(),
train_data
,
train_labels
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录