提交 71b81c8f 编写于 作者: Z Zirui Wu 提交者: 高东海

implemented multi-thread index writer for mindrecord

num threads cannot be more than num shards

minor fix

clang style fix

address review comments
上级 824d9e49
......@@ -85,14 +85,14 @@ class ShardIndexGenerator {
/// \param sql
/// \param data
/// \return
MSRStatus BindParamaterExecuteSQL(
MSRStatus BindParameterExecuteSQL(
sqlite3 *db, const std::string &sql,
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data);
INDEX_FIELDS GenerateIndexFields(const std::vector<json> &schema_detail);
MSRStatus ExcuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
const std::vector<int> &raw_page_ids, const std::map<int, int> &blob_id_to_page_id);
MSRStatus ExecuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
const std::vector<int> &raw_page_ids, const std::map<int, int> &blob_id_to_page_id);
MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name);
......@@ -103,12 +103,16 @@ class ShardIndexGenerator {
void AddIndexFieldByRawData(const std::vector<json> &schema_detail,
std::vector<std::tuple<std::string, std::string, std::string>> &row_data);
void DatabaseWriter(); // worker thread
std::string file_path_;
bool append_;
ShardHeader shard_header_;
uint64_t page_size_;
uint64_t header_size_;
int schema_count_;
std::atomic_int task_;
std::atomic_bool write_success_;
std::vector<std::pair<uint64_t, std::string>> fields_;
};
} // namespace mindrecord
......
......@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thread>
#include "mindrecord/include/shard_index_generator.h"
#include "common/utils.h"
......@@ -26,7 +27,13 @@ using mindspore::MsLogLevel::INFO;
namespace mindspore {
namespace mindrecord {
ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append)
: file_path_(file_path), append_(append), page_size_(0), header_size_(0), schema_count_(0) {}
: file_path_(file_path),
append_(append),
page_size_(0),
header_size_(0),
schema_count_(0),
task_(0),
write_success_(true) {}
MSRStatus ShardIndexGenerator::Build() {
ShardHeader header = ShardHeader();
......@@ -284,7 +291,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateRawSQL(
return {SUCCESS, sql};
}
MSRStatus ShardIndexGenerator::BindParamaterExecuteSQL(
MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
sqlite3 *db, const std::string &sql,
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) {
sqlite3_stmt *stmt = nullptr;
......@@ -471,9 +478,9 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s
return {SUCCESS, std::move(fields)};
}
MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
const std::vector<int> &raw_page_ids,
const std::map<int, int> &blob_id_to_page_id) {
MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
const std::vector<int> &raw_page_ids,
const std::map<int, int> &blob_id_to_page_id) {
// Add index data to database
std::string shard_address = shard_header_.get_shard_address_by_id(shard_no);
if (shard_address.empty()) {
......@@ -493,7 +500,7 @@ MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std:
if (data.first != SUCCESS) {
return FAILED;
}
if (BindParamaterExecuteSQL(db.second, sql.second, data.second) == FAILED) {
if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) {
return FAILED;
}
MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db.";
......@@ -514,37 +521,62 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
page_size_ = shard_header_.get_page_size();
header_size_ = shard_header_.get_header_size();
schema_count_ = shard_header_.get_schema_count();
if (shard_header_.get_shard_count() <= kMaxShardCount) {
// Create one database per shard
for (int shard_no = 0; shard_no < shard_header_.get_shard_count(); ++shard_no) {
// Create database
auto db = CreateDatabase(shard_no);
if (db.first != SUCCESS || db.second == nullptr) {
return FAILED;
}
MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";
// Pre-processing page information
auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;
std::map<int, int> blob_id_to_page_id;
std::vector<int> raw_page_ids;
for (uint64_t i = 0; i < total_pages; ++i) {
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
if (cur_page->get_page_type() == "RAW_DATA") {
raw_page_ids.push_back(i);
} else if (cur_page->get_page_type() == "BLOB_DATA") {
blob_id_to_page_id[cur_page->get_page_type_id()] = i;
}
}
if (shard_header_.get_shard_count() > kMaxShardCount) {
MS_LOG(ERROR) << "num shards: " << shard_header_.get_shard_count() << " exceeds max count:" << kMaxSchemaCount;
return FAILED;
}
task_ = 0; // set two atomic vars to initial value
write_success_ = true;
if (ExcuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) {
return FAILED;
// spawn half the physical threads or total number of shards whichever is smaller
const unsigned int num_workers =
std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.get_shard_count()));
std::vector<std::thread> threads;
threads.reserve(num_workers);
for (size_t t = 0; t < threads.capacity(); t++) {
threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this));
}
for (size_t t = 0; t < threads.capacity(); t++) {
threads[t].join();
}
return write_success_ ? SUCCESS : FAILED;
}
void ShardIndexGenerator::DatabaseWriter() {
int shard_no = task_++;
while (shard_no < shard_header_.get_shard_count()) {
auto db = CreateDatabase(shard_no);
if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) {
write_success_ = false;
return;
}
MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";
// Pre-processing page information
auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;
std::map<int, int> blob_id_to_page_id;
std::vector<int> raw_page_ids;
for (uint64_t i = 0; i < total_pages; ++i) {
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
if (cur_page->get_page_type() == "RAW_DATA") {
raw_page_ids.push_back(i);
} else if (cur_page->get_page_type() == "BLOB_DATA") {
blob_id_to_page_id[cur_page->get_page_type_id()] = i;
}
MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully.";
}
if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) {
write_success_ = false;
return;
}
MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully.";
shard_no = task_++;
}
return SUCCESS;
}
} // namespace mindrecord
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册