Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1de7271a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1de7271a
编写于
6月 11, 2020
作者:
J
jonyguo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add floatxx test case
上级
9dfb1011
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
151 addition
and
4 deletion
+151
-4
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
+4
-4
tests/ut/python/dataset/test_minddataset.py
tests/ut/python/dataset/test_minddataset.py
+147
-0
未找到文件。
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
浏览文件 @
1de7271a
...
...
@@ -335,15 +335,15 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
int
index
=
sqlite3_bind_parameter_index
(
stmt
,
common
::
SafeCStr
(
place_holder
));
if
(
field_type
==
"INTEGER"
)
{
if
(
sqlite3_bind_int
(
stmt
,
index
,
std
::
stoi
(
field_value
))
!=
SQLITE_OK
)
{
if
(
sqlite3_bind_int
64
(
stmt
,
index
,
std
::
stoll
(
field_value
))
!=
SQLITE_OK
)
{
MS_LOG
(
ERROR
)
<<
"SQL error: could not bind parameter, index: "
<<
index
<<
", field value: "
<<
std
::
sto
i
(
field_value
);
<<
", field value: "
<<
std
::
sto
ll
(
field_value
);
return
FAILED
;
}
}
else
if
(
field_type
==
"NUMERIC"
)
{
if
(
sqlite3_bind_double
(
stmt
,
index
,
std
::
stod
(
field_value
))
!=
SQLITE_OK
)
{
if
(
sqlite3_bind_double
(
stmt
,
index
,
std
::
sto
l
d
(
field_value
))
!=
SQLITE_OK
)
{
MS_LOG
(
ERROR
)
<<
"SQL error: could not bind parameter, index: "
<<
index
<<
", field value: "
<<
std
::
sto
i
(
field_value
);
<<
", field value: "
<<
std
::
sto
ld
(
field_value
);
return
FAILED
;
}
}
else
if
(
field_type
==
"NULL"
)
{
...
...
tests/ut/python/dataset/test_minddataset.py
浏览文件 @
1de7271a
...
...
@@ -17,6 +17,7 @@ This is the test module for mindrecord
"""
import
collections
import
json
import
math
import
os
import
re
import
string
...
...
@@ -1605,3 +1606,149 @@ def test_write_with_multi_array_and_MindDataset():
os
.
remove
(
"{}"
.
format
(
mindrecord_file_name
))
os
.
remove
(
"{}.db"
.
format
(
mindrecord_file_name
))
def
test_write_with_float32_float64_float32_array_float64_array_and_MindDataset
():
mindrecord_file_name
=
"test.mindrecord"
data
=
[{
"float32_array"
:
np
.
array
([
1.2
,
2.78
,
3.1234
,
4.9871
,
5.12341
],
dtype
=
np
.
float32
),
"float64_array"
:
np
.
array
([
48.1234556789
,
49.3251241431
,
50.13514312414
,
51.8971298471
,
123414314.2141243
,
87.1212122
],
dtype
=
np
.
float64
),
"float32"
:
3456.12345
,
"float64"
:
1987654321.123456785
,
"int32_array"
:
np
.
array
([
1
,
2
,
3
,
4
,
5
],
dtype
=
np
.
int32
),
"int64_array"
:
np
.
array
([
48
,
49
,
50
,
51
,
123414314
,
87
],
dtype
=
np
.
int64
),
"int32"
:
3456
,
"int64"
:
947654321123
},
{
"float32_array"
:
np
.
array
([
1.2
,
2.78
,
4.1234
,
4.9871
,
5.12341
],
dtype
=
np
.
float32
),
"float64_array"
:
np
.
array
([
48.1234556789
,
49.3251241431
,
60.13514312414
,
51.8971298471
,
123414314.2141243
,
87.1212122
],
dtype
=
np
.
float64
),
"float32"
:
3456.12445
,
"float64"
:
1987654321.123456786
,
"int32_array"
:
np
.
array
([
11
,
21
,
31
,
41
,
51
],
dtype
=
np
.
int32
),
"int64_array"
:
np
.
array
([
481
,
491
,
501
,
511
,
1234143141
,
871
],
dtype
=
np
.
int64
),
"int32"
:
3466
,
"int64"
:
957654321123
},
{
"float32_array"
:
np
.
array
([
1.2
,
2.78
,
5.1234
,
4.9871
,
5.12341
],
dtype
=
np
.
float32
),
"float64_array"
:
np
.
array
([
48.1234556789
,
49.3251241431
,
70.13514312414
,
51.8971298471
,
123414314.2141243
,
87.1212122
],
dtype
=
np
.
float64
),
"float32"
:
3456.12545
,
"float64"
:
1987654321.123456787
,
"int32_array"
:
np
.
array
([
12
,
22
,
32
,
42
,
52
],
dtype
=
np
.
int32
),
"int64_array"
:
np
.
array
([
482
,
492
,
502
,
512
,
1234143142
,
872
],
dtype
=
np
.
int64
),
"int32"
:
3476
,
"int64"
:
967654321123
},
{
"float32_array"
:
np
.
array
([
1.2
,
2.78
,
6.1234
,
4.9871
,
5.12341
],
dtype
=
np
.
float32
),
"float64_array"
:
np
.
array
([
48.1234556789
,
49.3251241431
,
80.13514312414
,
51.8971298471
,
123414314.2141243
,
87.1212122
],
dtype
=
np
.
float64
),
"float32"
:
3456.12645
,
"float64"
:
1987654321.123456788
,
"int32_array"
:
np
.
array
([
13
,
23
,
33
,
43
,
53
],
dtype
=
np
.
int32
),
"int64_array"
:
np
.
array
([
483
,
493
,
503
,
513
,
1234143143
,
873
],
dtype
=
np
.
int64
),
"int32"
:
3486
,
"int64"
:
977654321123
},
{
"float32_array"
:
np
.
array
([
1.2
,
2.78
,
7.1234
,
4.9871
,
5.12341
],
dtype
=
np
.
float32
),
"float64_array"
:
np
.
array
([
48.1234556789
,
49.3251241431
,
90.13514312414
,
51.8971298471
,
123414314.2141243
,
87.1212122
],
dtype
=
np
.
float64
),
"float32"
:
3456.12745
,
"float64"
:
1987654321.123456789
,
"int32_array"
:
np
.
array
([
14
,
24
,
34
,
44
,
54
],
dtype
=
np
.
int32
),
"int64_array"
:
np
.
array
([
484
,
494
,
504
,
514
,
1234143144
,
874
],
dtype
=
np
.
int64
),
"int32"
:
3496
,
"int64"
:
987654321123
},
]
writer
=
FileWriter
(
mindrecord_file_name
)
schema
=
{
"float32_array"
:
{
"type"
:
"float32"
,
"shape"
:
[
-
1
]},
"float64_array"
:
{
"type"
:
"float64"
,
"shape"
:
[
-
1
]},
"float32"
:
{
"type"
:
"float32"
},
"float64"
:
{
"type"
:
"float64"
},
"int32_array"
:
{
"type"
:
"int32"
,
"shape"
:
[
-
1
]},
"int64_array"
:
{
"type"
:
"int64"
,
"shape"
:
[
-
1
]},
"int32"
:
{
"type"
:
"int32"
},
"int64"
:
{
"type"
:
"int64"
}}
writer
.
add_schema
(
schema
,
"data is so cool"
)
writer
.
write_raw_data
(
data
)
writer
.
commit
()
# change data value to list - do none
data_value_to_list
=
[]
for
item
in
data
:
new_data
=
{}
new_data
[
'float32_array'
]
=
item
[
"float32_array"
]
new_data
[
'float64_array'
]
=
item
[
"float64_array"
]
new_data
[
'float32'
]
=
item
[
"float32"
]
new_data
[
'float64'
]
=
item
[
"float64"
]
new_data
[
'int32_array'
]
=
item
[
"int32_array"
]
new_data
[
'int64_array'
]
=
item
[
"int64_array"
]
new_data
[
'int32'
]
=
item
[
"int32"
]
new_data
[
'int64'
]
=
item
[
"int64"
]
data_value_to_list
.
append
(
new_data
)
num_readers
=
2
data_set
=
ds
.
MindDataset
(
dataset_file
=
mindrecord_file_name
,
num_parallel_workers
=
num_readers
,
shuffle
=
False
)
assert
data_set
.
get_dataset_size
()
==
5
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
assert
len
(
item
)
==
8
for
field
in
item
:
if
isinstance
(
item
[
field
],
np
.
ndarray
):
if
item
[
field
].
dtype
==
np
.
float32
:
assert
(
item
[
field
]
==
np
.
array
(
data_value_to_list
[
num_iter
][
field
],
np
.
float32
)).
all
()
else
:
assert
(
item
[
field
]
==
data_value_to_list
[
num_iter
][
field
]).
all
()
else
:
assert
item
[
field
]
==
data_value_to_list
[
num_iter
][
field
]
num_iter
+=
1
assert
num_iter
==
5
num_readers
=
2
data_set
=
ds
.
MindDataset
(
dataset_file
=
mindrecord_file_name
,
columns_list
=
[
"float32"
,
"int32"
],
num_parallel_workers
=
num_readers
,
shuffle
=
False
)
assert
data_set
.
get_dataset_size
()
==
5
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
assert
len
(
item
)
==
2
for
field
in
item
:
if
isinstance
(
item
[
field
],
np
.
ndarray
):
if
item
[
field
].
dtype
==
np
.
float32
:
assert
(
item
[
field
]
==
np
.
array
(
data_value_to_list
[
num_iter
][
field
],
np
.
float32
)).
all
()
else
:
assert
(
item
[
field
]
==
data_value_to_list
[
num_iter
][
field
]).
all
()
else
:
assert
item
[
field
]
==
data_value_to_list
[
num_iter
][
field
]
num_iter
+=
1
assert
num_iter
==
5
num_readers
=
2
data_set
=
ds
.
MindDataset
(
dataset_file
=
mindrecord_file_name
,
columns_list
=
[
"float64"
,
"int64"
],
num_parallel_workers
=
num_readers
,
shuffle
=
False
)
assert
data_set
.
get_dataset_size
()
==
5
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
assert
len
(
item
)
==
2
for
field
in
item
:
if
isinstance
(
item
[
field
],
np
.
ndarray
):
if
item
[
field
].
dtype
==
np
.
float32
:
assert
(
item
[
field
]
==
np
.
array
(
data_value_to_list
[
num_iter
][
field
],
np
.
float32
)).
all
()
elif
item
[
field
].
dtype
==
np
.
float64
:
assert
math
.
isclose
(
item
[
field
],
np
.
array
(
data_value_to_list
[
num_iter
][
field
],
np
.
float64
),
rel_tol
=
1e-14
)
else
:
assert
(
item
[
field
]
==
data_value_to_list
[
num_iter
][
field
]).
all
()
else
:
assert
item
[
field
]
==
data_value_to_list
[
num_iter
][
field
]
num_iter
+=
1
assert
num_iter
==
5
os
.
remove
(
"{}"
.
format
(
mindrecord_file_name
))
os
.
remove
(
"{}.db"
.
format
(
mindrecord_file_name
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录