From 5637f80692854eeae6e1fd4a439b5194905fb465 Mon Sep 17 00:00:00 2001 From: Zirui Wu Date: Wed, 1 Apr 2020 11:24:25 -0400 Subject: [PATCH] implemented multi-thread index writer for mindrecord num threads cannot be more than num shards minor fix clang style fix address review comments --- .../include/shard_index_generator.h | 10 +- .../mindrecord/io/shard_index_generator.cc | 98 ++++++++++++------- 2 files changed, 72 insertions(+), 36 deletions(-) diff --git a/mindspore/ccsrc/mindrecord/include/shard_index_generator.h b/mindspore/ccsrc/mindrecord/include/shard_index_generator.h index f59dbe9bf..1febd28fc 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_index_generator.h +++ b/mindspore/ccsrc/mindrecord/include/shard_index_generator.h @@ -85,14 +85,14 @@ class ShardIndexGenerator { /// \param sql /// \param data /// \return - MSRStatus BindParamaterExecuteSQL( + MSRStatus BindParameterExecuteSQL( sqlite3 *db, const std::string &sql, const std::vector>> &data); INDEX_FIELDS GenerateIndexFields(const std::vector &schema_detail); - MSRStatus ExcuteTransaction(const int &shard_no, const std::pair &db, - const std::vector &raw_page_ids, const std::map &blob_id_to_page_id); + MSRStatus ExecuteTransaction(const int &shard_no, const std::pair &db, + const std::vector &raw_page_ids, const std::map &blob_id_to_page_id); MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name); @@ -103,12 +103,16 @@ class ShardIndexGenerator { void AddIndexFieldByRawData(const std::vector &schema_detail, std::vector> &row_data); + void DatabaseWriter(); // worker thread + std::string file_path_; bool append_; ShardHeader shard_header_; uint64_t page_size_; uint64_t header_size_; int schema_count_; + std::atomic_int task_; + std::atomic_bool write_success_; std::vector> fields_; }; } // namespace mindrecord diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc index 1c14d30f3..c0108241a 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include "mindrecord/include/shard_index_generator.h" #include "common/utils.h" @@ -26,7 +27,13 @@ using mindspore::MsLogLevel::INFO; namespace mindspore { namespace mindrecord { ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append) - : file_path_(file_path), append_(append), page_size_(0), header_size_(0), schema_count_(0) {} + : file_path_(file_path), + append_(append), + page_size_(0), + header_size_(0), + schema_count_(0), + task_(0), + write_success_(true) {} MSRStatus ShardIndexGenerator::Build() { ShardHeader header = ShardHeader(); @@ -284,7 +291,7 @@ std::pair ShardIndexGenerator::GenerateRawSQL( return {SUCCESS, sql}; } -MSRStatus ShardIndexGenerator::BindParamaterExecuteSQL( +MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( sqlite3 *db, const std::string &sql, const std::vector>> &data) { sqlite3_stmt *stmt = nullptr; @@ -471,9 +478,9 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector &s return {SUCCESS, std::move(fields)}; } -MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std::pair &db, - const std::vector &raw_page_ids, - const std::map &blob_id_to_page_id) { +MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std::pair &db, + const std::vector &raw_page_ids, + const std::map &blob_id_to_page_id) { // Add index data to database std::string shard_address = shard_header_.get_shard_address_by_id(shard_no); if (shard_address.empty()) { @@ -493,7 +500,7 @@ MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std: if (data.first != SUCCESS) { return FAILED; } - if (BindParamaterExecuteSQL(db.second, sql.second, data.second) == FAILED) { + if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { return FAILED; } MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; @@ -514,37 +521,62 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() { page_size_ = shard_header_.get_page_size(); header_size_ = shard_header_.get_header_size(); schema_count_ = shard_header_.get_schema_count(); - if (shard_header_.get_shard_count() <= kMaxShardCount) { - // Create one database per shard - for (int shard_no = 0; shard_no < shard_header_.get_shard_count(); ++shard_no) { - // Create database - auto db = CreateDatabase(shard_no); - if (db.first != SUCCESS || db.second == nullptr) { - return FAILED; - } - MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; - - // Pre-processing page information - auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; - - std::map blob_id_to_page_id; - std::vector raw_page_ids; - for (uint64_t i = 0; i < total_pages; ++i) { - std::shared_ptr cur_page = shard_header_.GetPage(shard_no, i).first; - if (cur_page->get_page_type() == "RAW_DATA") { - raw_page_ids.push_back(i); - } else if (cur_page->get_page_type() == "BLOB_DATA") { - blob_id_to_page_id[cur_page->get_page_type_id()] = i; - } - } + if (shard_header_.get_shard_count() > kMaxShardCount) { + MS_LOG(ERROR) << "num shards: " << shard_header_.get_shard_count() << " exceeds max count:" << kMaxSchemaCount; + return FAILED; + } + task_ = 0; // set two atomic vars to initial value + write_success_ = true; - if (ExcuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) { - return FAILED; + // spawn half the physical threads or total number of shards whichever is smaller + const unsigned int num_workers = + std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast(shard_header_.get_shard_count())); + + std::vector threads; + threads.reserve(num_workers); + + for (size_t t = 0; t < threads.capacity(); t++) { + threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this)); + } + + for (size_t t = 0; t < threads.capacity(); t++) { + threads[t].join(); + } + return write_success_ ? SUCCESS : FAILED; +} + +void ShardIndexGenerator::DatabaseWriter() { + int shard_no = task_++; + while (shard_no < shard_header_.get_shard_count()) { + auto db = CreateDatabase(shard_no); + if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { + write_success_ = false; + return; + } + + MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; + + // Pre-processing page information + auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; + + std::map blob_id_to_page_id; + std::vector raw_page_ids; + for (uint64_t i = 0; i < total_pages; ++i) { + std::shared_ptr cur_page = shard_header_.GetPage(shard_no, i).first; + if (cur_page->get_page_type() == "RAW_DATA") { + raw_page_ids.push_back(i); + } else if (cur_page->get_page_type() == "BLOB_DATA") { + blob_id_to_page_id[cur_page->get_page_type_id()] = i; } - MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully."; } + + if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) { + write_success_ = false; + return; + } + MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully."; + shard_no = task_++; } - return SUCCESS; } } // namespace mindrecord } // namespace mindspore -- GitLab