Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
28ebd730
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
28ebd730
编写于
7月 16, 2020
作者:
L
liyong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug when int or float is numpy type
上级
74f2c89d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
47 addition
and
1 deletion
+47
-1
mindspore/mindrecord/shardwriter.py
mindspore/mindrecord/shardwriter.py
+8
-1
tests/ut/python/dataset/test_minddataset.py
tests/ut/python/dataset/test_minddataset.py
+39
-0
未找到文件。
mindspore/mindrecord/shardwriter.py
浏览文件 @
28ebd730
...
...
@@ -29,6 +29,7 @@ class ShardWriter:
The class would write MindRecord File series.
"""
def
__init__
(
self
):
self
.
_writer
=
ms
.
ShardWriter
()
self
.
_header
=
None
...
...
@@ -161,7 +162,7 @@ class ShardWriter:
if
row_blob
:
blob_data
.
append
(
list
(
row_blob
))
# filter raw data according to schema
row_raw
=
{
field
:
item
[
field
]
row_raw
=
{
field
:
self
.
_convert_np_types
(
item
[
field
])
for
field
in
self
.
_header
.
schema
.
keys
()
-
self
.
_header
.
blob_fields
if
field
in
item
}
if
row_raw
:
raw_data
.
append
(
row_raw
)
...
...
@@ -172,6 +173,12 @@ class ShardWriter:
raise
MRMWriteDatasetError
return
ret
def
_convert_np_types
(
self
,
val
):
"""convert numpy type to python primitive type"""
if
isinstance
(
val
,
(
np
.
int32
,
np
.
int64
,
np
.
float32
,
np
.
float64
)):
return
val
.
item
()
return
val
def
_merge_blob
(
self
,
blob_data
):
"""
Merge multiple blob data whose type is bytes or ndarray
...
...
tests/ut/python/dataset/test_minddataset.py
浏览文件 @
28ebd730
...
...
@@ -1853,3 +1853,42 @@ def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset(
os
.
remove
(
"{}"
.
format
(
mindrecord_file_name
))
os
.
remove
(
"{}.db"
.
format
(
mindrecord_file_name
))
def
test_numpy_generic
():
paths
=
[
"{}{}"
.
format
(
CV_FILE_NAME
,
str
(
x
).
rjust
(
1
,
'0'
))
for
x
in
range
(
FILES_NUM
)]
for
x
in
paths
:
if
os
.
path
.
exists
(
"{}"
.
format
(
x
)):
os
.
remove
(
"{}"
.
format
(
x
))
if
os
.
path
.
exists
(
"{}.db"
.
format
(
x
)):
os
.
remove
(
"{}.db"
.
format
(
x
))
writer
=
FileWriter
(
CV_FILE_NAME
,
FILES_NUM
)
cv_schema_json
=
{
"label1"
:
{
"type"
:
"int32"
},
"label2"
:
{
"type"
:
"int64"
},
"label3"
:
{
"type"
:
"float32"
},
"label4"
:
{
"type"
:
"float64"
}}
data
=
[]
for
idx
in
range
(
10
):
row
=
{}
row
[
'label1'
]
=
np
.
int32
(
idx
)
row
[
'label2'
]
=
np
.
int64
(
idx
*
10
)
row
[
'label3'
]
=
np
.
float32
(
idx
+
0.12345
)
row
[
'label4'
]
=
np
.
float64
(
idx
+
0.12345789
)
data
.
append
(
row
)
writer
.
add_schema
(
cv_schema_json
,
"img_schema"
)
writer
.
write_raw_data
(
data
)
writer
.
commit
()
num_readers
=
4
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
None
,
num_readers
,
shuffle
=
False
)
assert
data_set
.
get_dataset_size
()
==
10
idx
=
0
for
item
in
data_set
.
create_dict_iterator
():
assert
item
[
'label1'
]
==
item
[
'label1'
]
assert
item
[
'label2'
]
==
item
[
'label2'
]
assert
item
[
'label3'
]
==
item
[
'label3'
]
assert
item
[
'label4'
]
==
item
[
'label4'
]
idx
+=
1
assert
idx
==
10
for
x
in
paths
:
os
.
remove
(
"{}"
.
format
(
x
))
os
.
remove
(
"{}.db"
.
format
(
x
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录