未验证 提交 127ec826 编写于 作者: X Xiangyu Zhang 提交者: GitHub

The MPC_Compare function may send the value to the incorrect party when...

The MPC_Compare function may send the value to the incorrect party when handling integer values, fix it. (#528)

* Use more user-friendly parameter names for arithmetic components, and ensure that the task context structure is properly populated within the PushTaskReply structure.

* The MPC_Compare function may send the value to the incorrect party when handling integer values, fix it.
上级 bb1c3fa8
......@@ -318,7 +318,7 @@ int MissingProcess::loadParams(primihub::rpc::Task &task) {
LOG(ERROR) << "no data set found for party name: " << this->party_name();
return -1;
}
const auto& dataset = it->second.data();
const auto &dataset = it->second.data();
auto iter = dataset.find("Data_File");
if (iter == dataset.end()) {
LOG(ERROR) << "no dataset found for dataset name Data_File";
......@@ -372,7 +372,7 @@ int MissingProcess::loadParams(primihub::rpc::Task &task) {
}
LOG(INFO) << "New id of new dataset is " << new_dataset_id_ << ". "
<< "new dataset path: " << new_dataset_path_;
<< "new dataset path: " << new_dataset_path_;
return 0;
}
......@@ -448,9 +448,9 @@ int MissingProcess::initPartyComm(void) {
party_id_ = this->party_config_.SelfPartyId();
LOG(INFO) << "local_id_local_id_: " << party_id_;
LOG(INFO) << "next_party: " << next_party_name
<< " detail: " << next_party_info.to_string();
<< " detail: " << next_party_info.to_string();
LOG(INFO) << "prev_party: " << prev_party_name
<< " detail: " << prev_party_info.to_string();
<< " detail: " << prev_party_info.to_string();
return 0;
}
#endif
......@@ -597,10 +597,7 @@ int MissingProcess::execute() {
// Detect string that can't convert into int64_t value.
int ret = 0;
int64_t i64_val = 0;
LOG(WARNING) << "str_array->length() :" << str_array->length()
<< ".";
for (int64_t j = 0; j < str_array->length(); j++) {
if (str_array->IsNull(j)) {
LOG(WARNING) << "Find missing value in column " << iter->first
......@@ -1248,7 +1245,7 @@ int MissingProcess::saveModel(void) {
// int pos = data_file_path_.rfind(delimiter);
// std::string new_path = data_file_path_.substr(0, pos) + "_missing.csv";
auto& new_path = new_dataset_path_;
auto &new_path = new_dataset_path_;
std::shared_ptr<DataDriver> driver =
DataDirverFactory::getDriver("CSV", dataset_service_->getNodeletAddr());
......@@ -1378,7 +1375,7 @@ int MissingProcess::_LoadDatasetFromDB(std::string &source) {
// auto cursor = driver->initCursor(new_path);
auto cursor = driver->read(source);
auto sql_cursor = dynamic_cast<SQLiteCursor*>(cursor.get());
auto sql_cursor = dynamic_cast<SQLiteCursor *>(cursor.get());
if (sql_cursor == nullptr) {
return -1;
}
......
......@@ -58,30 +58,30 @@ private:
void _spiltStr(string str, const string &split, std::vector<string> &strlist);
std::unique_ptr<MPCOperator> mpc_op_exec_;
std::unique_ptr<MPCOperator> mpc_op_exec_{nullptr};
std::string job_id_;
std::string task_id_;
std::string job_id_{""};
std::string task_id_{""};
#ifdef MPC_SOCKET_CHANNEL
IOService ios_;
Session ep_next_;
Session ep_prev_;
std::string next_ip_, prev_ip_;
uint16_t next_port_, prev_port_;
std::string next_ip_{""}, prev_ip_{""};
uint16_t next_port_{0}, prev_port_{0};
#else
ABY3PartyConfig party_config_;
uint16_t local_party_id_;
uint16_t next_party_id_;
uint16_t prev_party_id_;
uint16_t local_party_id_{0};
uint16_t next_party_id_{0};
uint16_t prev_party_id_{0};
primihub::Node local_node_;
std::shared_ptr<network::IChannel> base_channel_next_;
std::shared_ptr<network::IChannel> base_channel_prev_;
std::shared_ptr<network::IChannel> base_channel_next_{nullptr};
std::shared_ptr<network::IChannel> base_channel_prev_{nullptr};
std::shared_ptr<MpcChannel> mpc_channel_next_;
std::shared_ptr<MpcChannel> mpc_channel_prev_;
std::shared_ptr<MpcChannel> mpc_channel_next_{nullptr};
std::shared_ptr<MpcChannel> mpc_channel_prev_{nullptr};
std::map<uint16_t, primihub::Node> partyid_node_map_;
#endif
......@@ -89,20 +89,20 @@ private:
std::map<std::string, uint32_t> col_and_dtype_;
std::vector<std::string> local_col_names;
std::string data_file_path_;
std::string replace_type_;
std::string conn_info_;
std::shared_ptr<arrow::Table> table;
std::string data_file_path_{""};
std::string replace_type_{""};
std::string conn_info_{""};
std::shared_ptr<arrow::Table> table{nullptr};
std::map<std::string, std::vector<int>> db_both_index;
bool use_db;
std::string table_name;
std::string node_id_;
uint32_t party_id_;
bool use_db{false};
std::string table_name{""};
std::string node_id_{""};
uint32_t party_id_{0};
std::string new_dataset_id_;
std::string new_dataset_path_;
std::string platform_type_ = "";
std::string new_dataset_id_{""};
std::string new_dataset_path_{""};
std::string platform_type_{""};
template <class T>
void replaceValue(map<std::string, uint32_t>::iterator &iter,
......@@ -122,9 +122,8 @@ private:
std::shared_ptr<arrow::ChunkedArray> chunk_array =
std::make_shared<arrow::ChunkedArray>(new_array);
bool isDouble = std::is_same<T, double>::value;
std::shared_ptr<arrow::Field> field;
if (!isDouble) {
if (!need_double) {
field = std::make_shared<arrow::Field>(iter->first, arrow::int64());
} else {
field = std::make_shared<arrow::Field>(iter->first, arrow::float64());
......@@ -133,11 +132,6 @@ private:
LOG(INFO) << "Replace column " << iter->first
<< " with new array in table.";
LOG(INFO) << "col_index:" << col_index;
LOG(INFO) << "name:" << field->name();
LOG(INFO) << "type:" << field->type();
LOG(INFO) << "table->type:" << table->field(col_index)->type();
auto result = table->SetColumn(col_index, field, chunk_array);
if (!result.ok()) {
std::stringstream ss;
......
......@@ -341,7 +341,7 @@ void MPCOperator::MPC_Compare(i64Matrix &m, sbMatrix &sh_res) {
if (partyIdx == (i + 1) % 3)
mPrev->recv(shape);
else if (partyIdx == (i + 2) % 3)
mPrev->recv(shape);
mNext->recv(shape);
else
throw std::runtime_error("Message recv logic error.");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册