提交 d4d236bc 编写于 作者: J jonyguo

fix: use MindDataset by column_names get data error in some situation

上级 d004ef22
......@@ -165,12 +165,22 @@ Status MindRecordOp::Init() {
Status MindRecordOp::SetColumnsBlob() {
columns_blob_ = shard_reader_->get_blob_fields().second;
// get the exactly blob fields by columns_to_load_
std::vector<std::string> columns_blob_exact;
for (auto &blob_field : columns_blob_) {
for (auto &column : columns_to_load_) {
if (column.compare(blob_field) == 0) {
columns_blob_exact.push_back(blob_field);
break;
}
}
}
columns_blob_index_ = std::vector<int32_t>(columns_to_load_.size(), -1);
int32_t iBlob = 0;
for (uint32_t i = 0; i < columns_blob_.size(); ++i) {
if (column_name_mapping_.count(columns_blob_[i])) {
columns_blob_index_[column_name_mapping_[columns_blob_[i]]] = iBlob++;
}
for (auto &blob_exact : columns_blob_exact) {
columns_blob_index_[column_name_mapping_[blob_exact]] = iBlob++;
}
return Status::OK();
}
......
......@@ -294,6 +294,10 @@ class ShardReader {
/// \brief get number of classes
int64_t GetNumClasses(const std::string &file_path, const std::string &category_field);
/// \brief get exactly blob fields data by indices
std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes,
std::vector<uint32_t> &ordered_selected_columns_index);
protected:
uint64_t header_size_; // header size
uint64_t page_size_; // page size
......
......@@ -790,6 +790,8 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
n_consumer = kMinConsumerCount;
}
CheckNlp();
// dead code
if (nlp_) {
selected_columns_ = selected_columns;
} else {
......@@ -801,6 +803,7 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
}
}
}
selected_columns_ = selected_columns;
if (CheckColumnList(selected_columns_) == FAILED) {
MS_LOG(ERROR) << "Illegal column list";
......@@ -1060,6 +1063,36 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
return SUCCESS;
}
std::vector<uint8_t> ShardReader::ExtractBlobFieldBySelectColumns(
std::vector<uint8_t> &blob_fields_bytes, std::vector<uint32_t> &ordered_selected_columns_index) {
std::vector<uint8_t> exactly_blob_fields_bytes;
auto uint64_from_bytes = [&](int64_t pos) {
uint64_t result = 0;
for (uint64_t n = 0; n < kInt64Len; n++) {
result = (result << 8) + blob_fields_bytes[pos + n];
}
return result;
};
// get the exactly blob fields
uint32_t current_index = 0;
uint64_t current_offset = 0;
uint64_t data_len = uint64_from_bytes(current_offset);
while (current_offset < blob_fields_bytes.size()) {
if (std::any_of(ordered_selected_columns_index.begin(), ordered_selected_columns_index.end(),
[&current_index](uint32_t &index) { return index == current_index; })) {
exactly_blob_fields_bytes.insert(exactly_blob_fields_bytes.end(), blob_fields_bytes.begin() + current_offset,
blob_fields_bytes.begin() + current_offset + kInt64Len + data_len);
}
current_index++;
current_offset += kInt64Len + data_len;
data_len = uint64_from_bytes(current_offset);
}
return exactly_blob_fields_bytes;
}
TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) {
// All tasks are done
if (task_id >= static_cast<int>(tasks_.Size())) {
......@@ -1077,6 +1110,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>());
}
const std::shared_ptr<Page> &page = ret.second;
// Pack image list
std::vector<uint8_t> images(addr[1] - addr[0]);
auto file_offset = header_size_ + page_size_ * (page->get_page_id()) + addr[0];
......@@ -1096,10 +1130,42 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>());
}
// extract the exactly blob bytes by selected columns
std::vector<uint8_t> images_with_exact_columns;
if (selected_columns_.size() == 0) {
images_with_exact_columns = images;
} else {
auto blob_fields = get_blob_fields();
std::vector<uint32_t> ordered_selected_columns_index;
uint32_t index = 0;
for (auto &blob_field : blob_fields.second) {
for (auto &field : selected_columns_) {
if (field.compare(blob_field) == 0) {
ordered_selected_columns_index.push_back(index);
break;
}
}
index++;
}
if (ordered_selected_columns_index.size() != 0) {
// extract the images
if (blob_fields.second.size() == 1) {
if (ordered_selected_columns_index.size() == 1) {
images_with_exact_columns = images;
}
} else {
images_with_exact_columns = ExtractBlobFieldBySelectColumns(images, ordered_selected_columns_index);
}
}
}
// Deliver batch data to output map
std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
if (nlp_) {
json blob_fields = json::from_msgpack(images);
// dead code
json blob_fields = json::from_msgpack(images_with_exact_columns);
json merge;
if (selected_columns_.size() > 0) {
......@@ -1117,7 +1183,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
}
batch.emplace_back(std::vector<uint8_t>{}, std::move(merge));
} else {
batch.emplace_back(std::move(images), std::move(std::get<2>(task)));
batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task)));
}
return std::make_pair(SUCCESS, std::move(batch));
}
......
......@@ -92,15 +92,25 @@ def populate_data(raw, blob, columns, blob_fields, schema):
if raw:
# remove dummy fileds
raw = {k: v for k, v in raw.items() if k in schema}
else:
raw = {}
if not blob_fields:
return raw
# Get the order preserving sequence of columns in blob
ordered_columns = []
if columns:
for blob_field in blob_fields:
if blob_field in columns:
ordered_columns.append(blob_field)
else:
ordered_columns = blob_fields
blob_bytes = bytes(blob)
def _render_raw(field, blob_data):
data_type = schema[field]['type']
data_shape = schema[field]['shape'] if 'shape' in schema[field] else []
if columns and field not in columns:
return
if data_shape:
try:
raw[field] = np.reshape(np.frombuffer(blob_data, dtype=data_type), data_shape)
......@@ -110,7 +120,9 @@ def populate_data(raw, blob, columns, blob_fields, schema):
raw[field] = blob_data
if len(blob_fields) == 1:
_render_raw(blob_fields[0], blob_bytes)
if len(ordered_columns) == 1:
_render_raw(blob_fields[0], blob_bytes)
return raw
return raw
def _int_from_bytes(xbytes: bytes) -> int:
......@@ -125,6 +137,6 @@ def populate_data(raw, blob, columns, blob_fields, schema):
start += 8
return blob_bytes[start : start + n_bytes]
for i, blob_field in enumerate(blob_fields):
for i, blob_field in enumerate(ordered_columns):
_render_raw(blob_field, _blob_at_position(i))
return raw
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册