提交 cd90dfa1 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1834 sorts column names for tfreader during schema creation

Merge pull request !1834 from Peilin/tf-reader-column-order-fix
...@@ -15,14 +15,14 @@ ...@@ -15,14 +15,14 @@
*/ */
#include "dataset/engine/datasetops/source/tf_reader_op.h" #include "dataset/engine/datasetops/source/tf_reader_op.h"
#include <cmath> #include <algorithm>
#include <condition_variable>
#include <future> #include <future>
#include <iomanip> #include <iomanip>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <string>
#include <utility> #include <utility>
#include <unordered_map> #include <vector>
#include "proto/example.pb.h" #include "proto/example.pb.h"
#include "./securec.h" #include "./securec.h"
...@@ -905,7 +905,7 @@ Status TFReaderOp::LoadIntList(const ColDescriptor &current_col, const dataengin ...@@ -905,7 +905,7 @@ Status TFReaderOp::LoadIntList(const ColDescriptor &current_col, const dataengin
return Status::OK(); return Status::OK();
} }
Status TFReaderOp::CreateSchema(const std::string tf_file, const std::vector<std::string> &columns_to_load) { Status TFReaderOp::CreateSchema(const std::string tf_file, std::vector<std::string> columns_to_load) {
std::ifstream reader; std::ifstream reader;
reader.open(tf_file); reader.open(tf_file);
...@@ -926,12 +926,14 @@ Status TFReaderOp::CreateSchema(const std::string tf_file, const std::vector<std ...@@ -926,12 +926,14 @@ Status TFReaderOp::CreateSchema(const std::string tf_file, const std::vector<std
const dataengine::Features &example_features = example.features(); const dataengine::Features &example_features = example.features();
const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature(); const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
std::vector<std::string> columns = columns_to_load;
if (columns_to_load.empty()) if (columns_to_load.empty()) {
(void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(columns), (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(columns_to_load),
[](const auto &it) -> std::string { return it.first; }); [](const auto &it) -> std::string { return it.first; });
for (const auto &curr_col_name : columns) { std::sort(columns_to_load.begin(), columns_to_load.end());
}
for (const auto &curr_col_name : columns_to_load) {
auto it = feature_map.find(curr_col_name); auto it = feature_map.find(curr_col_name);
if (it == feature_map.end()) { if (it == feature_map.end()) {
RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name);
......
...@@ -335,7 +335,7 @@ class TFReaderOp : public ParallelOp { ...@@ -335,7 +335,7 @@ class TFReaderOp : public ParallelOp {
// Reads one row of data from a tf file and creates a schema based on that row // Reads one row of data from a tf file and creates a schema based on that row
// @return Status - the error code returned. // @return Status - the error code returned.
Status CreateSchema(const std::string tf_file, const std::vector<std::string> &columns_to_load); Status CreateSchema(const std::string tf_file, std::vector<std::string> columns_to_load);
// Meant to be called async. Will read files in the range [begin, end) and return the total rows // Meant to be called async. Will read files in the range [begin, end) and return the total rows
// @param filenames - a list of tf data filenames. // @param filenames - a list of tf data filenames.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册