提交 cf352d19 编写于 作者: J jonyguo

format func name for mindrecord

上级 93429aba
......@@ -108,7 +108,7 @@ Status MindRecordOp::Init() {
data_schema_ = std::make_unique<DataSchema>();
std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->get_shard_header()->get_schemas();
std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->GetShardHeader()->GetSchemas();
// check whether schema exists, if so use the first one
CHECK_FAIL_RETURN_UNEXPECTED(!schema_vec.empty(), "No schema found");
mindrecord::json mr_schema = schema_vec[0]->GetSchema()["schema"];
......@@ -155,7 +155,7 @@ Status MindRecordOp::Init() {
column_name_mapping_[columns_to_load_[i]] = i;
}
num_rows_ = shard_reader_->get_num_rows();
num_rows_ = shard_reader_->GetNumRows();
// Compute how many buffers we would need to accomplish rowsPerBuffer
buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_;
RETURN_IF_NOT_OK(SetColumnsBlob());
......@@ -164,7 +164,7 @@ Status MindRecordOp::Init() {
}
Status MindRecordOp::SetColumnsBlob() {
columns_blob_ = shard_reader_->get_blob_fields().second;
columns_blob_ = shard_reader_->GetBlobFields().second;
// get the exactly blob fields by columns_to_load_
std::vector<std::string> columns_blob_exact;
......@@ -600,7 +600,7 @@ Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) {
// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work
Status MindRecordOp::operator()() {
RETURN_IF_NOT_OK(LaunchThreadAndInitOp());
num_rows_ = shard_reader_->get_num_rows();
num_rows_ = shard_reader_->GetNumRows();
buffers_needed_ = num_rows_ / rows_per_buffer_;
if (num_rows_ % rows_per_buffer_ != 0) {
......
......@@ -39,18 +39,18 @@ namespace mindrecord {
void BindSchema(py::module *m) {
(void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local())
.def_static("build", (std::shared_ptr<Schema>(*)(std::string, py::handle)) & Schema::Build)
.def("get_desc", &Schema::get_desc)
.def("get_desc", &Schema::GetDesc)
.def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython)
.def("get_blob_fields", &Schema::get_blob_fields)
.def("get_schema_id", &Schema::get_schema_id);
.def("get_blob_fields", &Schema::GetBlobFields)
.def("get_schema_id", &Schema::GetSchemaID);
}
void BindStatistics(const py::module *m) {
(void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local())
.def_static("build", (std::shared_ptr<Statistics>(*)(std::string, py::handle)) & Statistics::Build)
.def("get_desc", &Statistics::get_desc)
.def("get_desc", &Statistics::GetDesc)
.def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython)
.def("get_statistics_id", &Statistics::get_statistics_id);
.def("get_statistics_id", &Statistics::GetStatisticsID);
}
void BindShardHeader(const py::module *m) {
......@@ -60,9 +60,9 @@ void BindShardHeader(const py::module *m) {
.def("add_statistics", &ShardHeader::AddStatistic)
.def("add_index_fields",
(MSRStatus(ShardHeader::*)(const std::vector<std::string> &)) & ShardHeader::AddIndexFields)
.def("get_meta", &ShardHeader::get_schemas)
.def("get_statistics", &ShardHeader::get_statistics)
.def("get_fields", &ShardHeader::get_fields)
.def("get_meta", &ShardHeader::GetSchemas)
.def("get_statistics", &ShardHeader::GetStatistics)
.def("get_fields", &ShardHeader::GetFields)
.def("get_schema_by_id", &ShardHeader::GetSchemaByID)
.def("get_statistic_by_id", &ShardHeader::GetStatisticByID);
}
......@@ -72,8 +72,8 @@ void BindShardWriter(py::module *m) {
.def(py::init<>())
.def("open", &ShardWriter::Open)
.def("open_for_append", &ShardWriter::OpenForAppend)
.def("set_header_size", &ShardWriter::set_header_size)
.def("set_page_size", &ShardWriter::set_page_size)
.def("set_header_size", &ShardWriter::SetHeaderSize)
.def("set_page_size", &ShardWriter::SetPageSize)
.def("set_shard_header", &ShardWriter::SetShardHeader)
.def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &,
vector<vector<uint8_t>> &, bool, bool)) &
......@@ -88,8 +88,8 @@ void BindShardReader(const py::module *m) {
const std::vector<std::shared_ptr<ShardOperator>> &)) &
ShardReader::OpenPy)
.def("launch", &ShardReader::Launch)
.def("get_header", &ShardReader::get_shard_header)
.def("get_blob_fields", &ShardReader::get_blob_fields)
.def("get_header", &ShardReader::GetShardHeader)
.def("get_blob_fields", &ShardReader::GetBlobFields)
.def("get_next",
(std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy)
.def("finish", &ShardReader::Finish)
......@@ -119,9 +119,9 @@ void BindShardSegment(py::module *m) {
.def("read_at_page_by_name", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>(
ShardSegment::*)(std::string, int64_t, int64_t)) &
ShardSegment::ReadAtPageByNamePy)
.def("get_header", &ShardSegment::get_shard_header)
.def("get_header", &ShardSegment::GetShardHeader)
.def("get_blob_fields",
(std::pair<ShardType, std::vector<std::string>>(ShardSegment::*)()) & ShardSegment::get_blob_fields);
(std::pair<ShardType, std::vector<std::string>>(ShardSegment::*)()) & ShardSegment::GetBlobFields);
}
void BindGlobalParams(py::module *m) {
......
......@@ -36,7 +36,7 @@ class ShardCategory : public ShardOperator {
~ShardCategory() override{};
const std::vector<std::pair<std::string, std::string>> &get_categories() const { return categories_; }
const std::vector<std::pair<std::string, std::string>> &GetCategories() const { return categories_; }
const std::string GetCategoryField() const { return category_field_; }
......@@ -46,7 +46,7 @@ class ShardCategory : public ShardOperator {
bool GetReplacement() const { return replacement_; }
MSRStatus execute(ShardTask &tasks) override;
MSRStatus Execute(ShardTask &tasks) override;
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
......
......@@ -58,19 +58,19 @@ class ShardHeader {
/// \brief get the schema
/// \return the schema
std::vector<std::shared_ptr<Schema>> get_schemas();
std::vector<std::shared_ptr<Schema>> GetSchemas();
/// \brief get Statistics
/// \return the Statistic
std::vector<std::shared_ptr<Statistics>> get_statistics();
std::vector<std::shared_ptr<Statistics>> GetStatistics();
/// \brief get the fields of the index
/// \return the fields of the index
std::vector<std::pair<uint64_t, std::string>> get_fields();
std::vector<std::pair<uint64_t, std::string>> GetFields();
/// \brief get the index
/// \return the index
std::shared_ptr<Index> get_index();
std::shared_ptr<Index> GetIndex();
/// \brief get the schema by schemaid
/// \param[in] schemaId the id of schema needs to be got
......@@ -80,7 +80,7 @@ class ShardHeader {
/// \brief get the filepath to shard by shardID
/// \param[in] shardID the id of shard which filepath needs to be obtained
/// \return the filepath obtained by shardID
std::string get_shard_address_by_id(int64_t shard_id);
std::string GetShardAddressByID(int64_t shard_id);
/// \brief get the statistic by statistic id
/// \param[in] statisticId the id of statistic needs to be get
......@@ -89,7 +89,7 @@ class ShardHeader {
MSRStatus InitByFiles(const std::vector<std::string> &file_paths);
void set_index(Index index) { index_ = std::make_shared<Index>(index); }
void SetIndex(Index index) { index_ = std::make_shared<Index>(index); }
std::pair<std::shared_ptr<Page>, MSRStatus> GetPage(const int &shard_id, const int &page_id);
......@@ -103,21 +103,21 @@ class ShardHeader {
const std::pair<MSRStatus, std::shared_ptr<Page>> GetPageByGroupId(const int &group_id, const int &shard_id);
std::vector<std::string> get_shard_addresses() const { return shard_addresses_; }
std::vector<std::string> GetShardAddresses() const { return shard_addresses_; }
int get_shard_count() const { return shard_count_; }
int GetShardCount() const { return shard_count_; }
int get_schema_count() const { return schema_.size(); }
int GetSchemaCount() const { return schema_.size(); }
uint64_t get_header_size() const { return header_size_; }
uint64_t GetHeaderSize() const { return header_size_; }
uint64_t get_page_size() const { return page_size_; }
uint64_t GetPageSize() const { return page_size_; }
void set_header_size(const uint64_t &header_size) { header_size_ = header_size; }
void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; }
void set_page_size(const uint64_t &page_size) { page_size_ = page_size; }
void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
const string get_version() { return version_; }
const string GetVersion() { return version_; }
std::vector<std::string> SerializeHeader();
......@@ -132,7 +132,7 @@ class ShardHeader {
/// \param[in] the shard data real path
/// \param[in] the headers which readed from the shard data
/// \return SUCCESS/FAILED
MSRStatus get_headers(const vector<string> &real_addresses, std::vector<json> &headers);
MSRStatus GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers);
MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);
......
......@@ -52,7 +52,7 @@ class Index {
/// \brief get stored fields
/// \return fields stored
std::vector<std::pair<uint64_t, std::string> > get_fields();
std::vector<std::pair<uint64_t, std::string> > GetFields();
private:
std::vector<std::pair<uint64_t, std::string> > fields_;
......
......@@ -26,23 +26,23 @@ class ShardOperator {
virtual ~ShardOperator() = default;
MSRStatus operator()(ShardTask &tasks) {
if (SUCCESS != this->pre_execute(tasks)) {
if (SUCCESS != this->PreExecute(tasks)) {
return FAILED;
}
if (SUCCESS != this->execute(tasks)) {
if (SUCCESS != this->Execute(tasks)) {
return FAILED;
}
if (SUCCESS != this->suf_execute(tasks)) {
if (SUCCESS != this->SufExecute(tasks)) {
return FAILED;
}
return SUCCESS;
}
virtual MSRStatus pre_execute(ShardTask &tasks) { return SUCCESS; }
virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; }
virtual MSRStatus execute(ShardTask &tasks) = 0;
virtual MSRStatus Execute(ShardTask &tasks) = 0;
virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; }
virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; }
virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; }
};
......
......@@ -53,29 +53,29 @@ class Page {
/// \return the json format of the page and its description
json GetPage() const;
int get_page_id() const { return page_id_; }
int GetPageID() const { return page_id_; }
int get_shard_id() const { return shard_id_; }
int GetShardID() const { return shard_id_; }
int get_page_type_id() const { return page_type_id_; }
int GetPageTypeID() const { return page_type_id_; }
std::string get_page_type() const { return page_type_; }
std::string GetPageType() const { return page_type_; }
uint64_t get_page_size() const { return page_size_; }
uint64_t GetPageSize() const { return page_size_; }
uint64_t get_start_row_id() const { return start_row_id_; }
uint64_t GetStartRowID() const { return start_row_id_; }
uint64_t get_end_row_id() const { return end_row_id_; }
uint64_t GetEndRowID() const { return end_row_id_; }
void set_end_row_id(const uint64_t &end_row_id) { end_row_id_ = end_row_id; }
void SetEndRowID(const uint64_t &end_row_id) { end_row_id_ = end_row_id; }
void set_page_size(const uint64_t &page_size) { page_size_ = page_size; }
void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
std::pair<int, uint64_t> get_last_row_group_id() const { return row_group_ids_.back(); }
std::pair<int, uint64_t> GetLastRowGroupID() const { return row_group_ids_.back(); }
std::vector<std::pair<int, uint64_t>> get_row_group_ids() const { return row_group_ids_; }
std::vector<std::pair<int, uint64_t>> GetRowGroupIds() const { return row_group_ids_; }
void set_row_group_ids(const std::vector<std::pair<int, uint64_t>> &last_row_group_ids) {
void SetRowGroupIds(const std::vector<std::pair<int, uint64_t>> &last_row_group_ids) {
row_group_ids_ = last_row_group_ids;
}
......
......@@ -37,7 +37,7 @@ class ShardPkSample : public ShardCategory {
~ShardPkSample() override{};
MSRStatus suf_execute(ShardTask &tasks) override;
MSRStatus SufExecute(ShardTask &tasks) override;
private:
bool shuffle_;
......
......@@ -107,11 +107,11 @@ class ShardReader {
/// \brief aim to get the meta data
/// \return the metadata
std::shared_ptr<ShardHeader> get_shard_header() const;
std::shared_ptr<ShardHeader> GetShardHeader() const;
/// \brief get the number of shards
/// \return # of shards
int get_shard_count() const;
int GetShardCount() const;
/// \brief get the number of rows in database
/// \param[in] file_path the path of ONE file, any file in dataset is fine
......@@ -126,7 +126,7 @@ class ShardReader {
/// \brief get the number of rows in database
/// \return # of rows
int get_num_rows() const;
int GetNumRows() const;
/// \brief Read the summary of row groups
/// \return the tuple of 4 elements
......@@ -185,7 +185,7 @@ class ShardReader {
/// \brief get blob filed list
/// \return blob field list
std::pair<ShardType, std::vector<std::string>> get_blob_fields();
std::pair<ShardType, std::vector<std::string>> GetBlobFields();
/// \brief reset reader
/// \return null
......@@ -193,10 +193,10 @@ class ShardReader {
/// \brief set flag of all-in-index
/// \return null
void set_all_in_index(bool all_in_index) { all_in_index_ = all_in_index; }
void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; }
/// \brief get NLP flag
bool get_nlp_flag();
bool GetNlpFlag();
/// \brief get all classes
MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories);
......
......@@ -38,11 +38,11 @@ class ShardSample : public ShardOperator {
~ShardSample() override{};
const std::pair<int, int> get_partitions() const;
const std::pair<int, int> GetPartitions() const;
MSRStatus execute(ShardTask &tasks) override;
MSRStatus Execute(ShardTask &tasks) override;
MSRStatus suf_execute(ShardTask &tasks) override;
MSRStatus SufExecute(ShardTask &tasks) override;
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
......
......@@ -51,7 +51,7 @@ class Schema {
/// \brief get the schema and its description
/// \return the json format of the schema and its description
std::string get_desc() const;
std::string GetDesc() const;
/// \brief get the schema and its description
/// \return the json format of the schema and its description
......@@ -63,15 +63,15 @@ class Schema {
/// set the schema id
/// \param[in] id the id need to be set
void set_schema_id(int64_t id);
void SetSchemaID(int64_t id);
/// get the schema id
/// \return the int64 schema id
int64_t get_schema_id() const;
int64_t GetSchemaID() const;
/// get the blob fields
/// \return the vector<string> blob fields
std::vector<std::string> get_blob_fields() const;
std::vector<std::string> GetBlobFields() const;
private:
Schema() = default;
......
......@@ -81,7 +81,7 @@ class ShardSegment : public ShardReader {
std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByNamePy(
std::string category_name, int64_t page_no, int64_t n_rows_of_page);
std::pair<ShardType, std::vector<std::string>> get_blob_fields();
std::pair<ShardType, std::vector<std::string>> GetBlobFields();
private:
std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> WrapCategoryInfo();
......
......@@ -28,7 +28,7 @@ class ShardShuffle : public ShardOperator {
~ShardShuffle() override{};
MSRStatus execute(ShardTask &tasks) override;
MSRStatus Execute(ShardTask &tasks) override;
private:
uint32_t shuffle_seed_;
......
......@@ -53,11 +53,11 @@ class Statistics {
/// \brief get the description
/// \return the description
std::string get_desc() const;
std::string GetDesc() const;
/// \brief get the statistic
/// \return json format of the statistic
json get_statistics() const;
json GetStatistics() const;
/// \brief get the statistic for python
/// \return the python object of statistics
......@@ -66,11 +66,11 @@ class Statistics {
/// \brief decode the bson statistics to json
/// \param[in] encodedStatistics the bson type of statistics
/// \return json type of statistic
void set_statistics_id(int64_t id);
void SetStatisticsID(int64_t id);
/// \brief get the statistics id
/// \return the int64 statistics id
int64_t get_statistics_id() const;
int64_t GetStatisticsID() const;
private:
/// \brief validate the statistic
......
......@@ -39,9 +39,9 @@ class ShardTask {
uint32_t SizeOfRows() const;
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_task_by_id(size_t id);
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &GetTaskByID(size_t id);
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_random_task();
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask();
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements);
......
......@@ -69,12 +69,12 @@ class ShardWriter {
/// \brief Set file size
/// \param[in] header_size the size of header, only (1<<N) is accepted
/// \return MSRStatus the status of MSRStatus
MSRStatus set_header_size(const uint64_t &header_size);
MSRStatus SetHeaderSize(const uint64_t &header_size);
/// \brief Set page size
/// \param[in] page_size the size of page, only (1<<N) is accepted
/// \return MSRStatus the status of MSRStatus
MSRStatus set_page_size(const uint64_t &page_size);
MSRStatus SetPageSize(const uint64_t &page_size);
/// \brief Set shard header
/// \param[in] header_data the info of header
......
......@@ -64,7 +64,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const str
}
// schema does not contain the field
auto schema = shard_header_.get_schemas()[0]->GetSchema()["schema"];
auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"];
if (schema.find(field) == schema.end()) {
MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema;
return {FAILED, ""};
......@@ -203,7 +203,7 @@ MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::stri
}
std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no) {
std::string shard_address = shard_header_.get_shard_address_by_id(shard_no);
std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
if (shard_address.empty()) {
MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no;
return {FAILED, nullptr};
......@@ -357,12 +357,12 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
const std::shared_ptr<Page> cur_blob_page,
uint64_t &cur_blob_page_offset, std::fstream &in) {
row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->get_page_id()));
row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID()));
// blob data start
row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset));
auto &io_seekg_blob =
in.seekg(page_size_ * cur_blob_page->get_page_id() + header_size_ + cur_blob_page_offset, std::ios::beg);
in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg);
if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) {
MS_LOG(ERROR) << "File seekg failed";
in.close();
......@@ -405,7 +405,7 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
std::shared_ptr<Page> cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first;
// related blob page
vector<pair<int, uint64_t>> row_group_list = cur_raw_page->get_row_group_ids();
vector<pair<int, uint64_t>> row_group_list = cur_raw_page->GetRowGroupIds();
// pair: row_group id, offset in raw data page
for (pair<int, int> blob_ids : row_group_list) {
......@@ -415,18 +415,18 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
// offset in current raw data page
auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second);
uint64_t cur_blob_page_offset = 0;
for (unsigned int i = cur_blob_page->get_start_row_id(); i < cur_blob_page->get_end_row_id(); ++i) {
for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) {
std::vector<std::tuple<std::string, std::string, std::string>> row_data;
row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i));
row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->get_page_type_id()));
row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->get_page_id()));
row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID()));
row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID()));
// raw data start
row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset));
// calculate raw data end
auto &io_seekg =
in.seekg(page_size_ * (cur_raw_page->get_page_id()) + header_size_ + cur_raw_page_offset, std::ios::beg);
in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg);
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
MS_LOG(ERROR) << "File seekg failed";
in.close();
......@@ -473,7 +473,7 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail) {
std::vector<std::tuple<std::string, std::string, std::string>> fields;
// index fields
std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.get_fields();
std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields();
for (const auto &field : index_fields) {
if (field.first >= schema_detail.size()) {
return {FAILED, {}};
......@@ -504,7 +504,7 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std
const std::vector<int> &raw_page_ids,
const std::map<int, int> &blob_id_to_page_id) {
// Add index data to database
std::string shard_address = shard_header_.get_shard_address_by_id(shard_no);
std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
if (shard_address.empty()) {
MS_LOG(ERROR) << "Shard address is null";
return FAILED;
......@@ -546,12 +546,12 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std
}
MSRStatus ShardIndexGenerator::WriteToDatabase() {
fields_ = shard_header_.get_fields();
page_size_ = shard_header_.get_page_size();
header_size_ = shard_header_.get_header_size();
schema_count_ = shard_header_.get_schema_count();
if (shard_header_.get_shard_count() > kMaxShardCount) {
MS_LOG(ERROR) << "num shards: " << shard_header_.get_shard_count() << " exceeds max count:" << kMaxSchemaCount;
fields_ = shard_header_.GetFields();
page_size_ = shard_header_.GetPageSize();
header_size_ = shard_header_.GetHeaderSize();
schema_count_ = shard_header_.GetSchemaCount();
if (shard_header_.GetShardCount() > kMaxShardCount) {
MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount;
return FAILED;
}
task_ = 0; // set two atomic vars to initial value
......@@ -559,7 +559,7 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
// spawn half the physical threads or total number of shards whichever is smaller
const unsigned int num_workers =
std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.get_shard_count()));
std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.GetShardCount()));
std::vector<std::thread> threads;
threads.reserve(num_workers);
......@@ -576,7 +576,7 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
void ShardIndexGenerator::DatabaseWriter() {
int shard_no = task_++;
while (shard_no < shard_header_.get_shard_count()) {
while (shard_no < shard_header_.GetShardCount()) {
auto db = CreateDatabase(shard_no);
if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) {
write_success_ = false;
......@@ -592,10 +592,10 @@ void ShardIndexGenerator::DatabaseWriter() {
std::vector<int> raw_page_ids;
for (uint64_t i = 0; i < total_pages; ++i) {
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
if (cur_page->get_page_type() == "RAW_DATA") {
if (cur_page->GetPageType() == "RAW_DATA") {
raw_page_ids.push_back(i);
} else if (cur_page->get_page_type() == "BLOB_DATA") {
blob_id_to_page_id[cur_page->get_page_type_id()] = i;
} else if (cur_page->GetPageType() == "BLOB_DATA") {
blob_id_to_page_id[cur_page->GetPageTypeID()] = i;
}
}
......
......@@ -56,9 +56,9 @@ MSRStatus ShardReader::Init(const std::string &file_path) {
return FAILED;
}
shard_header_ = std::make_shared<ShardHeader>(sh);
header_size_ = shard_header_->get_header_size();
page_size_ = shard_header_->get_page_size();
file_paths_ = shard_header_->get_shard_addresses();
header_size_ = shard_header_->GetHeaderSize();
page_size_ = shard_header_->GetPageSize();
file_paths_ = shard_header_->GetShardAddresses();
for (const auto &file : file_paths_) {
sqlite3 *db = nullptr;
......@@ -105,7 +105,7 @@ MSRStatus ShardReader::Init(const std::string &file_path) {
MSRStatus ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) {
vector<int> inSchema(selected_columns.size(), 0);
for (auto &p : get_shard_header()->get_schemas()) {
for (auto &p : GetShardHeader()->GetSchemas()) {
auto schema = p->GetSchema()["schema"];
for (unsigned int i = 0; i < selected_columns.size(); ++i) {
if (schema.find(selected_columns[i]) != schema.end()) {
......@@ -183,15 +183,15 @@ void ShardReader::Close() {
FileStreamsOperator();
}
std::shared_ptr<ShardHeader> ShardReader::get_shard_header() const { return shard_header_; }
std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; }
int ShardReader::get_shard_count() const { return shard_header_->get_shard_count(); }
int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); }
int ShardReader::get_num_rows() const { return num_rows_; }
int ShardReader::GetNumRows() const { return num_rows_; }
std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummary() {
std::vector<std::tuple<int, int, int, uint64_t>> row_group_summary;
int shard_count = shard_header_->get_shard_count();
int shard_count = shard_header_->GetShardCount();
if (shard_count <= 0) {
return row_group_summary;
}
......@@ -205,13 +205,13 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) {
const auto &page_t = shard_header_->GetPage(shard_id, page_id);
const auto &page = page_t.first;
if (page->get_page_type() != kPageTypeBlob) continue;
uint64_t start_row_id = page->get_start_row_id();
if (start_row_id > page->get_end_row_id()) {
if (page->GetPageType() != kPageTypeBlob) continue;
uint64_t start_row_id = page->GetStartRowID();
if (start_row_id > page->GetEndRowID()) {
return std::vector<std::tuple<int, int, int, uint64_t>>();
}
uint64_t number_of_rows = page->get_end_row_id() - start_row_id;
row_group_summary.emplace_back(shard_id, page->get_page_type_id(), start_row_id, number_of_rows);
uint64_t number_of_rows = page->GetEndRowID() - start_row_id;
row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows);
}
}
}
......@@ -265,7 +265,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
json construct_json;
for (unsigned int j = 0; j < columns.size(); ++j) {
// construct json "f1": value
auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"];
auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
// convert the string to base type by schema
if (schema[columns[j]]["type"] == "int32") {
......@@ -317,7 +317,7 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
std::map<std::string, uint64_t> index_columns;
for (auto &field : get_shard_header()->get_fields()) {
for (auto &field : GetShardHeader()->GetFields()) {
index_columns[field.second] = field.first;
}
if (index_columns.find(category_field) == index_columns.end()) {
......@@ -400,11 +400,11 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const
}
const std::shared_ptr<Page> &page = ret.second;
std::string file_name = file_paths_[shard_id];
uint64_t page_length = page->get_page_size();
uint64_t page_offset = page_size_ * page->get_page_id() + header_size_;
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->get_page_id(), shard_id);
uint64_t page_length = page->GetPageSize();
uint64_t page_offset = page_size_ * page->GetPageID() + header_size_;
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id);
auto status_labels = GetLabels(page->get_page_id(), shard_id, columns);
auto status_labels = GetLabels(page->GetPageID(), shard_id, columns);
if (status_labels.first != SUCCESS) {
return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>());
}
......@@ -426,11 +426,11 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id,
}
const std::shared_ptr<Page> &page = ret.second;
std::string file_name = file_paths_[shard_id];
uint64_t page_length = page->get_page_size();
uint64_t page_offset = page_size_ * page->get_page_id() + header_size_;
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->get_page_id(), shard_id, criteria);
uint64_t page_length = page->GetPageSize();
uint64_t page_offset = page_size_ * page->GetPageID() + header_size_;
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria);
auto status_labels = GetLabels(page->get_page_id(), shard_id, columns, criteria);
auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria);
if (status_labels.first != SUCCESS) {
return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>());
}
......@@ -458,7 +458,7 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
// whether use index search
if (!criteria.first.empty()) {
auto schema = shard_header_->get_schemas()[0]->GetSchema();
auto schema = shard_header_->GetSchemas()[0]->GetSchema();
// not number field should add '' in sql
if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) {
......@@ -497,13 +497,13 @@ void ShardReader::CheckNlp() {
return;
}
bool ShardReader::get_nlp_flag() { return nlp_; }
bool ShardReader::GetNlpFlag() { return nlp_; }
std::pair<ShardType, std::vector<std::string>> ShardReader::get_blob_fields() {
std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
std::vector<std::string> blob_fields;
for (auto &p : get_shard_header()->get_schemas()) {
for (auto &p : GetShardHeader()->GetSchemas()) {
// assume one schema
const auto &fields = p->get_blob_fields();
const auto &fields = p->GetBlobFields();
blob_fields.assign(fields.begin(), fields.end());
break;
}
......@@ -516,7 +516,7 @@ void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns)
all_in_index_ = false;
return;
}
for (auto &field : get_shard_header()->get_fields()) {
for (auto &field : GetShardHeader()->GetFields()) {
column_schema_id_[field.second] = field.first;
}
for (auto &col : columns) {
......@@ -671,7 +671,7 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
json construct_json;
for (unsigned int j = 0; j < columns.size(); ++j) {
// construct json "f1": value
auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"];
auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
// convert the string to base type by schema
if (schema[columns[j]]["type"] == "int32") {
......@@ -719,9 +719,9 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
return -1;
}
auto header = std::make_shared<ShardHeader>(sh);
auto file_paths = header->get_shard_addresses();
auto file_paths = header->GetShardAddresses();
auto shard_count = file_paths.size();
auto index_fields = header->get_fields();
auto index_fields = header->GetFields();
std::map<std::string, int64_t> map_schema_id_fields;
for (auto &field : index_fields) {
......@@ -799,7 +799,7 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
if (nlp_) {
selected_columns_ = selected_columns;
} else {
vector<std::string> blob_fields = get_blob_fields().second;
vector<std::string> blob_fields = GetBlobFields().second;
for (unsigned int i = 0; i < selected_columns.size(); ++i) {
if (!std::any_of(blob_fields.begin(), blob_fields.end(),
[&selected_columns, i](std::string item) { return selected_columns[i] == item; })) {
......@@ -846,7 +846,7 @@ MSRStatus ShardReader::OpenPy(const std::string &file_path, const int &n_consume
}
// should remove blob field from selected_columns when call from python
std::vector<std::string> columns(selected_columns);
auto blob_fields = get_blob_fields().second;
auto blob_fields = GetBlobFields().second;
for (auto &blob_field : blob_fields) {
auto it = std::find(selected_columns.begin(), selected_columns.end(), blob_field);
if (it != selected_columns.end()) {
......@@ -909,7 +909,7 @@ vector<std::string> ShardReader::GetAllColumns() {
vector<std::string> columns;
if (nlp_) {
for (auto &c : selected_columns_) {
for (auto &p : get_shard_header()->get_schemas()) {
for (auto &p : GetShardHeader()->GetSchemas()) {
auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm.
for (auto it = schema.begin(); it != schema.end(); ++it) {
if (it.key() == c) {
......@@ -943,7 +943,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
CheckIfColumnInIndex(columns);
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
auto categories = category_op->get_categories();
auto categories = category_op->GetCategories();
int64_t num_elements = category_op->GetNumElements();
if (num_elements <= 0) {
MS_LOG(ERROR) << "Parameter num_element is not positive";
......@@ -1104,7 +1104,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
}
// Pick up task from task list
auto task = tasks_.get_task_by_id(tasks_.permutation_[task_id]);
auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]);
auto shard_id = std::get<0>(std::get<0>(task));
auto group_id = std::get<1>(std::get<0>(task));
......@@ -1117,7 +1117,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
// Pack image list
std::vector<uint8_t> images(addr[1] - addr[0]);
auto file_offset = header_size_ + page_size_ * (page->get_page_id()) + addr[0];
auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + addr[0];
auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg);
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
......@@ -1139,7 +1139,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
if (selected_columns_.size() == 0) {
images_with_exact_columns = images;
} else {
auto blob_fields = get_blob_fields();
auto blob_fields = GetBlobFields();
std::vector<uint32_t> ordered_selected_columns_index;
uint32_t index = 0;
......@@ -1272,7 +1272,7 @@ MSRStatus ShardReader::ConsumerByBlock(int consumer_id) {
}
// Pick up task from task list
auto task = tasks_.get_task_by_id(tasks_.permutation_[task_id]);
auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]);
auto shard_id = std::get<0>(std::get<0>(task));
auto group_id = std::get<1>(std::get<0>(task));
......
......@@ -28,7 +28,7 @@ using mindspore::MsLogLevel::INFO;
namespace mindspore {
namespace mindrecord {
ShardSegment::ShardSegment() { set_all_in_index(false); }
ShardSegment::ShardSegment() { SetAllInIndex(false); }
std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() {
// Skip if already populated
......@@ -211,7 +211,7 @@ std::pair<MSRStatus, std::vector<uint8_t>> ShardSegment::PackImages(int group_id
// Pack image list
std::vector<uint8_t> images(offset[1] - offset[0]);
auto file_offset = header_size_ + page_size_ * (blob_page->get_page_id()) + offset[0];
auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0];
auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg);
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
MS_LOG(ERROR) << "File seekg failed";
......@@ -363,21 +363,21 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::obje
return {SUCCESS, std::move(json_data)};
}
std::pair<ShardType, std::vector<std::string>> ShardSegment::get_blob_fields() {
std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() {
std::vector<std::string> blob_fields;
for (auto &p : get_shard_header()->get_schemas()) {
for (auto &p : GetShardHeader()->GetSchemas()) {
// assume one schema
const auto &fields = p->get_blob_fields();
const auto &fields = p->GetBlobFields();
blob_fields.assign(fields.begin(), fields.end());
break;
}
return std::make_pair(get_nlp_flag() ? kNLP : kCV, blob_fields);
return std::make_pair(GetNlpFlag() ? kNLP : kCV, blob_fields);
}
std::tuple<std::vector<uint8_t>, json> ShardSegment::GetImageLabel(std::vector<uint8_t> images, json label) {
if (get_nlp_flag()) {
if (GetNlpFlag()) {
vector<std::string> columns;
for (auto &p : get_shard_header()->get_schemas()) {
for (auto &p : GetShardHeader()->GetSchemas()) {
auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm.
auto schema_items = schema.items();
using it_type = decltype(schema_items.begin());
......
......@@ -179,12 +179,12 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
return FAILED;
}
shard_header_ = std::make_shared<ShardHeader>(sh);
auto paths = shard_header_->get_shard_addresses();
MSRStatus ret = set_header_size(shard_header_->get_header_size());
auto paths = shard_header_->GetShardAddresses();
MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize());
if (ret == FAILED) {
return FAILED;
}
ret = set_page_size(shard_header_->get_page_size());
ret = SetPageSize(shard_header_->GetPageSize());
if (ret == FAILED) {
return FAILED;
}
......@@ -229,10 +229,10 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
}
// set fields in mindrecord when empty
std::vector<std::pair<uint64_t, std::string>> fields = header_data->get_fields();
std::vector<std::pair<uint64_t, std::string>> fields = header_data->GetFields();
if (fields.empty()) {
MS_LOG(DEBUG) << "Missing index fields by user, auto generate index fields.";
std::vector<std::shared_ptr<Schema>> schemas = header_data->get_schemas();
std::vector<std::shared_ptr<Schema>> schemas = header_data->GetSchemas();
for (const auto &schema : schemas) {
json jsonSchema = schema->GetSchema()["schema"];
for (const auto &el : jsonSchema.items()) {
......@@ -241,7 +241,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
(el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) ||
(el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) ||
(el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) {
fields.emplace_back(std::make_pair(schema->get_schema_id(), el.key()));
fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key()));
}
}
}
......@@ -256,12 +256,12 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
}
shard_header_ = header_data;
shard_header_->set_header_size(header_size_);
shard_header_->set_page_size(page_size_);
shard_header_->SetHeaderSize(header_size_);
shard_header_->SetPageSize(page_size_);
return SUCCESS;
}
MSRStatus ShardWriter::set_header_size(const uint64_t &header_size) {
MSRStatus ShardWriter::SetHeaderSize(const uint64_t &header_size) {
// header_size [16KB, 128MB]
if (header_size < kMinHeaderSize || header_size > kMaxHeaderSize) {
MS_LOG(ERROR) << "Header size should between 16KB and 128MB.";
......@@ -276,7 +276,7 @@ MSRStatus ShardWriter::set_header_size(const uint64_t &header_size) {
return SUCCESS;
}
MSRStatus ShardWriter::set_page_size(const uint64_t &page_size) {
MSRStatus ShardWriter::SetPageSize(const uint64_t &page_size) {
// PageSize [32KB, 256MB]
if (page_size < kMinPageSize || page_size > kMaxPageSize) {
MS_LOG(ERROR) << "Page size should between 16KB and 256MB.";
......@@ -398,7 +398,7 @@ MSRStatus ShardWriter::CheckData(const std::map<uint64_t, std::vector<json>> &ra
return FAILED;
}
json schema = result.first->GetSchema()["schema"];
for (const auto &field : result.first->get_blob_fields()) {
for (const auto &field : result.first->GetBlobFields()) {
(void)schema.erase(field);
}
std::vector<json> sub_raw_data = rawdata_iter->second;
......@@ -456,7 +456,7 @@ std::tuple<MSRStatus, int, int> ShardWriter::ValidateRawData(std::map<uint64_t,
MS_LOG(DEBUG) << "Schema count is " << schema_count_;
// Determine if the number of schemas is the same
if (shard_header_->get_schemas().size() != schema_count_) {
if (shard_header_->GetSchemas().size() != schema_count_) {
MS_LOG(ERROR) << "Data size is not equal with the schema size";
return failed;
}
......@@ -475,9 +475,9 @@ std::tuple<MSRStatus, int, int> ShardWriter::ValidateRawData(std::map<uint64_t,
}
(void)schema_ids.insert(rawdata_iter->first);
}
const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->get_schemas();
const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->GetSchemas();
if (std::any_of(schemas.begin(), schemas.end(), [schema_ids](const std::shared_ptr<Schema> &schema) {
return schema_ids.find(schema->get_schema_id()) == schema_ids.end();
return schema_ids.find(schema->GetSchemaID()) == schema_ids.end();
})) {
// There is not enough data which is not matching the number of schema
MS_LOG(ERROR) << "Input rawdata schema id do not match real schema id.";
......@@ -810,10 +810,10 @@ MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector
std::vector<std::pair<int, int>> &rows_in_group,
const std::shared_ptr<Page> &last_raw_page,
const std::shared_ptr<Page> &last_blob_page) {
auto n_byte_blob = last_blob_page ? last_blob_page->get_page_size() : 0;
auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0;
auto last_raw_page_size = last_raw_page ? last_raw_page->get_page_size() : 0;
auto last_raw_offset = last_raw_page ? last_raw_page->get_last_row_group_id().second : 0;
auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0;
auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0;
auto n_byte_raw = last_raw_page_size - last_raw_offset;
int page_start_row = start_row;
......@@ -849,8 +849,8 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std
if (blob_row.first == blob_row.second) return SUCCESS;
// Write disk
auto page_id = last_blob_page->get_page_id();
auto bytes_page = last_blob_page->get_page_size();
auto page_id = last_blob_page->GetPageID();
auto bytes_page = last_blob_page->GetPageSize();
auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg);
if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
MS_LOG(ERROR) << "File seekp failed";
......@@ -862,9 +862,9 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std
// Update last blob page
bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0);
last_blob_page->set_page_size(bytes_page);
uint64_t end_row = last_blob_page->get_end_row_id() + blob_row.second - blob_row.first;
last_blob_page->set_end_row_id(end_row);
last_blob_page->SetPageSize(bytes_page);
uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first;
last_blob_page->SetEndRowID(end_row);
(void)shard_header_->SetPage(last_blob_page);
return SUCCESS;
}
......@@ -873,8 +873,8 @@ MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vector<std::v
const std::vector<std::pair<int, int>> &rows_in_group,
const std::shared_ptr<Page> &last_blob_page) {
auto page_id = shard_header_->GetLastPageId(shard_id);
auto page_type_id = last_blob_page ? last_blob_page->get_page_type_id() : -1;
auto current_row = last_blob_page ? last_blob_page->get_end_row_id() : 0;
auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1;
auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0;
// index(0) indicate appendBlobPage
for (uint32_t i = 1; i < rows_in_group.size(); ++i) {
auto blob_row = rows_in_group[i];
......@@ -905,15 +905,15 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::
std::shared_ptr<Page> &last_raw_page) {
auto blob_row = rows_in_group[0];
if (blob_row.first == blob_row.second) return SUCCESS;
auto last_raw_page_size = last_raw_page ? last_raw_page->get_page_size() : 0;
auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0;
if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) +
last_raw_page_size <=
page_size_) {
return SUCCESS;
}
auto page_id = shard_header_->GetLastPageId(shard_id);
auto last_row_group_id_offset = last_raw_page->get_last_row_group_id().second;
auto last_raw_page_id = last_raw_page->get_page_id();
auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second;
auto last_raw_page_id = last_raw_page->GetPageID();
auto shift_size = last_raw_page_size - last_row_group_id_offset;
std::vector<uint8_t> buf(shift_size);
......@@ -956,10 +956,10 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::
(void)shard_header_->SetPage(last_raw_page);
// Refresh page info in header
int row_group_id = last_raw_page->get_last_row_group_id().first + 1;
int row_group_id = last_raw_page->GetLastRowGroupID().first + 1;
std::vector<std::pair<int, uint64_t>> row_group_ids;
row_group_ids.emplace_back(row_group_id, 0);
int page_type_id = last_raw_page->get_page_id();
int page_type_id = last_raw_page->GetPageID();
auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size);
(void)shard_header_->AddPage(std::make_shared<Page>(page));
......@@ -971,7 +971,7 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::
MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
std::shared_ptr<Page> &last_raw_page,
const std::vector<std::vector<uint8_t>> &bin_raw_data) {
int last_row_group_id = last_raw_page ? last_raw_page->get_last_row_group_id().first : -1;
int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1;
for (uint32_t i = 0; i < rows_in_group.size(); ++i) {
const auto &blob_row = rows_in_group[i];
if (blob_row.first == blob_row.second) continue;
......@@ -979,7 +979,7 @@ MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::
std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0);
if (!last_raw_page) {
EmptyRawPage(shard_id, last_raw_page);
} else if (last_raw_page->get_page_size() + raw_size > page_size_) {
} else if (last_raw_page->GetPageSize() + raw_size > page_size_) {
(void)shard_header_->SetPage(last_raw_page);
EmptyRawPage(shard_id, last_raw_page);
}
......@@ -994,7 +994,7 @@ MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::
void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) {
auto row_group_ids = std::vector<std::pair<int, uint64_t>>();
auto page_id = shard_header_->GetLastPageId(shard_id);
auto page_type_id = last_raw_page ? last_raw_page->get_page_id() : -1;
auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1;
auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0);
(void)shard_header_->AddPage(std::make_shared<Page>(page));
SetLastRawPage(shard_id, last_raw_page);
......@@ -1003,9 +1003,9 @@ void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_
MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
const int &chunk_id, int &last_row_group_id, std::shared_ptr<Page> last_raw_page,
const std::vector<std::vector<uint8_t>> &bin_raw_data) {
std::vector<std::pair<int, uint64_t>> row_group_ids = last_raw_page->get_row_group_ids();
auto last_raw_page_id = last_raw_page->get_page_id();
auto n_bytes = last_raw_page->get_page_size();
std::vector<std::pair<int, uint64_t>> row_group_ids = last_raw_page->GetRowGroupIds();
auto last_raw_page_id = last_raw_page->GetPageID();
auto n_bytes = last_raw_page->GetPageSize();
// previous raw data page
auto &io_seekp =
......@@ -1022,8 +1022,8 @@ MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std:
(void)FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data);
// Update previous raw data page
last_raw_page->set_page_size(n_bytes);
last_raw_page->set_row_group_ids(row_group_ids);
last_raw_page->SetPageSize(n_bytes);
last_raw_page->SetRowGroupIds(row_group_ids);
(void)shard_header_->SetPage(last_raw_page);
return SUCCESS;
......
......@@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem
num_categories_(num_categories),
replacement_(replacement) {}
MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; }
MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; }
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (dataset_size == 0) return dataset_size;
......
......@@ -343,7 +343,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
std::string ShardHeader::SerializeIndexFields() {
json j;
auto fields = index_->get_fields();
auto fields = index_->GetFields();
for (const auto &field : fields) {
j.push_back({{"schema_id", field.first}, {"index_field", field.second}});
}
......@@ -365,7 +365,7 @@ std::vector<std::string> ShardHeader::SerializePage() {
std::string ShardHeader::SerializeStatistics() {
json j;
for (const auto &stats : statistics_) {
j.emplace_back(stats->get_statistics());
j.emplace_back(stats->GetStatistics());
}
return j.dump();
}
......@@ -398,8 +398,8 @@ MSRStatus ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
if (new_page == nullptr) {
return FAILED;
}
int shard_id = new_page->get_shard_id();
int page_id = new_page->get_page_id();
int shard_id = new_page->GetShardID();
int page_id = new_page->GetPageID();
if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
pages_[shard_id][page_id] = new_page;
return SUCCESS;
......@@ -412,8 +412,8 @@ MSRStatus ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
if (new_page == nullptr) {
return FAILED;
}
int shard_id = new_page->get_shard_id();
int page_id = new_page->get_page_id();
int shard_id = new_page->GetShardID();
int page_id = new_page->GetPageID();
if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) {
pages_[shard_id].push_back(new_page);
return SUCCESS;
......@@ -435,8 +435,8 @@ int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &pag
}
int last_page_id = -1;
for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
if (pages_[shard_id][i - 1]->get_page_type() == page_type) {
last_page_id = pages_[shard_id][i - 1]->get_page_id();
if (pages_[shard_id][i - 1]->GetPageType() == page_type) {
last_page_id = pages_[shard_id][i - 1]->GetPageID();
return last_page_id;
}
}
......@@ -451,7 +451,7 @@ const std::pair<MSRStatus, std::shared_ptr<Page>> ShardHeader::GetPageByGroupId(
}
for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
auto page = pages_[shard_id][i - 1];
if (page->get_page_type() == kPageTypeBlob && page->get_page_type_id() == group_id) {
if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) {
return {SUCCESS, page};
}
}
......@@ -470,10 +470,10 @@ int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
return -1;
}
int64_t schema_id = schema->get_schema_id();
int64_t schema_id = schema->GetSchemaID();
if (schema_id == -1) {
schema_id = schema_.size();
schema->set_schema_id(schema_id);
schema->SetSchemaID(schema_id);
}
schema_.push_back(schema);
return schema_id;
......@@ -481,10 +481,10 @@ int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
void ShardHeader::AddStatistic(std::shared_ptr<Statistics> statistic) {
if (statistic) {
int64_t statistics_id = statistic->get_statistics_id();
int64_t statistics_id = statistic->GetStatisticsID();
if (statistics_id == -1) {
statistics_id = statistics_.size();
statistic->set_statistics_id(statistics_id);
statistic->SetStatisticsID(statistics_id);
}
statistics_.push_back(statistic);
}
......@@ -527,13 +527,13 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
return FAILED;
}
if (get_schemas().empty()) {
if (GetSchemas().empty()) {
MS_LOG(ERROR) << "No schema is set";
return FAILED;
}
for (const auto &schemaPtr : schema_) {
auto result = GetSchemaByID(schemaPtr->get_schema_id());
auto result = GetSchemaByID(schemaPtr->GetSchemaID());
if (result.second != SUCCESS) {
MS_LOG(ERROR) << "Could not get schema by id.";
return FAILED;
......@@ -548,7 +548,7 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
// checkout and add fields for each schema
std::set<std::string> field_set;
for (const auto &item : index->get_fields()) {
for (const auto &item : index->GetFields()) {
field_set.insert(item.second);
}
for (const auto &field : fields) {
......@@ -564,7 +564,7 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
field_set.insert(field);
// add field into index
index.get()->AddIndexField(schemaPtr->get_schema_id(), field);
index.get()->AddIndexField(schemaPtr->GetSchemaID(), field);
}
}
......@@ -575,12 +575,12 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
MSRStatus ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) {
// get all schema id
for (const auto &schema : schema_) {
auto bucket_it = bucket_count.find(schema->get_schema_id());
auto bucket_it = bucket_count.find(schema->GetSchemaID());
if (bucket_it != bucket_count.end()) {
MS_LOG(ERROR) << "Schema duplication";
return FAILED;
} else {
bucket_count.insert(schema->get_schema_id());
bucket_count.insert(schema->GetSchemaID());
}
}
return SUCCESS;
......@@ -603,7 +603,7 @@ MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::strin
// check and add fields for each schema
std::set<std::pair<uint64_t, std::string>> field_set;
for (const auto &item : index->get_fields()) {
for (const auto &item : index->GetFields()) {
field_set.insert(item);
}
for (const auto &field : fields) {
......@@ -646,20 +646,20 @@ MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::strin
return SUCCESS;
}
std::string ShardHeader::get_shard_address_by_id(int64_t shard_id) {
std::string ShardHeader::GetShardAddressByID(int64_t shard_id) {
if (shard_id >= shard_addresses_.size()) {
return "";
}
return shard_addresses_.at(shard_id);
}
std::vector<std::shared_ptr<Schema>> ShardHeader::get_schemas() { return schema_; }
std::vector<std::shared_ptr<Schema>> ShardHeader::GetSchemas() { return schema_; }
std::vector<std::shared_ptr<Statistics>> ShardHeader::get_statistics() { return statistics_; }
std::vector<std::shared_ptr<Statistics>> ShardHeader::GetStatistics() { return statistics_; }
std::vector<std::pair<uint64_t, std::string>> ShardHeader::get_fields() { return index_->get_fields(); }
std::vector<std::pair<uint64_t, std::string>> ShardHeader::GetFields() { return index_->GetFields(); }
std::shared_ptr<Index> ShardHeader::get_index() { return index_; }
std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; }
std::pair<std::shared_ptr<Schema>, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) {
int64_t schemaSize = schema_.size();
......
......@@ -28,6 +28,6 @@ void Index::AddIndexField(const int64_t &schemaId, const std::string &field) {
}
// Get attribute list
std::vector<std::pair<uint64_t, std::string>> Index::get_fields() { return fields_; }
std::vector<std::pair<uint64_t, std::string>> Index::GetFields() { return fields_; }
} // namespace mindrecord
} // namespace mindspore
......@@ -34,7 +34,7 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement
}
MSRStatus ShardPkSample::suf_execute(ShardTask &tasks) {
MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) {
if (shuffle_ == true) {
if (SUCCESS != (*shuffle_op_)(tasks)) {
return FAILED;
......
......@@ -74,14 +74,14 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
return -1;
}
const std::pair<int, int> ShardSample::get_partitions() const {
const std::pair<int, int> ShardSample::GetPartitions() const {
if (numerator_ == 1 && denominator_ > 1) {
return std::pair<int, int>(denominator_, partition_id_);
}
return std::pair<int, int>(-1, -1);
}
MSRStatus ShardSample::execute(ShardTask &tasks) {
MSRStatus ShardSample::Execute(ShardTask &tasks) {
int no_of_categories = static_cast<int>(tasks.categories);
int total_no = static_cast<int>(tasks.Size());
......@@ -114,11 +114,11 @@ MSRStatus ShardSample::execute(ShardTask &tasks) {
if (sampler_type_ == kSubsetRandomSampler) {
for (int i = 0; i < indices_.size(); ++i) {
int index = ((indices_[i] % total_no) + total_no) % total_no;
new_tasks.InsertTask(tasks.get_task_by_id(index)); // different mod result between c and python
new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python
}
} else {
for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
new_tasks.InsertTask(tasks.get_task_by_id(i % total_no)); // rounding up. if overflow, go back to start
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start
}
}
std::swap(tasks, new_tasks);
......@@ -129,14 +129,14 @@ MSRStatus ShardSample::execute(ShardTask &tasks) {
}
total_no = static_cast<int>(tasks.permutation_.size());
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
new_tasks.InsertTask(tasks.get_task_by_id(tasks.permutation_[i % total_no]));
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
}
std::swap(tasks, new_tasks);
}
return SUCCESS;
}
MSRStatus ShardSample::suf_execute(ShardTask &tasks) {
MSRStatus ShardSample::SufExecute(ShardTask &tasks) {
if (sampler_type_ == kSubsetRandomSampler) {
if (SUCCESS != (*shuffle_op_)(tasks)) {
return FAILED;
......
......@@ -44,7 +44,7 @@ std::shared_ptr<Schema> Schema::Build(std::string desc, pybind11::handle schema)
return Build(std::move(desc), schema_json);
}
std::string Schema::get_desc() const { return desc_; }
std::string Schema::GetDesc() const { return desc_; }
json Schema::GetSchema() const {
json str_schema;
......@@ -60,11 +60,11 @@ pybind11::object Schema::GetSchemaForPython() const {
return schema_py;
}
void Schema::set_schema_id(int64_t id) { schema_id_ = id; }
void Schema::SetSchemaID(int64_t id) { schema_id_ = id; }
int64_t Schema::get_schema_id() const { return schema_id_; }
int64_t Schema::GetSchemaID() const { return schema_id_; }
std::vector<std::string> Schema::get_blob_fields() const { return blob_fields_; }
std::vector<std::string> Schema::GetBlobFields() const { return blob_fields_; }
std::vector<std::string> Schema::PopulateBlobFields(json schema) {
std::vector<std::string> blob_fields;
......@@ -155,7 +155,7 @@ bool Schema::Validate(json schema) {
}
bool Schema::operator==(const mindrecord::Schema &b) const {
if (this->get_desc() != b.get_desc() || this->GetSchema() != b.GetSchema()) {
if (this->GetDesc() != b.GetDesc() || this->GetSchema() != b.GetSchema()) {
return false;
}
return true;
......
......@@ -23,7 +23,7 @@ namespace mindrecord {
ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type)
: shuffle_seed_(seed), shuffle_type_(shuffle_type) {}
MSRStatus ShardShuffle::execute(ShardTask &tasks) {
MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
if (tasks.categories < 1) {
return FAILED;
}
......
......@@ -48,9 +48,9 @@ std::shared_ptr<Statistics> Statistics::Build(std::string desc, pybind11::handle
return std::make_shared<Statistics>(object_statistics);
}
std::string Statistics::get_desc() const { return desc_; }
std::string Statistics::GetDesc() const { return desc_; }
json Statistics::get_statistics() const {
json Statistics::GetStatistics() const {
json str_statistics;
str_statistics["desc"] = desc_;
str_statistics["statistics"] = statistics_;
......@@ -58,13 +58,13 @@ json Statistics::get_statistics() const {
}
pybind11::object Statistics::GetStatisticsForPython() const {
json str_statistics = Statistics::get_statistics();
json str_statistics = Statistics::GetStatistics();
return nlohmann::detail::FromJsonImpl(str_statistics);
}
void Statistics::set_statistics_id(int64_t id) { statistics_id_ = id; }
void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; }
int64_t Statistics::get_statistics_id() const { return statistics_id_; }
int64_t Statistics::GetStatisticsID() const { return statistics_id_; }
bool Statistics::Validate(const json &statistics) {
if (statistics.size() != kInt1) {
......@@ -103,7 +103,7 @@ bool Statistics::LevelRecursive(json level) {
}
bool Statistics::operator==(const Statistics &b) const {
if (this->get_statistics() != b.get_statistics()) {
if (this->GetStatistics() != b.GetStatistics()) {
return false;
}
return true;
......
......@@ -59,12 +59,12 @@ uint32_t ShardTask::SizeOfRows() const {
return nRows;
}
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_task_by_id(size_t id) {
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetTaskByID(size_t id) {
MS_ASSERT(id < task_list_.size());
return task_list_[id];
}
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_random_task() {
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetRandomTask() {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
......@@ -82,7 +82,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
}
for (uint32_t task_no = 0; task_no < minTasks; task_no++) {
for (uint32_t i = 0; i < total_categories; i++) {
res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast<int>(task_no))));
res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no))));
}
}
} else {
......@@ -95,7 +95,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
}
for (uint32_t i = 0; i < total_categories; i++) {
for (uint32_t j = 0; j < maxTasks; j++) {
res.InsertTask(category_tasks[i].get_random_task());
res.InsertTask(category_tasks[i].GetRandomTask());
}
}
}
......
......@@ -52,7 +52,7 @@ TEST_F(TestShard, TestShardSchemaPart) {
std::shared_ptr<Schema> schema = Schema::Build(desc, j);
ASSERT_TRUE(schema != nullptr);
MS_LOG(INFO) << "schema description: " << schema->get_desc() << ", schema: " <<
MS_LOG(INFO) << "schema description: " << schema->GetDesc() << ", schema: " <<
common::SafeCStr(schema->GetSchema().dump());
for (int i = 1; i <= 4; i++) {
string filename = std::string("./imagenet.shard0") + std::to_string(i);
......@@ -71,8 +71,8 @@ TEST_F(TestShard, TestStatisticPart) {
nlohmann::json statistic_json = json::parse(kStatistics[2]);
std::shared_ptr<Statistics> statistics = Statistics::Build(desc, statistic_json);
ASSERT_TRUE(statistics != nullptr);
MS_LOG(INFO) << "test get_desc(), result: " << statistics->get_desc();
MS_LOG(INFO) << "test get_statistics, result: " << statistics->get_statistics().dump();
MS_LOG(INFO) << "test get_desc(), result: " << statistics->GetDesc();
MS_LOG(INFO) << "test get_statistics, result: " << statistics->GetStatistics().dump();
std::string desc2 = "axis";
nlohmann::json statistic_json2 = R"({})";
......@@ -111,13 +111,13 @@ TEST_F(TestShard, TestShardHeaderPart) {
ASSERT_EQ(res, 0);
header_data.AddStatistic(statistics1);
std::vector<Schema> re_schemas;
for (auto &schema_ptr : header_data.get_schemas()) {
for (auto &schema_ptr : header_data.GetSchemas()) {
re_schemas.push_back(*schema_ptr);
}
ASSERT_EQ(re_schemas, validate_schema);
std::vector<Statistics> re_statistics;
for (auto &statistic : header_data.get_statistics()) {
for (auto &statistic : header_data.GetStatistics()) {
re_statistics.push_back(*statistic);
}
ASSERT_EQ(re_statistics, validate_statistics);
......@@ -129,7 +129,7 @@ TEST_F(TestShard, TestShardHeaderPart) {
std::pair<uint64_t, std::string> pair1(0, "name");
fields.push_back(pair1);
ASSERT_TRUE(header_data.AddIndexFields(fields) == SUCCESS);
std::vector<std::pair<uint64_t, std::string>> resFields = header_data.get_fields();
std::vector<std::pair<uint64_t, std::string>> resFields = header_data.GetFields();
ASSERT_EQ(resFields, fields);
}
......
......@@ -70,7 +70,7 @@ TEST_F(TestShardHeader, AddIndexFields) {
int schema_id1 = header_data.AddSchema(schema1);
int schema_id2 = header_data.AddSchema(schema2);
ASSERT_EQ(schema_id2, -1);
ASSERT_EQ(header_data.get_schemas().size(), 1);
ASSERT_EQ(header_data.GetSchemas().size(), 1);
// check out fields
std::vector<std::pair<uint64_t, std::string>> fields;
......@@ -81,35 +81,35 @@ TEST_F(TestShardHeader, AddIndexFields) {
fields.push_back(index_field2);
MSRStatus res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, SUCCESS);
ASSERT_EQ(header_data.get_fields().size(), 2);
ASSERT_EQ(header_data.GetFields().size(), 2);
fields.clear();
std::pair<uint64_t, std::string> index_field3(schema_id1, "name");
fields.push_back(index_field3);
res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, FAILED);
ASSERT_EQ(header_data.get_fields().size(), 2);
ASSERT_EQ(header_data.GetFields().size(), 2);
fields.clear();
std::pair<uint64_t, std::string> index_field4(schema_id1, "names");
fields.push_back(index_field4);
res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, FAILED);
ASSERT_EQ(header_data.get_fields().size(), 2);
ASSERT_EQ(header_data.GetFields().size(), 2);
fields.clear();
std::pair<uint64_t, std::string> index_field5(schema_id1 + 1, "name");
fields.push_back(index_field5);
res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, FAILED);
ASSERT_EQ(header_data.get_fields().size(), 2);
ASSERT_EQ(header_data.GetFields().size(), 2);
fields.clear();
std::pair<uint64_t, std::string> index_field6(schema_id1, "label");
fields.push_back(index_field6);
res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, FAILED);
ASSERT_EQ(header_data.get_fields().size(), 2);
ASSERT_EQ(header_data.GetFields().size(), 2);
std::string desc_new = "this is a test1";
json schemaContent_new = R"({"name": {"type": "string"},
......@@ -121,7 +121,7 @@ TEST_F(TestShardHeader, AddIndexFields) {
mindrecord::ShardHeader header_data_new;
header_data_new.AddSchema(schema_new);
ASSERT_EQ(header_data_new.get_schemas().size(), 1);
ASSERT_EQ(header_data_new.GetSchemas().size(), 1);
// test add fields
std::vector<std::string> single_fields;
......@@ -131,25 +131,25 @@ TEST_F(TestShardHeader, AddIndexFields) {
single_fields.push_back("box");
res = header_data_new.AddIndexFields(single_fields);
ASSERT_EQ(res, FAILED);
ASSERT_EQ(header_data_new.get_fields().size(), 1);
ASSERT_EQ(header_data_new.GetFields().size(), 1);
single_fields.push_back("name");
single_fields.push_back("box");
res = header_data_new.AddIndexFields(single_fields);
ASSERT_EQ(res, FAILED);
ASSERT_EQ(header_data_new.get_fields().size(), 1);
ASSERT_EQ(header_data_new.GetFields().size(), 1);
single_fields.clear();
single_fields.push_back("names");
res = header_data_new.AddIndexFields(single_fields);
ASSERT_EQ(res, FAILED);
ASSERT_EQ(header_data_new.get_fields().size(), 1);
ASSERT_EQ(header_data_new.GetFields().size(), 1);
single_fields.clear();
single_fields.push_back("box");
res = header_data_new.AddIndexFields(single_fields);
ASSERT_EQ(res, SUCCESS);
ASSERT_EQ(header_data_new.get_fields().size(), 2);
ASSERT_EQ(header_data_new.GetFields().size(), 2);
}
} // namespace mindrecord
} // namespace mindspore
......@@ -139,7 +139,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
const int kPar = 2;
std::vector<std::shared_ptr<ShardOperator>> ops;
ops.push_back(std::make_shared<ShardSample>(kNum, kDen, kPar));
auto partitions = std::dynamic_pointer_cast<ShardSample>(ops[0])->get_partitions();
auto partitions = std::dynamic_pointer_cast<ShardSample>(ops[0])->GetPartitions();
ASSERT_TRUE(partitions.first == 4);
ASSERT_TRUE(partitions.second == 2);
......
......@@ -57,15 +57,15 @@ TEST_F(TestShardPage, TestBasic) {
Page page =
Page(kGoldenPageId, kGoldenShardId, kGoldenType, kGoldenTypeId, kGoldenStart, kGoldenEnd, golden_row_group, kGoldenSize);
EXPECT_EQ(kGoldenPageId, page.get_page_id());
EXPECT_EQ(kGoldenShardId, page.get_shard_id());
EXPECT_EQ(kGoldenTypeId, page.get_page_type_id());
ASSERT_TRUE(kGoldenType == page.get_page_type());
EXPECT_EQ(kGoldenSize, page.get_page_size());
EXPECT_EQ(kGoldenStart, page.get_start_row_id());
EXPECT_EQ(kGoldenEnd, page.get_end_row_id());
ASSERT_TRUE(std::make_pair(4, kOffset) == page.get_last_row_group_id());
ASSERT_TRUE(golden_row_group == page.get_row_group_ids());
EXPECT_EQ(kGoldenPageId, page.GetPageID());
EXPECT_EQ(kGoldenShardId, page.GetShardID());
EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID());
ASSERT_TRUE(kGoldenType == page.GetPageType());
EXPECT_EQ(kGoldenSize, page.GetPageSize());
EXPECT_EQ(kGoldenStart, page.GetStartRowID());
EXPECT_EQ(kGoldenEnd, page.GetEndRowID());
ASSERT_TRUE(std::make_pair(4, kOffset) == page.GetLastRowGroupID());
ASSERT_TRUE(golden_row_group == page.GetRowGroupIds());
}
TEST_F(TestShardPage, TestSetter) {
......@@ -86,43 +86,43 @@ TEST_F(TestShardPage, TestSetter) {
Page page =
Page(kGoldenPageId, kGoldenShardId, kGoldenType, kGoldenTypeId, kGoldenStart, kGoldenEnd, golden_row_group, kGoldenSize);
EXPECT_EQ(kGoldenPageId, page.get_page_id());
EXPECT_EQ(kGoldenShardId, page.get_shard_id());
EXPECT_EQ(kGoldenTypeId, page.get_page_type_id());
ASSERT_TRUE(kGoldenType == page.get_page_type());
EXPECT_EQ(kGoldenSize, page.get_page_size());
EXPECT_EQ(kGoldenStart, page.get_start_row_id());
EXPECT_EQ(kGoldenEnd, page.get_end_row_id());
ASSERT_TRUE(std::make_pair(4, kOffset1) == page.get_last_row_group_id());
ASSERT_TRUE(golden_row_group == page.get_row_group_ids());
EXPECT_EQ(kGoldenPageId, page.GetPageID());
EXPECT_EQ(kGoldenShardId, page.GetShardID());
EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID());
ASSERT_TRUE(kGoldenType == page.GetPageType());
EXPECT_EQ(kGoldenSize, page.GetPageSize());
EXPECT_EQ(kGoldenStart, page.GetStartRowID());
EXPECT_EQ(kGoldenEnd, page.GetEndRowID());
ASSERT_TRUE(std::make_pair(4, kOffset1) == page.GetLastRowGroupID());
ASSERT_TRUE(golden_row_group == page.GetRowGroupIds());
const int kNewEnd = 33;
const int kNewSize = 300;
std::vector<std::pair<int, uint64_t>> new_row_group = {{0, 100}, {100, 200}, {200, 3000}};
page.set_end_row_id(kNewEnd);
page.set_page_size(kNewSize);
page.set_row_group_ids(new_row_group);
EXPECT_EQ(kGoldenPageId, page.get_page_id());
EXPECT_EQ(kGoldenShardId, page.get_shard_id());
EXPECT_EQ(kGoldenTypeId, page.get_page_type_id());
ASSERT_TRUE(kGoldenType == page.get_page_type());
EXPECT_EQ(kNewSize, page.get_page_size());
EXPECT_EQ(kGoldenStart, page.get_start_row_id());
EXPECT_EQ(kNewEnd, page.get_end_row_id());
ASSERT_TRUE(std::make_pair(200, kOffset2) == page.get_last_row_group_id());
ASSERT_TRUE(new_row_group == page.get_row_group_ids());
page.SetEndRowID(kNewEnd);
page.SetPageSize(kNewSize);
page.SetRowGroupIds(new_row_group);
EXPECT_EQ(kGoldenPageId, page.GetPageID());
EXPECT_EQ(kGoldenShardId, page.GetShardID());
EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID());
ASSERT_TRUE(kGoldenType == page.GetPageType());
EXPECT_EQ(kNewSize, page.GetPageSize());
EXPECT_EQ(kGoldenStart, page.GetStartRowID());
EXPECT_EQ(kNewEnd, page.GetEndRowID());
ASSERT_TRUE(std::make_pair(200, kOffset2) == page.GetLastRowGroupID());
ASSERT_TRUE(new_row_group == page.GetRowGroupIds());
page.DeleteLastGroupId();
EXPECT_EQ(kGoldenPageId, page.get_page_id());
EXPECT_EQ(kGoldenShardId, page.get_shard_id());
EXPECT_EQ(kGoldenTypeId, page.get_page_type_id());
ASSERT_TRUE(kGoldenType == page.get_page_type());
EXPECT_EQ(3000, page.get_page_size());
EXPECT_EQ(kGoldenStart, page.get_start_row_id());
EXPECT_EQ(kNewEnd, page.get_end_row_id());
ASSERT_TRUE(std::make_pair(100, kOffset3) == page.get_last_row_group_id());
EXPECT_EQ(kGoldenPageId, page.GetPageID());
EXPECT_EQ(kGoldenShardId, page.GetShardID());
EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID());
ASSERT_TRUE(kGoldenType == page.GetPageType());
EXPECT_EQ(3000, page.GetPageSize());
EXPECT_EQ(kGoldenStart, page.GetStartRowID());
EXPECT_EQ(kNewEnd, page.GetEndRowID());
ASSERT_TRUE(std::make_pair(100, kOffset3) == page.GetLastRowGroupID());
new_row_group.pop_back();
ASSERT_TRUE(new_row_group == page.get_row_group_ids());
ASSERT_TRUE(new_row_group == page.GetRowGroupIds());
}
TEST_F(TestShardPage, TestJson) {
......
......@@ -107,15 +107,15 @@ TEST_F(TestShardSchema, TestFunction) {
std::shared_ptr<Schema> schema = Schema::Build(desc, schema_content);
ASSERT_NE(schema, nullptr);
ASSERT_EQ(schema->get_desc(), desc);
ASSERT_EQ(schema->GetDesc(), desc);
json schema_json = schema->GetSchema();
ASSERT_EQ(schema_json["desc"], desc);
ASSERT_EQ(schema_json["schema"], schema_content);
ASSERT_EQ(schema->get_schema_id(), -1);
schema->set_schema_id(2);
ASSERT_EQ(schema->get_schema_id(), 2);
ASSERT_EQ(schema->GetSchemaID(), -1);
schema->SetSchemaID(2);
ASSERT_EQ(schema->GetSchemaID(), 2);
}
TEST_F(TestStatistics, StatisticPart) {
......@@ -137,8 +137,8 @@ TEST_F(TestStatistics, StatisticPart) {
ASSERT_NE(statistics, nullptr);
MS_LOG(INFO) << "test get_desc(), result: " << statistics->get_desc();
MS_LOG(INFO) << "test get_statistics, result: " << statistics->get_statistics().dump();
MS_LOG(INFO) << "test GetDesc(), result: " << statistics->GetDesc();
MS_LOG(INFO) << "test GetStatistics, result: " << statistics->GetStatistics().dump();
statistic_json["test"] = "test";
statistics = Statistics::Build(desc, statistic_json);
......
......@@ -194,8 +194,8 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) {
fw.Open(file_names);
uint64_t header_size = 1 << 14;
uint64_t page_size = 1 << 15;
fw.set_header_size(header_size);
fw.set_page_size(page_size);
fw.SetHeaderSize(header_size);
fw.SetPageSize(page_size);
// set shardHeader
fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
......@@ -331,8 +331,8 @@ TEST_F(TestShardWriter, TestShardWriterTrial) {
fw.Open(file_names);
uint64_t header_size = 1 << 14;
uint64_t page_size = 1 << 17;
fw.set_header_size(header_size);
fw.set_page_size(page_size);
fw.SetHeaderSize(header_size);
fw.SetPageSize(page_size);
// set shardHeader
fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
......@@ -466,8 +466,8 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) {
fw.Open(file_names);
uint64_t header_size = 1 << 14;
uint64_t page_size = 1 << 17;
fw.set_header_size(header_size);
fw.set_page_size(page_size);
fw.SetHeaderSize(header_size);
fw.SetPageSize(page_size);
// set shardHeader
fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
......@@ -567,8 +567,8 @@ TEST_F(TestShardWriter, DataCheck) {
fw.Open(file_names);
uint64_t header_size = 1 << 14;
uint64_t page_size = 1 << 17;
fw.set_header_size(header_size);
fw.set_page_size(page_size);
fw.SetHeaderSize(header_size);
fw.SetPageSize(page_size);
// set shardHeader
fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
......@@ -668,8 +668,8 @@ TEST_F(TestShardWriter, AllRawDataWrong) {
fw.Open(file_names);
uint64_t header_size = 1 << 14;
uint64_t page_size = 1 << 17;
fw.set_header_size(header_size);
fw.set_page_size(page_size);
fw.SetHeaderSize(header_size);
fw.SetPageSize(page_size);
// set shardHeader
fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册