提交 1de7271a 编写于 作者: J jonyguo

add floatxx test case

上级 9dfb1011
......@@ -335,15 +335,15 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
if (field_type == "INTEGER") {
if (sqlite3_bind_int(stmt, index, std::stoi(field_value)) != SQLITE_OK) {
if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
<< ", field value: " << std::stoi(field_value);
<< ", field value: " << std::stoll(field_value);
return FAILED;
}
} else if (field_type == "NUMERIC") {
if (sqlite3_bind_double(stmt, index, std::stod(field_value)) != SQLITE_OK) {
if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
<< ", field value: " << std::stoi(field_value);
<< ", field value: " << std::stold(field_value);
return FAILED;
}
} else if (field_type == "NULL") {
......
......@@ -17,6 +17,7 @@ This is the test module for mindrecord
"""
import collections
import json
import math
import os
import re
import string
......@@ -1605,3 +1606,149 @@ def test_write_with_multi_array_and_MindDataset():
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset():
mindrecord_file_name = "test.mindrecord"
data = [{"float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12345,
"float64": 1987654321.123456785,
"int32_array": np.array([1, 2, 3, 4, 5], dtype=np.int32),
"int64_array": np.array([48, 49, 50, 51, 123414314, 87], dtype=np.int64),
"int32": 3456,
"int64": 947654321123},
{"float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12445,
"float64": 1987654321.123456786,
"int32_array": np.array([11, 21, 31, 41, 51], dtype=np.int32),
"int64_array": np.array([481, 491, 501, 511, 1234143141, 871], dtype=np.int64),
"int32": 3466,
"int64": 957654321123},
{"float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12545,
"float64": 1987654321.123456787,
"int32_array": np.array([12, 22, 32, 42, 52], dtype=np.int32),
"int64_array": np.array([482, 492, 502, 512, 1234143142, 872], dtype=np.int64),
"int32": 3476,
"int64": 967654321123},
{"float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12645,
"float64": 1987654321.123456788,
"int32_array": np.array([13, 23, 33, 43, 53], dtype=np.int32),
"int64_array": np.array([483, 493, 503, 513, 1234143143, 873], dtype=np.int64),
"int32": 3486,
"int64": 977654321123},
{"float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12745,
"float64": 1987654321.123456789,
"int32_array": np.array([14, 24, 34, 44, 54], dtype=np.int32),
"int64_array": np.array([484, 494, 504, 514, 1234143144, 874], dtype=np.int64),
"int32": 3496,
"int64": 987654321123},
]
writer = FileWriter(mindrecord_file_name)
schema = {"float32_array": {"type": "float32", "shape": [-1]},
"float64_array": {"type": "float64", "shape": [-1]},
"float32": {"type": "float32"},
"float64": {"type": "float64"},
"int32_array": {"type": "int32", "shape": [-1]},
"int64_array": {"type": "int64", "shape": [-1]},
"int32": {"type": "int32"},
"int64": {"type": "int64"}}
writer.add_schema(schema, "data is so cool")
writer.write_raw_data(data)
writer.commit()
# change data value to list - do none
data_value_to_list = []
for item in data:
new_data = {}
new_data['float32_array'] = item["float32_array"]
new_data['float64_array'] = item["float64_array"]
new_data['float32'] = item["float32"]
new_data['float64'] = item["float64"]
new_data['int32_array'] = item["int32_array"]
new_data['int64_array'] = item["int64_array"]
new_data['int32'] = item["int32"]
new_data['int64'] = item["int64"]
data_value_to_list.append(new_data)
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 5
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 8
for field in item:
if isinstance(item[field], np.ndarray):
if item[field].dtype == np.float32:
assert (item[field] ==
np.array(data_value_to_list[num_iter][field], np.float32)).all()
else:
assert (item[field] ==
data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 5
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["float32", "int32"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 5
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 2
for field in item:
if isinstance(item[field], np.ndarray):
if item[field].dtype == np.float32:
assert (item[field] ==
np.array(data_value_to_list[num_iter][field], np.float32)).all()
else:
assert (item[field] ==
data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 5
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["float64", "int64"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 5
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 2
for field in item:
if isinstance(item[field], np.ndarray):
if item[field].dtype == np.float32:
assert (item[field] ==
np.array(data_value_to_list[num_iter][field], np.float32)).all()
elif item[field].dtype == np.float64:
assert math.isclose(item[field],
np.array(data_value_to_list[num_iter][field], np.float64), rel_tol=1e-14)
else:
assert (item[field] ==
data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 5
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册