未验证 提交 5e316620 编写于 作者: PhoenixTree2013's avatar PhoenixTree2013 提交者: GitHub

fix read dataset error for mysql (#551)

上级 bf150d00
......@@ -87,7 +87,6 @@ class DataAlign:
batch_size = 1000
num_iter, reminder = divmod(end_index - start_index, batch_size)
result_list = []
for i in range(num_iter):
sql_fmt = "select {} from {} where {} in {}"
......@@ -97,7 +96,6 @@ class DataAlign:
query_ids = tuple(intersect_ids[inner_start_index:inner_end_index])
sql = sql_fmt.format(
column_str, db_info["tableName"], db_info["index_column"], query_ids)
try:
cursor.execute(sql)
except Exception as e:
......@@ -115,7 +113,6 @@ class DataAlign:
inner_end_index = inner_start_index + reminder
sql = sql_fmt.format(column_str, db_info["tableName"], db_info["index_column"],
tuple(intersect_ids[inner_start_index:inner_end_index]))
try:
cursor.execute(sql)
except Exception as e:
......@@ -186,7 +183,6 @@ class DataAlign:
def generate_new_datast_from_mysql(self, meta_info, query_thread_num):
if not self.has_data_rows(meta_info["psiPath"]):
raise Exception("PSI result is empty, no intersection is found")
db_info = meta_info["localdata_path"]
# Connect to mysql server and create cursor.
try:
......@@ -201,38 +197,45 @@ class DataAlign:
raise e
# Get column name except for id column.
sql_template = ("SELECT column_name, data_type FROM information_schema.COLUMNS "
"WHERE TABLE_NAME='{}' and TABLE_SCHEMA='{}' ORDER BY column_name ASC;")
sql = sql_template.format(db_info["tableName"], db_info["dbName"])
# sql_template = ("SELECT column_name, data_type FROM information_schema.COLUMNS "
# "WHERE TABLE_NAME='{}' and TABLE_SCHEMA='{}' ORDER BY column_name ASC;")
# sql = sql_template.format(db_info["tableName"], db_info["dbName"])
# selected_columns = []
# try:
# cursor.execute(sql)
# except Exception as e:
# logger.error("Run sql 'desc {}' failed.".format(db_info["tableName"]))
# logger.error(e)
# raise e
# else:
# table_columns = []
# for col_info in cursor.fetchall():
# table_columns.append(col_info[0])
# index_column = table_columns[meta_info["index"][0]]
# db_info["index_column"] = index_column
# logger.info("The column corresponds to index {} is {}.".format(
# meta_info["index"], index_column))
# selected_columns = []
# for col_name in table_columns:
# selected_columns.append(col_name)
# for mysql, just support for one selected index
index = meta_info["index"][0]
schema = json.loads(db_info["schema"])
selected_columns = []
try:
cursor.execute(sql)
except Exception as e:
logger.error("Run sql 'desc {}' failed.".format(db_info["tableName"]))
logger.error(e)
raise e
else:
table_columns = []
for col_info in cursor.fetchall():
table_columns.append(col_info[0])
index_column = table_columns[meta_info["index"][0]]
db_info["index_column"] = index_column
logger.info("The column corresponds to index {} is {}.".format(
meta_info["index"], index_column))
selected_columns = []
for col_name in table_columns:
selected_columns.append(col_name)
for field in schema:
for field_name in field:
selected_columns.append(field_name)
index_col_name = selected_columns[index]
db_info["index_column"] = index_col_name
logger.info("Column name of table {} is {}.".format(db_info["tableName"], selected_columns))
selected_column_str = "`{}`".format(selected_columns[0])
for i in range(len(selected_columns) - 1):
new_str = "`{}`".format(selected_columns[i+1])
selected_column_str = selected_column_str + "," + new_str
# Collect all ids that PSI output.
intersect_ids = []
try:
......
......@@ -86,6 +86,7 @@ std::string CSVAccessInfo::toString() {
nlohmann::json js;
js["type"] = "csv";
js["data_path"] = this->file_path_;
js["schema"] = SchemaToJsonString();
ss << js;
return ss.str();
}
......
......@@ -163,6 +163,23 @@ retcode DataSetAccessInfo::SetDatasetSchema(std::vector<FieldType>&& schema_info
this->schema = std::move(schema_info);
return MakeArrowSchema();
}
std::string DataSetAccessInfo::SchemaToJsonString() {
auto schema = this->ArrowSchema();
if (schema == nullptr) {
return std::string("");
}
nlohmann::json js_schema = nlohmann::json::array();
for (int col_index = 0; col_index < schema->num_fields(); ++col_index) {
auto field = schema->field(col_index);
nlohmann::json item;
item[field->name()] = field->type()->id();
js_schema.emplace_back(item);
}
return js_schema.dump();
}
// cursor
std::shared_ptr<arrow::Schema>
Cursor::MakeArrowSchema(const std::vector<FieldType>& data_schema) {
......
......@@ -66,6 +66,7 @@ struct DataSetAccessInfo {
protected:
std::shared_ptr<arrow::DataType> MakeArrowDataType(int type);
retcode MakeArrowSchema();
std::string SchemaToJsonString();
public:
std::vector<FieldType> schema;
......
......@@ -48,6 +48,7 @@ std::string MySQLAccessInfo::toString() {
}
js["query_index"] = std::move(quey_col_info);
}
js["schema"] = SchemaToJsonString();
ss << std::setw(4) << js;
return ss.str();
}
......@@ -260,9 +261,9 @@ retcode MySQLCursor::fetchData(const std::string& query_sql,
int schema_fields = table_schema->num_fields();
auto& all_select_index = this->SelectedColumnIndex();
for (size_t i = 0; i < selected_fields; i++) {
int index = all_select_index[i];
if (index < schema_fields) {
auto& field_ptr = table_schema->field(index);
// int index = all_select_index[i];
if (i < schema_fields) {
auto& field_ptr = table_schema->field(i);
int field_type = field_ptr->type()->id();
VLOG(5) << "field_name: " << field_ptr->name() << " type: " << field_type;
auto array = arrow_wrapper::util::MakeArrowArray(field_type, result_data[i]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册