diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index 9cd02d9120f2a5bb8fa4343ff06cfdc54bf974bd..dd34615f7e0a59258584943a73e2fce00e52747d 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -346,7 +346,8 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg; return; } - MS_LOG(INFO) << "Get" << static_cast(columns.size()) << " records from shard " << shard_id << " index."; + MS_LOG(INFO) << "Get " << static_cast(columns.size()) << " records from shard " << shard_id << " index."; + std::lock_guard lck(shard_locker_); for (int i = 0; i < static_cast(columns.size()); ++i) { categories.emplace(columns[i][0]); } diff --git a/tests/ut/cpp/mindrecord/ut_common.cc b/tests/ut/cpp/mindrecord/ut_common.cc index 76aa5fc503263936f5e0b579e167cb00e429b9c5..2d2d69bd546db389b2fa1bbf224450cce1cbdabc 100644 --- a/tests/ut/cpp/mindrecord/ut_common.cc +++ b/tests/ut/cpp/mindrecord/ut_common.cc @@ -16,9 +16,9 @@ #include "ut_common.h" -using mindspore::MsLogLevel::ERROR; -using mindspore::ExceptionType::NoExceptionType; using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; namespace mindspore { namespace mindrecord { @@ -33,23 +33,6 @@ void Common::SetUp() {} void Common::TearDown() {} -void Common::LoadData(const std::string &directory, std::vector &json_buffer, const int max_num) { - int count = 0; - string input_path = directory; - ifstream infile(input_path); - if (!infile.is_open()) { - MS_LOG(ERROR) << "can not open the file "; - return; - } - string temp; - while (getline(infile, temp) && count != max_num) { - count++; - json j = json::parse(temp); - json_buffer.push_back(j); - } - infile.close(); -} - #ifdef __cplusplus #if __cplusplus } @@ -70,5 +53,353 @@ const std::string FormatInfo(const std::string &message, uint32_t message_total_ std::string right_padding(static_cast(floor(padding_length / 2.0)), '='); return left_padding + part_message + right_padding; } + +void LoadData(const std::string &directory, std::vector &json_buffer, const int max_num) { + int count = 0; + string input_path = directory; + ifstream infile(input_path); + if (!infile.is_open()) { + MS_LOG(ERROR) << "can not open the file "; + return; + } + string temp; + while (getline(infile, temp) && count != max_num) { + count++; + json j = json::parse(temp); + json_buffer.push_back(j); + } + infile.close(); +} + +void LoadDataFromImageNet(const std::string &directory, std::vector &json_buffer, const int max_num) { + int count = 0; + string input_path = directory; + ifstream infile(input_path); + if (!infile.is_open()) { + MS_LOG(ERROR) << "can not open the file "; + return; + } + string temp; + string filename; + string label; + json j; + while (getline(infile, temp) && count != max_num) { + count++; + std::size_t pos = temp.find(",", 0); + if (pos != std::string::npos) { + j["file_name"] = temp.substr(0, pos); + j["label"] = atoi(common::SafeCStr(temp.substr(pos + 1, temp.length()))); + json_buffer.push_back(j); + } + } + infile.close(); +} + +int Img2DataUint8(const std::vector &img_absolute_path, std::vector> &bin_data) { + for (auto &file : img_absolute_path) { + // read image file + std::ifstream in(common::SafeCStr(file), std::ios::in | std::ios::binary | std::ios::ate); + if (!in) { + MS_LOG(ERROR) << common::SafeCStr(file) << " is not a directory or not exist!"; + return -1; + } + + // get the file size + uint64_t size = in.tellg(); + in.seekg(0, std::ios::beg); + std::vector file_data(size); + in.read(reinterpret_cast(&file_data[0]), size); + in.close(); + bin_data.push_back(file_data); + } + return 0; +} + +int GetAbsoluteFiles(std::string directory, std::vector &files_absolute_path) { + DIR *dir = opendir(common::SafeCStr(directory)); + if (dir == nullptr) { + MS_LOG(ERROR) << common::SafeCStr(directory) << " is not a directory or not exist!"; + return -1; + } + struct dirent *d_ent = nullptr; + char dot[3] = "."; + char dotdot[6] = ".."; + while ((d_ent = readdir(dir)) != nullptr) { + if ((strcmp(d_ent->d_name, dot) != 0) && (strcmp(d_ent->d_name, dotdot) != 0)) { + if (d_ent->d_type == DT_DIR) { + std::string new_directory = directory + std::string("/") + std::string(d_ent->d_name); + if (directory[directory.length() - 1] == '/') { + new_directory = directory + string(d_ent->d_name); + } + if (-1 == GetAbsoluteFiles(new_directory, files_absolute_path)) { + closedir(dir); + return -1; + } + } else { + std::string absolute_path = directory + std::string("/") + std::string(d_ent->d_name); + if (directory[directory.length() - 1] == '/') { + absolute_path = directory + std::string(d_ent->d_name); + } + files_absolute_path.push_back(absolute_path); + } + } + } + closedir(dir); + return 0; +} + +void ShardWriterImageNet() { + MS_LOG(INFO) << common::SafeCStr(FormatInfo("Write imageNet")); + + // load binary data + std::vector> bin_data; + std::vector filenames; + if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) { + MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------"; + return; + } + mindrecord::Img2DataUint8(filenames, bin_data); + + // init shardHeader + ShardHeader header_data; + MS_LOG(INFO) << "Init ShardHeader Already."; + + // create schema + json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json; + std::shared_ptr anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json); + if (anno_schema == nullptr) { + MS_LOG(ERROR) << "Build annotation schema failed"; + return; + } + + // add schema to shardHeader + int anno_schema_id = header_data.AddSchema(anno_schema); + MS_LOG(INFO) << "Init Schema Already."; + + // create index + std::pair index_field1(anno_schema_id, "file_name"); + std::pair index_field2(anno_schema_id, "label"); + std::vector> fields; + fields.push_back(index_field1); + fields.push_back(index_field2); + + // add index to shardHeader + header_data.AddIndexFields(fields); + MS_LOG(INFO) << "Init Index Fields Already."; + // load meta data + std::vector annotations; + LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 10); + + // add data + std::map> rawdatas; + rawdatas.insert(pair>(anno_schema_id, annotations)); + MS_LOG(INFO) << "Init Images Already."; + + // init file_writer + std::vector file_names; + int file_count = 4; + for (int i = 1; i <= file_count; i++) { + file_names.emplace_back(std::string("./imagenet.shard0") + std::to_string(i)); + MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]); + } + + MS_LOG(INFO) << "Init Output Files Already."; + { + ShardWriter fw_init; + fw_init.Open(file_names); + + // set shardHeader + fw_init.SetShardHeader(std::make_shared(header_data)); + + // close file_writer + fw_init.Commit(); + } + std::string filename = "./imagenet.shard01"; + { + MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; + mindrecord::ShardWriter fw; + fw.OpenForAppend(filename); + fw.WriteRawData(rawdatas, bin_data); + fw.Commit(); + } + mindrecord::ShardIndexGenerator sg{filename}; + sg.Build(); + sg.WriteToDatabase(); + + MS_LOG(INFO) << "Done create index"; +} + +void ShardWriterImageNetOneSample() { + // load binary data + std::vector> bin_data; + std::vector filenames; + if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) { + MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------"; + return; + } + mindrecord::Img2DataUint8(filenames, bin_data); + + // init shardHeader + mindrecord::ShardHeader header_data; + MS_LOG(INFO) << "Init ShardHeader Already."; + + // create schema + json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json; + std::shared_ptr anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json); + if (anno_schema == nullptr) { + MS_LOG(ERROR) << "Build annotation schema failed"; + return; + } + + // add schema to shardHeader + int anno_schema_id = header_data.AddSchema(anno_schema); + MS_LOG(INFO) << "Init Schema Already."; + + // create index + std::pair index_field1(anno_schema_id, "file_name"); + std::pair index_field2(anno_schema_id, "label"); + std::vector> fields; + fields.push_back(index_field1); + fields.push_back(index_field2); + + // add index to shardHeader + header_data.AddIndexFields(fields); + MS_LOG(INFO) << "Init Index Fields Already."; + + // load meta data + std::vector annotations; + LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 1); + + // add data + std::map> rawdatas; + rawdatas.insert(pair>(anno_schema_id, annotations)); + MS_LOG(INFO) << "Init Images Already."; + + // init file_writer + std::vector file_names; + for (int i = 1; i <= 4; i++) { + file_names.emplace_back(std::string("./OneSample.shard0") + std::to_string(i)); + MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]); + } + + MS_LOG(INFO) << "Init Output Files Already."; + { + mindrecord::ShardWriter fw_init; + fw_init.Open(file_names); + + // set shardHeader + fw_init.SetShardHeader(std::make_shared(header_data)); + + // close file_writer + fw_init.Commit(); + } + + std::string filename = "./OneSample.shard01"; + { + MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; + mindrecord::ShardWriter fw; + fw.OpenForAppend(filename); + bin_data = std::vector>(bin_data.begin(), bin_data.begin() + 1); + fw.WriteRawData(rawdatas, bin_data); + fw.Commit(); + } + + mindrecord::ShardIndexGenerator sg{filename}; + sg.Build(); + sg.WriteToDatabase(); + MS_LOG(INFO) << "Done create index"; +} + +void ShardWriterImageNetOpenForAppend(string filename) { + for (int i = 1; i <= 4; i++) { + string filename = std::string("./OpenForAppendSample.shard0") + std::to_string(i); + string db_name = std::string("./OpenForAppendSample.shard0") + std::to_string(i) + ".db"; + remove(common::SafeCStr(filename)); + remove(common::SafeCStr(db_name)); + } + + // load binary data + std::vector> bin_data; + std::vector filenames; + if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) { + MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------"; + return; + } + mindrecord::Img2DataUint8(filenames, bin_data); + + // init shardHeader + mindrecord::ShardHeader header_data; + MS_LOG(INFO) << "Init ShardHeader Already."; + + // create schema + json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json; + std::shared_ptr anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json); + if (anno_schema == nullptr) { + MS_LOG(ERROR) << "Build annotation schema failed"; + return; + } + + // add schema to shardHeader + int anno_schema_id = header_data.AddSchema(anno_schema); + MS_LOG(INFO) << "Init Schema Already."; + + // create index + std::pair index_field1(anno_schema_id, "file_name"); + std::pair index_field2(anno_schema_id, "label"); + std::vector> fields; + fields.push_back(index_field1); + fields.push_back(index_field2); + + // add index to shardHeader + header_data.AddIndexFields(fields); + MS_LOG(INFO) << "Init Index Fields Already."; + + // load meta data + std::vector annotations; + LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 1); + + // add data + std::map> rawdatas; + rawdatas.insert(pair>(anno_schema_id, annotations)); + MS_LOG(INFO) << "Init Images Already."; + + // init file_writer + std::vector file_names; + for (int i = 1; i <= 4; i++) { + file_names.emplace_back(std::string("./OpenForAppendSample.shard0") + std::to_string(i)); + MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]); + } + + MS_LOG(INFO) << "Init Output Files Already."; + { + mindrecord::ShardWriter fw_init; + fw_init.Open(file_names); + + // set shardHeader + fw_init.SetShardHeader(std::make_shared(header_data)); + + // close file_writer + fw_init.Commit(); + } + { + MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; + mindrecord::ShardWriter fw; + auto ret = fw.OpenForAppend(filename); + if (ret == FAILED) { + return; + } + + bin_data = std::vector>(bin_data.begin(), bin_data.begin() + 1); + fw.WriteRawData(rawdatas, bin_data); + fw.Commit(); + } + + ShardIndexGenerator sg{filename}; + sg.Build(); + sg.WriteToDatabase(); + MS_LOG(INFO) << "Done create index"; +} + + } // namespace mindrecord } // namespace mindspore diff --git a/tests/ut/cpp/mindrecord/ut_common.h b/tests/ut/cpp/mindrecord/ut_common.h index 398c59779bbdee2ea064034e870771e21d5df50a..8b244bf87aeaa156bb2c042b04f8ea058fe33c5f 100644 --- a/tests/ut/cpp/mindrecord/ut_common.h +++ b/tests/ut/cpp/mindrecord/ut_common.h @@ -17,6 +17,7 @@ #ifndef TESTS_MINDRECORD_UT_UT_COMMON_H_ #define TESTS_MINDRECORD_UT_UT_COMMON_H_ +#include #include #include #include @@ -25,7 +26,9 @@ #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "mindrecord/include/shard_index.h" - +#include "mindrecord/include/shard_header.h" +#include "mindrecord/include/shard_index_generator.h" +#include "mindrecord/include/shard_writer.h" using json = nlohmann::json; using std::ifstream; using std::pair; @@ -40,11 +43,10 @@ class Common : public testing::Test { std::string install_root; // every TEST_F macro will enter one - void SetUp(); + virtual void SetUp(); - void TearDown(); + virtual void TearDown(); - static void LoadData(const std::string &directory, std::vector &json_buffer, const int max_num); }; } // namespace UT @@ -55,6 +57,21 @@ class Common : public testing::Test { /// /// return the formatted string const std::string FormatInfo(const std::string &message, uint32_t message_total_length = 128); + + +void LoadData(const std::string &directory, std::vector &json_buffer, const int max_num); + +void LoadDataFromImageNet(const std::string &directory, std::vector &json_buffer, const int max_num); + +int Img2DataUint8(const std::vector &img_absolute_path, std::vector> &bin_data); + +int GetAbsoluteFiles(std::string directory, std::vector &files_absolute_path); + +void ShardWriterImageNet(); + +void ShardWriterImageNetOneSample(); + +void ShardWriterImageNetOpenForAppend(string filename); } // namespace mindrecord } // namespace mindspore #endif // TESTS_MINDRECORD_UT_UT_COMMON_H_ diff --git a/tests/ut/cpp/mindrecord/ut_shard.cc b/tests/ut/cpp/mindrecord/ut_shard.cc index 88fdb7e167c3c55fc2acb8ab9ffa87f0d950f194..994ff1b859bad6c5d192f32cf6ed348fcf0aba53 100644 --- a/tests/ut/cpp/mindrecord/ut_shard.cc +++ b/tests/ut/cpp/mindrecord/ut_shard.cc @@ -29,7 +29,6 @@ #include "mindrecord/include/shard_statistics.h" #include "securec.h" #include "ut_common.h" -#include "ut_shard_writer_test.h" using mindspore::MsLogLevel::INFO; using mindspore::ExceptionType::NoExceptionType; @@ -43,7 +42,7 @@ class TestShard : public UT::Common { }; TEST_F(TestShard, TestShardSchemaPart) { - TestShardWriterImageNet(); + ShardWriterImageNet(); MS_LOG(INFO) << FormatInfo("Test schema"); @@ -55,6 +54,12 @@ TEST_F(TestShard, TestShardSchemaPart) { ASSERT_TRUE(schema != nullptr); MS_LOG(INFO) << "schema description: " << schema->get_desc() << ", schema: " << common::SafeCStr(schema->GetSchema().dump()); + for (int i = 1; i <= 4; i++) { + string filename = std::string("./imagenet.shard0") + std::to_string(i); + string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db"; + remove(common::SafeCStr(filename)); + remove(common::SafeCStr(db_name)); + } } TEST_F(TestShard, TestStatisticPart) { @@ -128,6 +133,5 @@ TEST_F(TestShard, TestShardHeaderPart) { ASSERT_EQ(resFields, fields); } -TEST_F(TestShard, TestShardWriteImage) { MS_LOG(INFO) << FormatInfo("Test writer"); } } // namespace mindrecord } // namespace mindspore diff --git a/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc index 0c33d33ffd45cb232fc6643b1c0aa64b80614747..140fff4166cf1254a32afa0208f03c12c72799f2 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc @@ -53,38 +53,6 @@ class TestShardIndexGenerator : public UT::Common { TestShardIndexGenerator() {} }; -/* -TEST_F(TestShardIndexGenerator, GetField) { - MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field"); - - int max_num = 1; - string input_path1 = install_root + "/test/testCBGData/data/annotation.data"; - std::vector json_buffer1; // store the image_raw_meta.data - Common::LoadData(input_path1, json_buffer1, max_num); - - MS_LOG(INFO) << "Fetch fields: "; - for (auto &j : json_buffer1) { - auto v_name = ShardIndexGenerator::GetField("anno_tool", j); - auto v_attr_name = ShardIndexGenerator::GetField("entity_instances.attributes.attr_name", j); - auto v_entity_name = ShardIndexGenerator::GetField("entity_instances.entity_name", j); - vector names = {"\"CVAT\""}; - for (unsigned int i = 0; i != names.size(); i++) { - ASSERT_EQ(names[i], v_name[i]); - } - vector attr_names = {"\"脸部评分\"", "\"特征点\"", "\"points_example\"", "\"polyline_example\"", - "\"polyline_example\""}; - for (unsigned int i = 0; i != attr_names.size(); i++) { - ASSERT_EQ(attr_names[i], v_attr_name[i]); - } - vector entity_names = {"\"276点人脸\"", "\"points_example\"", "\"polyline_example\"", - "\"polyline_example\""}; - for (unsigned int i = 0; i != entity_names.size(); i++) { - ASSERT_EQ(entity_names[i], v_entity_name[i]); - } - } -} -*/ - TEST_F(TestShardIndexGenerator, TakeFieldType) { MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type"); diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index bfd49069b20cd9c710db90a4f618d5934173dbba..9c177d7a4084e45dd94b71a33d3d352021bb2c15 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -40,6 +40,17 @@ namespace mindrecord { class TestShardOperator : public UT::Common { public: TestShardOperator() {} + + void SetUp() override { ShardWriterImageNet(); } + + void TearDown() override { + for (int i = 1; i <= 4; i++) { + string filename = std::string("./imagenet.shard0") + std::to_string(i); + string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db"; + remove(common::SafeCStr(filename)); + remove(common::SafeCStr(db_name)); + } + } }; TEST_F(TestShardOperator, TestShardSampleBasic) { @@ -165,7 +176,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) { auto x = dataset.GetNext(); if (x.empty()) break; std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) - << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; i++; } dataset.Finish(); @@ -191,7 +202,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { if (x.empty()) break; std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) - << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; i++; } dataset.Finish(); diff --git a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc index f7ed39a00613dcaa2eab30d108bd58abc884b99f..e88c2fe3d61365b632ccd17626d0d424e1e5bc15 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc @@ -37,6 +37,16 @@ namespace mindrecord { class TestShardReader : public UT::Common { public: TestShardReader() {} + void SetUp() override { ShardWriterImageNet(); } + + void TearDown() override { + for (int i = 1; i <= 4; i++) { + string filename = std::string("./imagenet.shard0") + std::to_string(i); + string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db"; + remove(common::SafeCStr(filename)); + remove(common::SafeCStr(db_name)); + } + } }; TEST_F(TestShardReader, TestShardReaderGeneral) { @@ -51,8 +61,8 @@ TEST_F(TestShardReader, TestShardReaderGeneral) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - for (auto& j : x) { - for (auto& item : std::get<1>(j).items()) { + for (auto &j : x) { + for (auto &item : std::get<1>(j).items()) { MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); } } @@ -74,8 +84,8 @@ TEST_F(TestShardReader, TestShardReaderSample) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - for (auto& j : x) { - for (auto& item : std::get<1>(j).items()) { + for (auto &j : x) { + for (auto &item : std::get<1>(j).items()) { MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); } } @@ -99,8 +109,8 @@ TEST_F(TestShardReader, TestShardReaderBlock) { while (true) { auto x = dataset.GetBlockNext(); if (x.empty()) break; - for (auto& j : x) { - for (auto& item : std::get<1>(j).items()) { + for (auto &j : x) { + for (auto &item : std::get<1>(j).items()) { MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); } } @@ -119,8 +129,8 @@ TEST_F(TestShardReader, TestShardReaderEasy) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - for (auto& j : x) { - for (auto& item : std::get<1>(j).items()) { + for (auto &j : x) { + for (auto &item : std::get<1>(j).items()) { MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); } } @@ -140,8 +150,8 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - for (auto& j : x) { - for (auto& item : std::get<1>(j).items()) { + for (auto &j : x) { + for (auto &item : std::get<1>(j).items()) { MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); } } @@ -169,9 +179,9 @@ TEST_F(TestShardReader, TestShardVersion) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - for (auto& j : x) { + for (auto &j : x) { MS_LOG(INFO) << "result size: " << std::get<0>(j).size(); - for (auto& item : std::get<1>(j).items()) { + for (auto &item : std::get<1>(j).items()) { MS_LOG(INFO) << "key: " << common::SafeCStr(item.key()) << ", value: " << common::SafeCStr(item.value().dump()); } } @@ -201,8 +211,8 @@ TEST_F(TestShardReader, TestShardReaderConsumer) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - for (auto& j : x) { - for (auto& item : std::get<1>(j).items()) { + for (auto &j : x) { + for (auto &item : std::get<1>(j).items()) { MS_LOG(INFO) << "key: " << common::SafeCStr(item.key()) << ", value: " << common::SafeCStr(item.value().dump()); } } diff --git a/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc b/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc index c803f584aaa5832028ec94ed69f377f80ab144af..bf0a35df7dabb2bf79a79516b6f7f7e9e2f40b13 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc @@ -33,15 +33,25 @@ #include "mindrecord/include/shard_segment.h" #include "ut_common.h" -using mindspore::MsLogLevel::INFO; -using mindspore::ExceptionType::NoExceptionType; using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; namespace mindspore { namespace mindrecord { class TestShardSegment : public UT::Common { public: TestShardSegment() {} + void SetUp() override { ShardWriterImageNet(); } + + void TearDown() override { + for (int i = 1; i <= 4; i++) { + string filename = std::string("./imagenet.shard0") + std::to_string(i); + string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db"; + remove(common::SafeCStr(filename)); + remove(common::SafeCStr(db_name)); + } + } }; TEST_F(TestShardSegment, TestShardSegment) { diff --git a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc index 18e9214b08eb7d6cbfcd5ea89a436ef1ac4044fe..3fa248c2e0540b2cd964fdbda34663b0681dcd13 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc @@ -16,7 +16,6 @@ #include #include -#include #include #include #include @@ -30,7 +29,6 @@ #include "mindrecord/include/shard_index_generator.h" #include "securec.h" #include "ut_common.h" -#include "ut_shard_writer_test.h" using mindspore::LogStream; using mindspore::ExceptionType::NoExceptionType; @@ -44,249 +42,10 @@ class TestShardWriter : public UT::Common { TestShardWriter() {} }; -void LoadDataFromImageNet(const std::string &directory, std::vector &json_buffer, const int max_num) { - int count = 0; - string input_path = directory; - ifstream infile(input_path); - if (!infile.is_open()) { - MS_LOG(ERROR) << "can not open the file "; - return; - } - string temp; - string filename; - string label; - json j; - while (getline(infile, temp) && count != max_num) { - count++; - std::size_t pos = temp.find(",", 0); - if (pos != std::string::npos) { - j["file_name"] = temp.substr(0, pos); - j["label"] = atoi(common::SafeCStr(temp.substr(pos + 1, temp.length()))); - json_buffer.push_back(j); - } - } - infile.close(); -} - -int Img2DataUint8(const std::vector &img_absolute_path, std::vector> &bin_data) { - for (auto &file : img_absolute_path) { - // read image file - std::ifstream in(common::SafeCStr(file), std::ios::in | std::ios::binary | std::ios::ate); - if (!in) { - MS_LOG(ERROR) << common::SafeCStr(file) << " is not a directory or not exist!"; - return -1; - } - - // get the file size - uint64_t size = in.tellg(); - in.seekg(0, std::ios::beg); - std::vector file_data(size); - in.read(reinterpret_cast(&file_data[0]), size); - in.close(); - bin_data.push_back(file_data); - } - return 0; -} - -int GetAbsoluteFiles(std::string directory, std::vector &files_absolute_path) { - DIR *dir = opendir(common::SafeCStr(directory)); - if (dir == nullptr) { - MS_LOG(ERROR) << common::SafeCStr(directory) << " is not a directory or not exist!"; - return -1; - } - struct dirent *d_ent = nullptr; - char dot[3] = "."; - char dotdot[6] = ".."; - while ((d_ent = readdir(dir)) != nullptr) { - if ((strcmp(d_ent->d_name, dot) != 0) && (strcmp(d_ent->d_name, dotdot) != 0)) { - if (d_ent->d_type == DT_DIR) { - std::string new_directory = directory + std::string("/") + std::string(d_ent->d_name); - if (directory[directory.length() - 1] == '/') { - new_directory = directory + string(d_ent->d_name); - } - if (-1 == GetAbsoluteFiles(new_directory, files_absolute_path)) { - closedir(dir); - return -1; - } - } else { - std::string absolute_path = directory + std::string("/") + std::string(d_ent->d_name); - if (directory[directory.length() - 1] == '/') { - absolute_path = directory + std::string(d_ent->d_name); - } - files_absolute_path.push_back(absolute_path); - } - } - } - closedir(dir); - return 0; -} - -void TestShardWriterImageNet() { - MS_LOG(INFO) << common::SafeCStr(FormatInfo("Write imageNet")); - - // load binary data - std::vector> bin_data; - std::vector filenames; - if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) { - MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------"; - return; - } - mindrecord::Img2DataUint8(filenames, bin_data); - - // init shardHeader - mindrecord::ShardHeader header_data; - MS_LOG(INFO) << "Init ShardHeader Already."; - - // create schema - json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json; - std::shared_ptr anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json); - if (anno_schema == nullptr) { - MS_LOG(ERROR) << "Build annotation schema failed"; - return; - } - - // add schema to shardHeader - int anno_schema_id = header_data.AddSchema(anno_schema); - MS_LOG(INFO) << "Init Schema Already."; - - // create index - std::pair index_field1(anno_schema_id, "file_name"); - std::pair index_field2(anno_schema_id, "label"); - std::vector> fields; - fields.push_back(index_field1); - fields.push_back(index_field2); - - // add index to shardHeader - header_data.AddIndexFields(fields); - MS_LOG(INFO) << "Init Index Fields Already."; - // load meta data - std::vector annotations; - LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 10); - - // add data - std::map> rawdatas; - rawdatas.insert(pair>(anno_schema_id, annotations)); - MS_LOG(INFO) << "Init Images Already."; - - // init file_writer - std::vector file_names; - int file_count = 4; - for (int i = 1; i <= file_count; i++) { - file_names.emplace_back(std::string("./imagenet.shard0") + std::to_string(i)); - MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]); - } - - MS_LOG(INFO) << "Init Output Files Already."; - { - mindrecord::ShardWriter fw_init; - fw_init.Open(file_names); - - // set shardHeader - fw_init.SetShardHeader(std::make_shared(header_data)); - - // close file_writer - fw_init.Commit(); - } - std::string filename = "./imagenet.shard01"; - { - MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; - mindrecord::ShardWriter fw; - fw.OpenForAppend(filename); - fw.WriteRawData(rawdatas, bin_data); - fw.Commit(); - } - mindrecord::ShardIndexGenerator sg{filename}; - sg.Build(); - sg.WriteToDatabase(); - - MS_LOG(INFO) << "Done create index"; -} - -void TestShardWriterImageNetOneSample() { - // load binary data - std::vector> bin_data; - std::vector filenames; - if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) { - MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------"; - return; - } - mindrecord::Img2DataUint8(filenames, bin_data); - - // init shardHeader - mindrecord::ShardHeader header_data; - MS_LOG(INFO) << "Init ShardHeader Already."; - - // create schema - json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json; - std::shared_ptr anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json); - if (anno_schema == nullptr) { - MS_LOG(ERROR) << "Build annotation schema failed"; - return; - } - - // add schema to shardHeader - int anno_schema_id = header_data.AddSchema(anno_schema); - MS_LOG(INFO) << "Init Schema Already."; - - // create index - std::pair index_field1(anno_schema_id, "file_name"); - std::pair index_field2(anno_schema_id, "label"); - std::vector> fields; - fields.push_back(index_field1); - fields.push_back(index_field2); - - // add index to shardHeader - header_data.AddIndexFields(fields); - MS_LOG(INFO) << "Init Index Fields Already."; - - // load meta data - std::vector annotations; - LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 1); - - // add data - std::map> rawdatas; - rawdatas.insert(pair>(anno_schema_id, annotations)); - MS_LOG(INFO) << "Init Images Already."; - - // init file_writer - std::vector file_names; - for (int i = 1; i <= 4; i++) { - file_names.emplace_back(std::string("./OneSample.shard0") + std::to_string(i)); - MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]); - } - - MS_LOG(INFO) << "Init Output Files Already."; - { - mindrecord::ShardWriter fw_init; - fw_init.Open(file_names); - - // set shardHeader - fw_init.SetShardHeader(std::make_shared(header_data)); - - // close file_writer - fw_init.Commit(); - } - - std::string filename = "./OneSample.shard01"; - { - MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; - mindrecord::ShardWriter fw; - fw.OpenForAppend(filename); - bin_data = std::vector>(bin_data.begin(), bin_data.begin() + 1); - fw.WriteRawData(rawdatas, bin_data); - fw.Commit(); - } - - mindrecord::ShardIndexGenerator sg{filename}; - sg.Build(); - sg.WriteToDatabase(); - MS_LOG(INFO) << "Done create index"; -} - TEST_F(TestShardWriter, TestShardWriterBench) { MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test write imageNet")); - TestShardWriterImageNet(); + ShardWriterImageNet(); for (int i = 1; i <= 4; i++) { string filename = std::string("./imagenet.shard0") + std::to_string(i); string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db"; @@ -297,7 +56,7 @@ TEST_F(TestShardWriter, TestShardWriterBench) { TEST_F(TestShardWriter, TestShardWriterOneSample) { MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test write imageNet int32 of sample less than num of shards")); - TestShardWriterImageNetOneSample(); + ShardWriterImageNetOneSample(); std::string filename = "./OneSample.shard01"; ShardReader dataset; @@ -342,7 +101,7 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) { std::vector image_filenames; // save all files' path within path_dir // read image_raw_meta.data - Common::LoadData(input_path1, json_buffer1, kMaxNum); + LoadData(input_path1, json_buffer1, kMaxNum); MS_LOG(INFO) << "Load Meta Data Already."; // get files' pathes stored in vector image_filenames @@ -375,7 +134,7 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) { MS_LOG(INFO) << "Init Schema Already."; // create/init statistics - Common::LoadData(input_path3, json_buffer4, 2); + LoadData(input_path3, json_buffer4, 2); json static1_json = json_buffer4[0]; json static2_json = json_buffer4[1]; MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump()); @@ -474,7 +233,7 @@ TEST_F(TestShardWriter, TestShardWriterTrial) { std::vector image_filenames; // save all files' path within path_dir // read image_raw_meta.data - Common::LoadData(input_path1, json_buffer1, kMaxNum); + LoadData(input_path1, json_buffer1, kMaxNum); MS_LOG(INFO) << "Load Meta Data Already."; // get files' pathes stored in vector image_filenames @@ -508,7 +267,7 @@ TEST_F(TestShardWriter, TestShardWriterTrial) { MS_LOG(INFO) << "Init Schema Already."; // create/init statistics - Common::LoadData(input_path3, json_buffer4, 2); + LoadData(input_path3, json_buffer4, 2); json static1_json = json_buffer4[0]; json static2_json = json_buffer4[1]; MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump()); @@ -613,7 +372,7 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) { std::vector image_filenames; // save all files' path within path_dir // read image_raw_meta.data - Common::LoadData(input_path1, json_buffer1, kMaxNum); + LoadData(input_path1, json_buffer1, kMaxNum); MS_LOG(INFO) << "Load Meta Data Already."; // get files' pathes stored in vector image_filenames @@ -644,7 +403,7 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) { MS_LOG(INFO) << "Init Schema Already."; // create/init statistics - Common::LoadData(input_path3, json_buffer4, 2); + LoadData(input_path3, json_buffer4, 2); json static1_json = json_buffer4[0]; json static2_json = json_buffer4[1]; MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump()); @@ -1357,107 +1116,24 @@ TEST_F(TestShardWriter, TestWriteOpenFileName) { } } -void TestShardWriterImageNetOpenForAppend(string filename) { - for (int i = 1; i <= 4; i++) { - string filename = std::string("./OpenForAppendSample.shard0") + std::to_string(i); - string db_name = std::string("./OpenForAppendSample.shard0") + std::to_string(i) + ".db"; - remove(common::SafeCStr(filename)); - remove(common::SafeCStr(db_name)); - } - - // load binary data - std::vector> bin_data; - std::vector filenames; - if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) { - MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------"; - return; - } - mindrecord::Img2DataUint8(filenames, bin_data); - - // init shardHeader - mindrecord::ShardHeader header_data; - MS_LOG(INFO) << "Init ShardHeader Already."; - - // create schema - json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json; - std::shared_ptr anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json); - if (anno_schema == nullptr) { - MS_LOG(ERROR) << "Build annotation schema failed"; - return; - } - - // add schema to shardHeader - int anno_schema_id = header_data.AddSchema(anno_schema); - MS_LOG(INFO) << "Init Schema Already."; - - // create index - std::pair index_field1(anno_schema_id, "file_name"); - std::pair index_field2(anno_schema_id, "label"); - std::vector> fields; - fields.push_back(index_field1); - fields.push_back(index_field2); - - // add index to shardHeader - header_data.AddIndexFields(fields); - MS_LOG(INFO) << "Init Index Fields Already."; - - // load meta data - std::vector annotations; - LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 1); - - // add data - std::map> rawdatas; - rawdatas.insert(pair>(anno_schema_id, annotations)); - MS_LOG(INFO) << "Init Images Already."; - - // init file_writer - std::vector file_names; - for (int i = 1; i <= 4; i++) { - file_names.emplace_back(std::string("./OpenForAppendSample.shard0") + std::to_string(i)); - MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]); - } - - MS_LOG(INFO) << "Init Output Files Already."; - { - mindrecord::ShardWriter fw_init; - fw_init.Open(file_names); - - // set shardHeader - fw_init.SetShardHeader(std::make_shared(header_data)); - - // close file_writer - fw_init.Commit(); - } - { - MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; - mindrecord::ShardWriter fw; - auto ret = fw.OpenForAppend(filename); - if (ret == FAILED) { - return; - } - - bin_data = std::vector>(bin_data.begin(), bin_data.begin() + 1); - fw.WriteRawData(rawdatas, bin_data); - fw.Commit(); - } - - mindrecord::ShardIndexGenerator sg{filename}; - sg.Build(); - sg.WriteToDatabase(); - MS_LOG(INFO) << "Done create index"; -} - TEST_F(TestShardWriter, TestOpenForAppend) { MS_LOG(INFO) << "start ---- TestOpenForAppend\n"; string filename = "./"; - TestShardWriterImageNetOpenForAppend(filename); + ShardWriterImageNetOpenForAppend(filename); string filename1 = "./▒AppendSample.shard01"; - TestShardWriterImageNetOpenForAppend(filename1); + ShardWriterImageNetOpenForAppend(filename1); string filename2 = "./ä\xA9ü"; - TestShardWriterImageNetOpenForAppend(filename2); + ShardWriterImageNetOpenForAppend(filename2); + MS_LOG(INFO) << "end ---- TestOpenForAppend\n"; + for (int i = 1; i <= 4; i++) { + string filename = std::string("./OpenForAppendSample.shard0") + std::to_string(i); + string db_name = std::string("./OpenForAppendSample.shard0") + std::to_string(i) + ".db"; + remove(common::SafeCStr(filename)); + remove(common::SafeCStr(db_name)); + } } } // namespace mindrecord diff --git a/tests/ut/cpp/mindrecord/ut_shard_writer_test.h b/tests/ut/cpp/mindrecord/ut_shard_writer_test.h deleted file mode 100644 index f665297b17d10d2b1ccef33fe96a28793cd27146..0000000000000000000000000000000000000000 --- a/tests/ut/cpp/mindrecord/ut_shard_writer_test.h +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TESTS_MINDRECORD_UT_SHARDWRITER_H -#define TESTS_MINDRECORD_UT_SHARDWRITER_H - -namespace mindspore { -namespace mindrecord { -void TestShardWriterImageNet(); -} // namespace mindrecord -} // namespace mindspore - -#endif // TESTS_MINDRECORD_UT_SHARDWRITER_H