提交 a9443635 编写于 作者: J jonyguo

fix: mindpage enhance parameter check and search by filename failed

上级 aaa8d9ed
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <map> #include <map>
#include <random> #include <random>
#include <set> #include <set>
#include <sstream>
#include <string> #include <string>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
...@@ -117,6 +118,12 @@ const char kPoint = '.'; ...@@ -117,6 +118,12 @@ const char kPoint = '.';
// field type used by check schema validation // field type used by check schema validation
const std::set<std::string> kFieldTypeSet = {"bytes", "string", "int32", "int64", "float32", "float64"}; const std::set<std::string> kFieldTypeSet = {"bytes", "string", "int32", "int64", "float32", "float64"};
// can be searched field list
const std::set<std::string> kScalarFieldTypeSet = {"string", "int32", "int64", "float32", "float64"};
// number field list
const std::set<std::string> kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"};
/// \brief split a string using a character /// \brief split a string using a character
/// \param[in] field target string /// \param[in] field target string
/// \param[in] separator a character for spliting /// \param[in] separator a character for spliting
......
...@@ -42,11 +42,11 @@ class ShardIndexGenerator { ...@@ -42,11 +42,11 @@ class ShardIndexGenerator {
~ShardIndexGenerator() {} ~ShardIndexGenerator() {}
/// \brief fetch value in json by field path /// \brief fetch value in json by field name
/// \param[in] field_path /// \param[in] field
/// \param[in] schema /// \param[in] input
/// \return the vector of value /// \return pair<MSRStatus, value>
static std::vector<std::string> GetField(const std::string &field_path, json schema); std::pair<MSRStatus, std::string> GetValueByField(const string &field, json input);
/// \brief fetch field type in schema n by field path /// \brief fetch field type in schema n by field path
/// \param[in] field_path /// \param[in] field_path
......
...@@ -38,7 +38,7 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe ...@@ -38,7 +38,7 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe
MSRStatus ShardIndexGenerator::Build() { MSRStatus ShardIndexGenerator::Build() {
ShardHeader header = ShardHeader(); ShardHeader header = ShardHeader();
if (header.Build(file_path_) != SUCCESS) { if (header.Build(file_path_) != SUCCESS) {
MS_LOG(ERROR) << "Build shard schema failed"; MS_LOG(ERROR) << "Build shard schema failed.";
return FAILED; return FAILED;
} }
shard_header_ = header; shard_header_ = header;
...@@ -46,35 +46,49 @@ MSRStatus ShardIndexGenerator::Build() { ...@@ -46,35 +46,49 @@ MSRStatus ShardIndexGenerator::Build() {
return SUCCESS; return SUCCESS;
} }
std::vector<std::string> ShardIndexGenerator::GetField(const string &field_path, json schema) { std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const string &field, json input) {
std::vector<std::string> field_name = StringSplit(field_path, kPoint); if (field.empty()) {
std::vector<std::string> res; MS_LOG(ERROR) << "The input field is None.";
if (schema.empty()) { return {FAILED, ""};
res.emplace_back("null");
return res;
} }
for (uint64_t i = 0; i < field_name.size(); i++) {
// Check if field is part of an array of objects if (input.empty()) {
auto &child = schema.at(field_name[i]); MS_LOG(ERROR) << "The input json is None.";
if (child.is_array() && !child.empty() && child[0].is_object()) { return {FAILED, ""};
schema = schema[field_name[i]]; }
std::string new_field_path;
for (uint64_t j = i + 1; j < field_name.size(); j++) { // parameter input does not contain the field
if (j > i + 1) new_field_path += '.'; if (input.find(field) == input.end()) {
new_field_path += field_name[j]; MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input;
} return {FAILED, ""};
// Return multiple field data since multiple objects in array }
for (auto &single_schema : schema) {
auto child_res = GetField(new_field_path, single_schema); // schema does not contain the field
res.insert(res.end(), child_res.begin(), child_res.end()); auto schema = shard_header_.get_schemas()[0]->GetSchema()["schema"];
} if (schema.find(field) == schema.end()) {
return res; MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema;
return {FAILED, ""};
}
// field should be scalar type
if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) {
MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable";
return {FAILED, ""};
}
if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) {
auto schema_field_options = schema[field];
if (schema_field_options.find("shape") == schema_field_options.end()) {
return {SUCCESS, input[field].dump()};
} else {
// field with shape option
MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable";
return {FAILED, ""};
} }
schema = schema.at(field_name[i]);
} }
// Return vector of one field data (not array of objects) // the field type is string in here
return std::vector<std::string>{schema.dump()}; return {SUCCESS, input[field].get<std::string>()};
} }
std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) {
...@@ -304,6 +318,7 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( ...@@ -304,6 +318,7 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
const auto &place_holder = std::get<0>(field); const auto &place_holder = std::get<0>(field);
const auto &field_type = std::get<1>(field); const auto &field_type = std::get<1>(field);
const auto &field_value = std::get<2>(field); const auto &field_value = std::get<2>(field);
int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder)); int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
if (field_type == "INTEGER") { if (field_type == "INTEGER") {
if (sqlite3_bind_int(stmt, index, std::stoi(field_value)) != SQLITE_OK) { if (sqlite3_bind_int(stmt, index, std::stoi(field_value)) != SQLITE_OK) {
...@@ -463,17 +478,24 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s ...@@ -463,17 +478,24 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s
if (field.first >= schema_detail.size()) { if (field.first >= schema_detail.size()) {
return {FAILED, {}}; return {FAILED, {}};
} }
auto field_value = GetField(field.second, schema_detail[field.first]); auto field_value = GetValueByField(field.second, schema_detail[field.first]);
if (field_value.first != SUCCESS) {
MS_LOG(ERROR) << "Get value from json by field name failed";
return {FAILED, {}};
}
auto result = shard_header_.GetSchemaByID(field.first); auto result = shard_header_.GetSchemaByID(field.first);
if (result.second != SUCCESS) { if (result.second != SUCCESS) {
return {FAILED, {}}; return {FAILED, {}};
} }
std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"])); std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"]));
auto ret = GenerateFieldName(field); auto ret = GenerateFieldName(field);
if (ret.first != SUCCESS) { if (ret.first != SUCCESS) {
return {FAILED, {}}; return {FAILED, {}};
} }
fields.emplace_back(ret.second, field_type, field_value[0]);
fields.emplace_back(ret.second, field_type, field_value.second);
} }
return {SUCCESS, std::move(fields)}; return {SUCCESS, std::move(fields)};
} }
......
...@@ -25,6 +25,15 @@ using mindspore::MsLogLevel::INFO; ...@@ -25,6 +25,15 @@ using mindspore::MsLogLevel::INFO;
namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
template <class Type>
// convert the string to exactly number type (int32_t/int64_t/float/double)
Type StringToNum(const std::string &str) {
std::istringstream iss(str);
Type num;
iss >> num;
return num;
}
ShardReader::ShardReader() { ShardReader::ShardReader() {
task_id_ = 0; task_id_ = 0;
deliver_id_ = 0; deliver_id_ = 0;
...@@ -259,16 +268,25 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str ...@@ -259,16 +268,25 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
} }
column_values[shard_id].emplace_back(tmp); column_values[shard_id].emplace_back(tmp);
} else { } else {
string json_str = "{"; json construct_json;
for (unsigned int j = 0; j < columns.size(); ++j) { for (unsigned int j = 0; j < columns.size(); ++j) {
// construct the string json "f1": value // construct json "f1": value
json_str = json_str + "\"" + columns[j] + "\":" + labels[i][j + 3]; auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"];
if (j < columns.size() - 1) {
json_str += ","; // convert the string to base type by schema
if (schema[columns[j]]["type"] == "int32") {
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j + 3]);
} else if (schema[columns[j]]["type"] == "int64") {
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j + 3]);
} else if (schema[columns[j]]["type"] == "float32") {
construct_json[columns[j]] = StringToNum<float>(labels[i][j + 3]);
} else if (schema[columns[j]]["type"] == "float64") {
construct_json[columns[j]] = StringToNum<double>(labels[i][j + 3]);
} else {
construct_json[columns[j]] = std::string(labels[i][j + 3]);
} }
} }
json_str += "}"; column_values[shard_id].emplace_back(construct_json);
column_values[shard_id].emplace_back(json::parse(json_str));
} }
} }
...@@ -402,7 +420,16 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int ...@@ -402,7 +420,16 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
// whether use index search // whether use index search
if (!criteria.first.empty()) { if (!criteria.first.empty()) {
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; auto schema = shard_header_->get_schemas()[0]->GetSchema();
// not number field should add '' in sql
if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) {
sql +=
" AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second;
} else {
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" +
criteria.second + "'";
}
} }
sql += ";"; sql += ";";
std::vector<std::vector<std::string>> image_offsets; std::vector<std::vector<std::string>> image_offsets;
...@@ -603,16 +630,25 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int ...@@ -603,16 +630,25 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
std::vector<json> ret; std::vector<json> ret;
for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{}); for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{});
for (unsigned int i = 0; i < labels.size(); ++i) { for (unsigned int i = 0; i < labels.size(); ++i) {
string json_str = "{"; json construct_json;
for (unsigned int j = 0; j < columns.size(); ++j) { for (unsigned int j = 0; j < columns.size(); ++j) {
// construct string json "f1": value // construct json "f1": value
json_str = json_str + "\"" + columns[j] + "\":" + labels[i][j]; auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"];
if (j < columns.size() - 1) {
json_str += ","; // convert the string to base type by schema
if (schema[columns[j]]["type"] == "int32") {
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j]);
} else if (schema[columns[j]]["type"] == "int64") {
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j]);
} else if (schema[columns[j]]["type"] == "float32") {
construct_json[columns[j]] = StringToNum<float>(labels[i][j]);
} else if (schema[columns[j]]["type"] == "float64") {
construct_json[columns[j]] = StringToNum<double>(labels[i][j]);
} else {
construct_json[columns[j]] = std::string(labels[i][j]);
} }
} }
json_str += "}"; ret[i] = construct_json;
ret[i] = json::parse(json_str);
} }
return {SUCCESS, ret}; return {SUCCESS, ret};
} }
......
...@@ -311,14 +311,23 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardS ...@@ -311,14 +311,23 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardS
MS_LOG(ERROR) << "Get category info"; MS_LOG(ERROR) << "Get category info";
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
} }
// category_name to category_id
int64_t category_id = -1;
for (const auto &categories : ret.second) { for (const auto &categories : ret.second) {
if (std::get<1>(categories) == category_name) { std::string categories_name = std::get<1>(categories);
auto result = ReadAllAtPageById(std::get<0>(categories), page_no, n_rows_of_page);
return {SUCCESS, result.second}; if (categories_name == category_name) {
category_id = std::get<0>(categories);
break;
} }
} }
return {SUCCESS, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; if (category_id == -1) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
return ReadAllAtPageById(category_id, page_no, n_rows_of_page);
} }
std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByIdPy( std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByIdPy(
......
...@@ -133,15 +133,15 @@ class MindPage: ...@@ -133,15 +133,15 @@ class MindPage:
Raises: Raises:
ParamValueError: If any parameter is invalid. ParamValueError: If any parameter is invalid.
MRMFetchDataError: If failed to read by category id. MRMFetchDataError: If failed to fetch data by category.
MRMUnsupportedSchemaError: If schema is invalid. MRMUnsupportedSchemaError: If schema is invalid.
""" """
if category_id < 0: if not isinstance(category_id, int) or category_id < 0:
raise ParamValueError("Category id should be greater than 0.") raise ParamValueError("Category id should be int and greater than or equal to 0.")
if page < 0: if not isinstance(page, int) or page < 0:
raise ParamValueError("Page should be greater than 0.") raise ParamValueError("Page should be int and greater than or equal to 0.")
if num_row < 0: if not isinstance(num_row, int) or num_row <= 0:
raise ParamValueError("num_row should be greater than 0.") raise ParamValueError("num_row should be int and greater than 0.")
return self._segment.read_at_page_by_id(category_id, page, num_row) return self._segment.read_at_page_by_id(category_id, page, num_row)
def read_at_page_by_name(self, category_name, page, num_row): def read_at_page_by_name(self, category_name, page, num_row):
...@@ -157,8 +157,10 @@ class MindPage: ...@@ -157,8 +157,10 @@ class MindPage:
Returns: Returns:
str, read at page. str, read at page.
""" """
if page < 0: if not isinstance(category_name, str):
raise ParamValueError("Page should be greater than 0.") raise ParamValueError("Category name should be str.")
if num_row < 0: if not isinstance(page, int) or page < 0:
raise ParamValueError("num_row should be greater than 0.") raise ParamValueError("Page should be int and greater than or equal to 0.")
if not isinstance(num_row, int) or num_row <= 0:
raise ParamValueError("num_row should be int and greater than 0.")
return self._segment.read_at_page_by_name(category_name, page, num_row) return self._segment.read_at_page_by_name(category_name, page, num_row)
...@@ -53,6 +53,7 @@ class TestShardIndexGenerator : public UT::Common { ...@@ -53,6 +53,7 @@ class TestShardIndexGenerator : public UT::Common {
TestShardIndexGenerator() {} TestShardIndexGenerator() {}
}; };
/*
TEST_F(TestShardIndexGenerator, GetField) { TEST_F(TestShardIndexGenerator, GetField) {
MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field"); MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field");
...@@ -82,6 +83,8 @@ TEST_F(TestShardIndexGenerator, GetField) { ...@@ -82,6 +83,8 @@ TEST_F(TestShardIndexGenerator, GetField) {
} }
} }
} }
*/
TEST_F(TestShardIndexGenerator, TakeFieldType) { TEST_F(TestShardIndexGenerator, TakeFieldType) {
MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type"); MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type");
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test mindrecord base""" """test mindrecord base"""
import numpy as np
import os import os
import uuid import uuid
from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS
...@@ -25,6 +26,105 @@ CV2_FILE_NAME = "./imagenet_loop.mindrecord" ...@@ -25,6 +26,105 @@ CV2_FILE_NAME = "./imagenet_loop.mindrecord"
CV3_FILE_NAME = "./imagenet_append.mindrecord" CV3_FILE_NAME = "./imagenet_append.mindrecord"
NLP_FILE_NAME = "./aclImdb.mindrecord" NLP_FILE_NAME = "./aclImdb.mindrecord"
def test_write_read_process():
mindrecord_file_name = "test.mindrecord"
data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64),
"segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
"data": bytes("image bytes abc", encoding='UTF-8')},
{"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64),
"segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32),
"data": bytes("image bytes def", encoding='UTF-8')},
{"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64),
"segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32),
"data": bytes("image bytes ghi", encoding='UTF-8')},
{"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64),
"segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32),
"data": bytes("image bytes jkl", encoding='UTF-8')},
{"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64),
"segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32),
"data": bytes("image bytes mno", encoding='UTF-8')},
{"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64),
"segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32),
"data": bytes("image bytes pqr", encoding='UTF-8')}
]
writer = FileWriter(mindrecord_file_name)
schema = {"file_name": {"type": "string"},
"label": {"type": "int32"},
"score": {"type": "float64"},
"mask": {"type": "int64", "shape": [-1]},
"segments": {"type": "float32", "shape": [2, 2]},
"data": {"type": "bytes"}}
writer.add_schema(schema, "data is so cool")
writer.write_raw_data(data)
writer.commit()
reader = FileReader(mindrecord_file_name)
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 6
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
def test_write_read_process_with_define_index_field():
mindrecord_file_name = "test.mindrecord"
data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64),
"segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
"data": bytes("image bytes abc", encoding='UTF-8')},
{"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64),
"segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32),
"data": bytes("image bytes def", encoding='UTF-8')},
{"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64),
"segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32),
"data": bytes("image bytes ghi", encoding='UTF-8')},
{"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64),
"segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32),
"data": bytes("image bytes jkl", encoding='UTF-8')},
{"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64),
"segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32),
"data": bytes("image bytes mno", encoding='UTF-8')},
{"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64),
"segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32),
"data": bytes("image bytes pqr", encoding='UTF-8')}
]
writer = FileWriter(mindrecord_file_name)
schema = {"file_name": {"type": "string"},
"label": {"type": "int32"},
"score": {"type": "float64"},
"mask": {"type": "int64", "shape": [-1]},
"segments": {"type": "float32", "shape": [2, 2]},
"data": {"type": "bytes"}}
writer.add_schema(schema, "data is so cool")
writer.add_index(["label"])
writer.write_raw_data(data)
writer.commit()
reader = FileReader(mindrecord_file_name)
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 6
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
def test_cv_file_writer_tutorial(): def test_cv_file_writer_tutorial():
"""tutorial for cv dataset writer.""" """tutorial for cv dataset writer."""
writer = FileWriter(CV_FILE_NAME, FILES_NUM) writer = FileWriter(CV_FILE_NAME, FILES_NUM)
...@@ -137,6 +237,51 @@ def test_cv_page_reader_tutorial(): ...@@ -137,6 +237,51 @@ def test_cv_page_reader_tutorial():
assert len(row1[0]) == 3 assert len(row1[0]) == 3
assert row1[0]['label'] == 822 assert row1[0]['label'] == 822
def test_cv_page_reader_tutorial_by_file_name():
"""tutorial for cv page reader."""
reader = MindPage(CV_FILE_NAME + "0")
fields = reader.get_category_fields()
assert fields == ['file_name', 'label'],\
'failed on getting candidate category fields.'
ret = reader.set_category_field("file_name")
assert ret == SUCCESS, 'failed on setting category field.'
info = reader.read_category_info()
logger.info("category info: {}".format(info))
row = reader.read_at_page_by_id(0, 0, 1)
assert len(row) == 1
assert len(row[0]) == 3
assert row[0]['label'] == 490
row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1)
assert len(row1) == 1
assert len(row1[0]) == 3
assert row1[0]['label'] == 13
def test_cv_page_reader_tutorial_new_api():
"""tutorial for cv page reader."""
reader = MindPage(CV_FILE_NAME + "0")
fields = reader.candidate_fields
assert fields == ['file_name', 'label'],\
'failed on getting candidate category fields.'
reader.category_field = "file_name"
info = reader.read_category_info()
logger.info("category info: {}".format(info))
row = reader.read_at_page_by_id(0, 0, 1)
assert len(row) == 1
assert len(row[0]) == 3
assert row[0]['label'] == 490
row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1)
assert len(row1) == 1
assert len(row1[0]) == 3
assert row1[0]['label'] == 13
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)] for x in range(FILES_NUM)]
for x in paths: for x in paths:
......
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
"""test mindrecord exception""" """test mindrecord exception"""
import os import os
import pytest import pytest
from mindspore.mindrecord import FileWriter, FileReader, MindPage from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError, \
MRMFetchDataError
from mindspore import log as logger from mindspore import log as logger
from utils import get_data from utils import get_data
...@@ -286,3 +287,67 @@ def test_add_index_without_add_schema(): ...@@ -286,3 +287,67 @@ def test_add_index_without_add_schema():
fw = FileWriter(CV_FILE_NAME) fw = FileWriter(CV_FILE_NAME)
fw.add_index(["label"]) fw.add_index(["label"])
assert 'Failed to get meta info' in str(err.value) assert 'Failed to get meta info' in str(err.value)
def test_mindpage_pageno_pagesize_not_int():
"""test page reader when some partition does not exist."""
create_cv_mindrecord(4)
reader = MindPage(CV_FILE_NAME + "0")
fields = reader.get_category_fields()
assert fields == ['file_name', 'label'],\
'failed on getting candidate category fields.'
ret = reader.set_category_field("label")
assert ret == SUCCESS, 'failed on setting category field.'
info = reader.read_category_info()
logger.info("category info: {}".format(info))
with pytest.raises(ParamValueError) as err:
reader.read_at_page_by_id(0, "0", 1)
with pytest.raises(ParamValueError) as err:
reader.read_at_page_by_id(0, 0, "b")
with pytest.raises(ParamValueError) as err:
reader.read_at_page_by_name("822", "e", 1)
with pytest.raises(ParamValueError) as err:
reader.read_at_page_by_name("822", 0, "qwer")
with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."):
reader.read_at_page_by_id(99999, 0, 1)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
def test_mindpage_filename_not_exist():
"""test page reader when some partition does not exist."""
create_cv_mindrecord(4)
reader = MindPage(CV_FILE_NAME + "0")
fields = reader.get_category_fields()
assert fields == ['file_name', 'label'],\
'failed on getting candidate category fields.'
ret = reader.set_category_field("file_name")
assert ret == SUCCESS, 'failed on setting category field.'
info = reader.read_category_info()
logger.info("category info: {}".format(info))
with pytest.raises(MRMFetchDataError) as err:
reader.read_at_page_by_id(9999, 0, 1)
with pytest.raises(MRMFetchDataError) as err:
reader.read_at_page_by_name("abc.jpg", 0, 1)
with pytest.raises(ParamValueError) as err:
reader.read_at_page_by_name(1, 0, 1)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册