提交 2831641e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1030 [MD] remove dead code in mindrecord

Merge pull request !1030 from liyong126/mindrecord_remove_dead_code
......@@ -280,9 +280,6 @@ class ShardReader {
/// \brief read one row by one task
TASK_RETURN_CONTENT ConsumerOneTask(int task_id, uint32_t consumer_id);
/// \brief get all the column names by schema
vector<std::string> GetAllColumns();
/// \brief get one row from buffer in block-reader mode
std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>> GetRowFromBuffer(int bufId, int rowId);
......@@ -308,7 +305,6 @@ class ShardReader {
uint64_t page_size_; // page size
int shard_count_; // number of shards
std::shared_ptr<ShardHeader> shard_header_; // shard header
bool nlp_ = false; // NLP data
std::vector<sqlite3 *> database_paths_; // sqlite handle list
std::vector<string> file_paths_; // file paths
......
......@@ -90,8 +90,6 @@ class ShardSegment : public ShardReader {
std::string CleanUp(std::string fieldName);
std::tuple<std::vector<uint8_t>, json> GetImageLabel(std::vector<uint8_t> images, json label);
std::pair<MSRStatus, std::vector<uint8_t>> PackImages(int group_id, int shard_id, std::vector<uint64_t> offset);
std::vector<std::string> candidate_category_fields_;
......
......@@ -433,7 +433,6 @@ ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) {
}
ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns) {
std::lock_guard<std::mutex> lck(shard_locker_);
const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id);
if (SUCCESS != ret.first) {
return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>());
......@@ -455,7 +454,6 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const
ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id,
const std::pair<std::string, std::string> &criteria,
const std::vector<std::string> &columns) {
std::lock_guard<std::mutex> lck(shard_locker_);
const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id);
if (SUCCESS != ret.first) {
return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>());
......@@ -532,13 +530,6 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
return res;
}
void ShardReader::CheckNlp() {
nlp_ = false;
return;
}
bool ShardReader::GetNlpFlag() { return nlp_; }
std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
std::vector<std::string> blob_fields;
for (auto &p : GetShardHeader()->GetSchemas()) {
......@@ -547,7 +538,7 @@ std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
blob_fields.assign(fields.begin(), fields.end());
break;
}
return std::make_pair(nlp_ ? kNLP : kCV, blob_fields);
return std::make_pair(kCV, blob_fields);
}
void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns) {
......@@ -828,18 +819,11 @@ MSRStatus ShardReader::Open(const std::vector<std::string> &file_paths, bool loa
if (n_consumer < kMinConsumerCount) {
n_consumer = kMinConsumerCount;
}
CheckNlp();
// dead code
if (nlp_) {
selected_columns_ = selected_columns;
} else {
vector<std::string> blob_fields = GetBlobFields().second;
for (unsigned int i = 0; i < selected_columns.size(); ++i) {
if (!std::any_of(blob_fields.begin(), blob_fields.end(),
[&selected_columns, i](std::string item) { return selected_columns[i] == item; })) {
selected_columns_.push_back(selected_columns[i]);
}
vector<std::string> blob_fields = GetBlobFields().second;
for (unsigned int i = 0; i < selected_columns.size(); ++i) {
if (!std::any_of(blob_fields.begin(), blob_fields.end(),
[&selected_columns, i](std::string item) { return selected_columns[i] == item; })) {
selected_columns_.push_back(selected_columns[i]);
}
}
selected_columns_ = selected_columns;
......@@ -895,7 +879,6 @@ MSRStatus ShardReader::OpenPy(const std::vector<std::string> &file_paths, bool l
if (Open(n_consumer) == FAILED) {
return FAILED;
}
CheckNlp();
// Initialize argument
shard_count_ = static_cast<int>(file_paths_.size());
n_consumer_ = n_consumer;
......@@ -918,10 +901,7 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) {
interrupt_ = true;
return FAILED;
}
MS_LOG(INFO) << "Launching read threads.";
if (isSimpleReader) return SUCCESS;
// Start provider consumer threads
thread_set_ = std::vector<std::thread>(n_consumer_);
if (n_consumer_ <= 0 || n_consumer_ > kMaxConsumerCount) {
......@@ -940,29 +920,9 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) {
return SUCCESS;
}
vector<std::string> ShardReader::GetAllColumns() {
vector<std::string> columns;
if (nlp_) {
for (auto &c : selected_columns_) {
for (auto &p : GetShardHeader()->GetSchemas()) {
auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm.
for (auto it = schema.begin(); it != schema.end(); ++it) {
if (it.key() == c) {
columns.push_back(c);
}
}
}
}
} else {
columns = selected_columns_;
}
return columns;
}
MSRStatus ShardReader::CreateTasksByBlock(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
vector<std::string> columns = GetAllColumns();
CheckIfColumnInIndex(columns);
CheckIfColumnInIndex(selected_columns_);
for (const auto &rg : row_group_summary) {
auto shard_id = std::get<0>(rg);
auto group_id = std::get<1>(rg);
......@@ -974,9 +934,7 @@ MSRStatus ShardReader::CreateTasksByBlock(const std::vector<std::tuple<int, int,
MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::shared_ptr<ShardOperator> &op) {
vector<std::string> columns = GetAllColumns();
CheckIfColumnInIndex(columns);
CheckIfColumnInIndex(selected_columns_);
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
auto categories = category_op->GetCategories();
int64_t num_elements = category_op->GetNumElements();
......@@ -1011,7 +969,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
auto shard_id = std::get<0>(rg);
auto group_id = std::get<1>(rg);
auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], columns);
auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_);
if (SUCCESS != std::get<0>(details)) {
return FAILED;
}
......@@ -1037,10 +995,9 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
vector<std::string> columns = GetAllColumns();
CheckIfColumnInIndex(columns);
CheckIfColumnInIndex(selected_columns_);
auto ret = ReadAllRowGroup(columns);
auto ret = ReadAllRowGroup(selected_columns_);
if (std::get<0>(ret) != SUCCESS) {
return FAILED;
}
......@@ -1202,28 +1159,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
// Deliver batch data to output map
std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
if (nlp_) {
// dead code
json blob_fields = json::from_msgpack(images_with_exact_columns);
json merge;
if (selected_columns_.size() > 0) {
for (auto &col : selected_columns_) {
if (blob_fields.find(col) != blob_fields.end()) {
merge[col] = blob_fields[col];
}
}
} else {
merge = blob_fields;
}
auto label_json = std::get<2>(task);
if (label_json != nullptr) {
merge.update(label_json);
}
batch.emplace_back(std::vector<uint8_t>{}, std::move(merge));
} else {
batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task)));
}
batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task)));
return std::make_pair(SUCCESS, std::move(batch));
}
......
......@@ -296,8 +296,7 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardS
if (SUCCESS != ret1.first) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
auto imageLabel = GetImageLabel(ret1.second, labels[i]);
page.emplace_back(std::move(std::get<0>(imageLabel)), std::move(std::get<1>(imageLabel)));
page.emplace_back(std::move(ret1.second), std::move(labels[i]));
}
}
}
......@@ -371,35 +370,7 @@ std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() {
blob_fields.assign(fields.begin(), fields.end());
break;
}
return std::make_pair(GetNlpFlag() ? kNLP : kCV, blob_fields);
}
std::tuple<std::vector<uint8_t>, json> ShardSegment::GetImageLabel(std::vector<uint8_t> images, json label) {
if (GetNlpFlag()) {
vector<std::string> columns;
for (auto &p : GetShardHeader()->GetSchemas()) {
auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm.
auto schema_items = schema.items();
using it_type = decltype(schema_items.begin());
std::transform(schema_items.begin(), schema_items.end(), std::back_inserter(columns),
[](it_type item) { return item.key(); });
}
json blob_fields = json::from_msgpack(images);
json merge;
if (columns.size() > 0) {
for (auto &col : columns) {
if (blob_fields.find(col) != blob_fields.end()) {
merge[col] = blob_fields[col];
}
}
} else {
merge = blob_fields;
}
merge.update(label);
return std::make_tuple(std::vector<uint8_t>{}, merge);
}
return std::make_tuple(images, label);
return std::make_pair(kCV, blob_fields);
}
std::string ShardSegment::CleanUp(std::string field_name) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册