提交 a9443635 编写于 作者: J jonyguo

fix: mindpage enhance parameter check and search by filename failed

上级 aaa8d9ed
......@@ -33,6 +33,7 @@
#include <map>
#include <random>
#include <set>
#include <sstream>
#include <string>
#include <thread>
#include <unordered_map>
......@@ -117,6 +118,12 @@ const char kPoint = '.';
// field type used by check schema validation
const std::set<std::string> kFieldTypeSet = {"bytes", "string", "int32", "int64", "float32", "float64"};
// can be searched field list
const std::set<std::string> kScalarFieldTypeSet = {"string", "int32", "int64", "float32", "float64"};
// number field list
const std::set<std::string> kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"};
/// \brief split a string using a character
/// \param[in] field target string
/// \param[in] separator a character for spliting
......
......@@ -42,11 +42,11 @@ class ShardIndexGenerator {
~ShardIndexGenerator() {}
/// \brief fetch value in json by field path
/// \param[in] field_path
/// \param[in] schema
/// \return the vector of value
static std::vector<std::string> GetField(const std::string &field_path, json schema);
/// \brief fetch value in json by field name
/// \param[in] field
/// \param[in] input
/// \return pair<MSRStatus, value>
std::pair<MSRStatus, std::string> GetValueByField(const string &field, json input);
/// \brief fetch field type in schema n by field path
/// \param[in] field_path
......
......@@ -38,7 +38,7 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe
MSRStatus ShardIndexGenerator::Build() {
ShardHeader header = ShardHeader();
if (header.Build(file_path_) != SUCCESS) {
MS_LOG(ERROR) << "Build shard schema failed";
MS_LOG(ERROR) << "Build shard schema failed.";
return FAILED;
}
shard_header_ = header;
......@@ -46,35 +46,49 @@ MSRStatus ShardIndexGenerator::Build() {
return SUCCESS;
}
std::vector<std::string> ShardIndexGenerator::GetField(const string &field_path, json schema) {
std::vector<std::string> field_name = StringSplit(field_path, kPoint);
std::vector<std::string> res;
if (schema.empty()) {
res.emplace_back("null");
return res;
std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const string &field, json input) {
if (field.empty()) {
MS_LOG(ERROR) << "The input field is None.";
return {FAILED, ""};
}
for (uint64_t i = 0; i < field_name.size(); i++) {
// Check if field is part of an array of objects
auto &child = schema.at(field_name[i]);
if (child.is_array() && !child.empty() && child[0].is_object()) {
schema = schema[field_name[i]];
std::string new_field_path;
for (uint64_t j = i + 1; j < field_name.size(); j++) {
if (j > i + 1) new_field_path += '.';
new_field_path += field_name[j];
}
// Return multiple field data since multiple objects in array
for (auto &single_schema : schema) {
auto child_res = GetField(new_field_path, single_schema);
res.insert(res.end(), child_res.begin(), child_res.end());
}
return res;
if (input.empty()) {
MS_LOG(ERROR) << "The input json is None.";
return {FAILED, ""};
}
schema = schema.at(field_name[i]);
// parameter input does not contain the field
if (input.find(field) == input.end()) {
MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input;
return {FAILED, ""};
}
// schema does not contain the field
auto schema = shard_header_.get_schemas()[0]->GetSchema()["schema"];
if (schema.find(field) == schema.end()) {
MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema;
return {FAILED, ""};
}
// Return vector of one field data (not array of objects)
return std::vector<std::string>{schema.dump()};
// field should be scalar type
if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) {
MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable";
return {FAILED, ""};
}
if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) {
auto schema_field_options = schema[field];
if (schema_field_options.find("shape") == schema_field_options.end()) {
return {SUCCESS, input[field].dump()};
} else {
// field with shape option
MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable";
return {FAILED, ""};
}
}
// the field type is string in here
return {SUCCESS, input[field].get<std::string>()};
}
std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) {
......@@ -304,6 +318,7 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
const auto &place_holder = std::get<0>(field);
const auto &field_type = std::get<1>(field);
const auto &field_value = std::get<2>(field);
int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
if (field_type == "INTEGER") {
if (sqlite3_bind_int(stmt, index, std::stoi(field_value)) != SQLITE_OK) {
......@@ -463,17 +478,24 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s
if (field.first >= schema_detail.size()) {
return {FAILED, {}};
}
auto field_value = GetField(field.second, schema_detail[field.first]);
auto field_value = GetValueByField(field.second, schema_detail[field.first]);
if (field_value.first != SUCCESS) {
MS_LOG(ERROR) << "Get value from json by field name failed";
return {FAILED, {}};
}
auto result = shard_header_.GetSchemaByID(field.first);
if (result.second != SUCCESS) {
return {FAILED, {}};
}
std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"]));
auto ret = GenerateFieldName(field);
if (ret.first != SUCCESS) {
return {FAILED, {}};
}
fields.emplace_back(ret.second, field_type, field_value[0]);
fields.emplace_back(ret.second, field_type, field_value.second);
}
return {SUCCESS, std::move(fields)};
}
......
......@@ -25,6 +25,15 @@ using mindspore::MsLogLevel::INFO;
namespace mindspore {
namespace mindrecord {
template <class Type>
// convert the string to exactly number type (int32_t/int64_t/float/double)
Type StringToNum(const std::string &str) {
std::istringstream iss(str);
Type num;
iss >> num;
return num;
}
ShardReader::ShardReader() {
task_id_ = 0;
deliver_id_ = 0;
......@@ -259,16 +268,25 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
}
column_values[shard_id].emplace_back(tmp);
} else {
string json_str = "{";
json construct_json;
for (unsigned int j = 0; j < columns.size(); ++j) {
// construct the string json "f1": value
json_str = json_str + "\"" + columns[j] + "\":" + labels[i][j + 3];
if (j < columns.size() - 1) {
json_str += ",";
// construct json "f1": value
auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"];
// convert the string to base type by schema
if (schema[columns[j]]["type"] == "int32") {
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j + 3]);
} else if (schema[columns[j]]["type"] == "int64") {
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j + 3]);
} else if (schema[columns[j]]["type"] == "float32") {
construct_json[columns[j]] = StringToNum<float>(labels[i][j + 3]);
} else if (schema[columns[j]]["type"] == "float64") {
construct_json[columns[j]] = StringToNum<double>(labels[i][j + 3]);
} else {
construct_json[columns[j]] = std::string(labels[i][j + 3]);
}
}
json_str += "}";
column_values[shard_id].emplace_back(json::parse(json_str));
column_values[shard_id].emplace_back(construct_json);
}
}
......@@ -402,7 +420,16 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
// whether use index search
if (!criteria.first.empty()) {
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second;
auto schema = shard_header_->get_schemas()[0]->GetSchema();
// not number field should add '' in sql
if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) {
sql +=
" AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second;
} else {
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" +
criteria.second + "'";
}
}
sql += ";";
std::vector<std::vector<std::string>> image_offsets;
......@@ -603,16 +630,25 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
std::vector<json> ret;
for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{});
for (unsigned int i = 0; i < labels.size(); ++i) {
string json_str = "{";
json construct_json;
for (unsigned int j = 0; j < columns.size(); ++j) {
// construct string json "f1": value
json_str = json_str + "\"" + columns[j] + "\":" + labels[i][j];
if (j < columns.size() - 1) {
json_str += ",";
// construct json "f1": value
auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"];
// convert the string to base type by schema
if (schema[columns[j]]["type"] == "int32") {
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j]);
} else if (schema[columns[j]]["type"] == "int64") {
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j]);
} else if (schema[columns[j]]["type"] == "float32") {
construct_json[columns[j]] = StringToNum<float>(labels[i][j]);
} else if (schema[columns[j]]["type"] == "float64") {
construct_json[columns[j]] = StringToNum<double>(labels[i][j]);
} else {
construct_json[columns[j]] = std::string(labels[i][j]);
}
}
json_str += "}";
ret[i] = json::parse(json_str);
ret[i] = construct_json;
}
return {SUCCESS, ret};
}
......
......@@ -311,14 +311,23 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardS
MS_LOG(ERROR) << "Get category info";
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
// category_name to category_id
int64_t category_id = -1;
for (const auto &categories : ret.second) {
if (std::get<1>(categories) == category_name) {
auto result = ReadAllAtPageById(std::get<0>(categories), page_no, n_rows_of_page);
return {SUCCESS, result.second};
std::string categories_name = std::get<1>(categories);
if (categories_name == category_name) {
category_id = std::get<0>(categories);
break;
}
}
return {SUCCESS, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
if (category_id == -1) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
return ReadAllAtPageById(category_id, page_no, n_rows_of_page);
}
std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByIdPy(
......
......@@ -133,15 +133,15 @@ class MindPage:
Raises:
ParamValueError: If any parameter is invalid.
MRMFetchDataError: If failed to read by category id.
MRMFetchDataError: If failed to fetch data by category.
MRMUnsupportedSchemaError: If schema is invalid.
"""
if category_id < 0:
raise ParamValueError("Category id should be greater than 0.")
if page < 0:
raise ParamValueError("Page should be greater than 0.")
if num_row < 0:
raise ParamValueError("num_row should be greater than 0.")
if not isinstance(category_id, int) or category_id < 0:
raise ParamValueError("Category id should be int and greater than or equal to 0.")
if not isinstance(page, int) or page < 0:
raise ParamValueError("Page should be int and greater than or equal to 0.")
if not isinstance(num_row, int) or num_row <= 0:
raise ParamValueError("num_row should be int and greater than 0.")
return self._segment.read_at_page_by_id(category_id, page, num_row)
def read_at_page_by_name(self, category_name, page, num_row):
......@@ -157,8 +157,10 @@ class MindPage:
Returns:
str, read at page.
"""
if page < 0:
raise ParamValueError("Page should be greater than 0.")
if num_row < 0:
raise ParamValueError("num_row should be greater than 0.")
if not isinstance(category_name, str):
raise ParamValueError("Category name should be str.")
if not isinstance(page, int) or page < 0:
raise ParamValueError("Page should be int and greater than or equal to 0.")
if not isinstance(num_row, int) or num_row <= 0:
raise ParamValueError("num_row should be int and greater than 0.")
return self._segment.read_at_page_by_name(category_name, page, num_row)
......@@ -53,6 +53,7 @@ class TestShardIndexGenerator : public UT::Common {
TestShardIndexGenerator() {}
};
/*
TEST_F(TestShardIndexGenerator, GetField) {
MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field");
......@@ -82,6 +83,8 @@ TEST_F(TestShardIndexGenerator, GetField) {
}
}
}
*/
TEST_F(TestShardIndexGenerator, TakeFieldType) {
MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type");
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""test mindrecord base"""
import numpy as np
import os
import uuid
from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS
......@@ -25,6 +26,105 @@ CV2_FILE_NAME = "./imagenet_loop.mindrecord"
CV3_FILE_NAME = "./imagenet_append.mindrecord"
NLP_FILE_NAME = "./aclImdb.mindrecord"
def test_write_read_process():
mindrecord_file_name = "test.mindrecord"
data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64),
"segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
"data": bytes("image bytes abc", encoding='UTF-8')},
{"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64),
"segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32),
"data": bytes("image bytes def", encoding='UTF-8')},
{"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64),
"segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32),
"data": bytes("image bytes ghi", encoding='UTF-8')},
{"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64),
"segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32),
"data": bytes("image bytes jkl", encoding='UTF-8')},
{"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64),
"segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32),
"data": bytes("image bytes mno", encoding='UTF-8')},
{"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64),
"segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32),
"data": bytes("image bytes pqr", encoding='UTF-8')}
]
writer = FileWriter(mindrecord_file_name)
schema = {"file_name": {"type": "string"},
"label": {"type": "int32"},
"score": {"type": "float64"},
"mask": {"type": "int64", "shape": [-1]},
"segments": {"type": "float32", "shape": [2, 2]},
"data": {"type": "bytes"}}
writer.add_schema(schema, "data is so cool")
writer.write_raw_data(data)
writer.commit()
reader = FileReader(mindrecord_file_name)
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 6
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
def test_write_read_process_with_define_index_field():
mindrecord_file_name = "test.mindrecord"
data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64),
"segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
"data": bytes("image bytes abc", encoding='UTF-8')},
{"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64),
"segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32),
"data": bytes("image bytes def", encoding='UTF-8')},
{"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64),
"segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32),
"data": bytes("image bytes ghi", encoding='UTF-8')},
{"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64),
"segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32),
"data": bytes("image bytes jkl", encoding='UTF-8')},
{"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64),
"segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32),
"data": bytes("image bytes mno", encoding='UTF-8')},
{"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64),
"segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32),
"data": bytes("image bytes pqr", encoding='UTF-8')}
]
writer = FileWriter(mindrecord_file_name)
schema = {"file_name": {"type": "string"},
"label": {"type": "int32"},
"score": {"type": "float64"},
"mask": {"type": "int64", "shape": [-1]},
"segments": {"type": "float32", "shape": [2, 2]},
"data": {"type": "bytes"}}
writer.add_schema(schema, "data is so cool")
writer.add_index(["label"])
writer.write_raw_data(data)
writer.commit()
reader = FileReader(mindrecord_file_name)
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 6
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
def test_cv_file_writer_tutorial():
"""tutorial for cv dataset writer."""
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
......@@ -137,6 +237,51 @@ def test_cv_page_reader_tutorial():
assert len(row1[0]) == 3
assert row1[0]['label'] == 822
def test_cv_page_reader_tutorial_by_file_name():
"""tutorial for cv page reader."""
reader = MindPage(CV_FILE_NAME + "0")
fields = reader.get_category_fields()
assert fields == ['file_name', 'label'],\
'failed on getting candidate category fields.'
ret = reader.set_category_field("file_name")
assert ret == SUCCESS, 'failed on setting category field.'
info = reader.read_category_info()
logger.info("category info: {}".format(info))
row = reader.read_at_page_by_id(0, 0, 1)
assert len(row) == 1
assert len(row[0]) == 3
assert row[0]['label'] == 490
row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1)
assert len(row1) == 1
assert len(row1[0]) == 3
assert row1[0]['label'] == 13
def test_cv_page_reader_tutorial_new_api():
"""tutorial for cv page reader."""
reader = MindPage(CV_FILE_NAME + "0")
fields = reader.candidate_fields
assert fields == ['file_name', 'label'],\
'failed on getting candidate category fields.'
reader.category_field = "file_name"
info = reader.read_category_info()
logger.info("category info: {}".format(info))
row = reader.read_at_page_by_id(0, 0, 1)
assert len(row) == 1
assert len(row[0]) == 3
assert row[0]['label'] == 490
row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1)
assert len(row1) == 1
assert len(row1[0]) == 3
assert row1[0]['label'] == 13
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
......
......@@ -15,8 +15,9 @@
"""test mindrecord exception"""
import os
import pytest
from mindspore.mindrecord import FileWriter, FileReader, MindPage
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError
from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError, \
MRMFetchDataError
from mindspore import log as logger
from utils import get_data
......@@ -286,3 +287,67 @@ def test_add_index_without_add_schema():
fw = FileWriter(CV_FILE_NAME)
fw.add_index(["label"])
assert 'Failed to get meta info' in str(err.value)
def test_mindpage_pageno_pagesize_not_int():
"""test page reader when some partition does not exist."""
create_cv_mindrecord(4)
reader = MindPage(CV_FILE_NAME + "0")
fields = reader.get_category_fields()
assert fields == ['file_name', 'label'],\
'failed on getting candidate category fields.'
ret = reader.set_category_field("label")
assert ret == SUCCESS, 'failed on setting category field.'
info = reader.read_category_info()
logger.info("category info: {}".format(info))
with pytest.raises(ParamValueError) as err:
reader.read_at_page_by_id(0, "0", 1)
with pytest.raises(ParamValueError) as err:
reader.read_at_page_by_id(0, 0, "b")
with pytest.raises(ParamValueError) as err:
reader.read_at_page_by_name("822", "e", 1)
with pytest.raises(ParamValueError) as err:
reader.read_at_page_by_name("822", 0, "qwer")
with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."):
reader.read_at_page_by_id(99999, 0, 1)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
def test_mindpage_filename_not_exist():
"""test page reader when some partition does not exist."""
create_cv_mindrecord(4)
reader = MindPage(CV_FILE_NAME + "0")
fields = reader.get_category_fields()
assert fields == ['file_name', 'label'],\
'failed on getting candidate category fields.'
ret = reader.set_category_field("file_name")
assert ret == SUCCESS, 'failed on setting category field.'
info = reader.read_category_info()
logger.info("category info: {}".format(info))
with pytest.raises(MRMFetchDataError) as err:
reader.read_at_page_by_id(9999, 0, 1)
with pytest.raises(MRMFetchDataError) as err:
reader.read_at_page_by_name("abc.jpg", 0, 1)
with pytest.raises(ParamValueError) as err:
reader.read_at_page_by_name(1, 0, 1)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册