未验证 提交 5551af2d 编写于 作者: X Xiangyu Zhang 提交者: GitHub

Fix segment fault in arithmetic component. (#483)

* Working on debug segment fault error in MPC ExpressExecutor.

* Added support: Column names can be strings of any length

* modified params name of cmp

* Replace the party ID with the party name for ease of use.

* Remove unnecessary checks and debug output.
上级 6eee25c1
#include "src/primihub/algorithm/arithmetic.h"
#include <arrow/api.h>
#include <arrow/array.h>
#include <arrow/result.h>
#include "src/primihub/algorithm/arithmetic.h"
#include "src/primihub/data_store/csv/csv_driver.h"
#include "src/primihub/data_store/factory.h"
using arrow::Array;
using arrow::DoubleArray;
using arrow::Int64Array;
using arrow::Table;
namespace primihub {
void spiltStr(string str, const string &split, std::vector<string> &strlist) {
strlist.clear();
......@@ -33,8 +34,7 @@ ArithmeticExecutor<Dbit>::ArithmeticExecutor(
: AlgorithmBase(dataset_service) {
this->algorithm_name_ = "arithmetic";
auto& node_map = config.node_map;
// LOG(INFO) << node_map.size();
auto &node_map = config.node_map;
std::map<uint16_t, rpc::Node> party_id_node_map;
for (auto iter = node_map.begin(); iter != node_map.end(); iter++) {
rpc::Node &node = iter->second;
......@@ -66,12 +66,6 @@ ArithmeticExecutor<Dbit>::ArithmeticExecutor(
// A local server addr.
uint16_t port = node.vm(0).next().port();
// next_addr_ = std::make_pair(node.ip(), port);
// // A remote server addr.
// prev_addr_ =
// std::make_pair(node.vm(0).prev().ip(), node.vm(0).prev().port());
next_ip_ = node.ip();
next_port_ = port;
......@@ -80,12 +74,6 @@ ArithmeticExecutor<Dbit>::ArithmeticExecutor(
} else {
rpc::Node &node = party_id_node_map[2];
// Two remote server addr.
// next_addr_ =
// std::make_pair(node.vm(0).next().ip(), node.vm(0).next().port());
// prev_addr_ =
// std::make_pair(node.vm(0).prev().ip(), node.vm(0).prev().port());
next_ip_ = node.vm(0).next().ip();
next_port_ = node.vm(0).next().port();
......@@ -94,23 +82,78 @@ ArithmeticExecutor<Dbit>::ArithmeticExecutor(
}
}
template <Decimal Dbit>
int ArithmeticExecutor<Dbit>::_getPartyIdWithPartyName(
const std::string &party_name, uint16_t &party_id) {
auto iter = party_info_.find(party_name);
if (iter == party_info_.end())
return -1;
party_id = iter->second;
return 0;
}
template <Decimal Dbit>
void ArithmeticExecutor<Dbit>::_fillPartyNameAndPartyId(
const primihub::rpc::Task &task) {
const auto &all_info = task.party_access_info();
for (const auto &party_info : all_info) {
if (party_info.first == SCHEDULER_NODE)
continue;
uint16_t party_id = party_info.second.vm(0).party_id();
std::string party_name = party_info.first;
party_info_.insert(std::make_pair(party_info.first, party_id));
LOG(INFO) << "Party " << party_name << " has party id " << party_id << ".";
}
}
template <Decimal Dbit>
int ArithmeticExecutor<Dbit>::loadParams(primihub::rpc::Task &task) {
auto param_map = task.params().param_map();
try {
data_file_path_ = param_map["Data_File"].value_string();
// col_and_owner
{
std::string party_name = task.party_name();
auto &datasets = task.party_datasets();
auto dataset_iter = datasets.find(party_name);
if (dataset_iter == datasets.end()) {
std::stringstream ss;
ss << "Can't get dataset id with party name " << party_name << ".";
throw std::runtime_error(ss.str());
}
auto &data = dataset_iter->second.data();
auto data_iter = data.begin();
dataset_id_ = data_iter->second;
LOG(INFO) << "Dataset id of party " << party_name << " is " << dataset_id_
<< ".";
}
_fillPartyNameAndPartyId(task);
std::string col_and_owner = param_map["Col_And_Owner"].value_string();
std::vector<string> tmp1, tmp2, tmp3;
spiltStr(col_and_owner, ";", tmp1);
int ret = 0;
for (auto itr = tmp1.begin(); itr != tmp1.end(); itr++) {
int pos = itr->find('-');
std::string col = itr->substr(0, pos);
int owner = std::atoi((itr->substr(pos + 1, itr->size())).c_str());
col_and_owner_.insert(make_pair(col, owner));
// LOG(INFO) << col << ":" << owner;
std::string party_name = itr->substr(pos + 1, itr->size());
uint16_t party_id = 0;
ret = _getPartyIdWithPartyName(party_name, party_id);
if (ret) {
std::stringstream ss;
ss << "Get party id with party name " << party_name << " failed.";
LOG(ERROR) << ss.str();
throw std::runtime_error(ss.str());
}
col_and_owner_.insert(make_pair(col, party_id));
LOG(INFO) << "Column " << col << " belong to party " << party_name << ".";
}
// LOG(INFO) << col_and_owner;
std::string col_and_dtype = param_map["Col_And_Dtype"].value_string();
spiltStr(col_and_dtype, ";", tmp2);
......@@ -119,9 +162,8 @@ int ArithmeticExecutor<Dbit>::loadParams(primihub::rpc::Task &task) {
std::string col = itr->substr(0, pos);
int dtype = std::atoi((itr->substr(pos + 1, itr->size())).c_str());
col_and_dtype_.insert(make_pair(col, dtype));
// LOG(INFO) << col << ":" << dtype;
LOG(INFO) << "Dtype of column " << col << " is " << dtype << ".";
}
// LOG(INFO) << col_and_dtype;
expr_ = param_map["Expr"].value_string();
is_cmp = false;
......@@ -144,15 +186,23 @@ int ArithmeticExecutor<Dbit>::loadParams(primihub::rpc::Task &task) {
} else {
mpc_exec_ = new MPCExpressExecutor<Dbit>();
}
// LOG(INFO) << expr_;
std::string parties = param_map["Parties"].value_string();
spiltStr(parties, ";", tmp3);
for (auto itr = tmp3.begin(); itr != tmp3.end(); itr++) {
uint32_t party = std::atoi((*itr).c_str());
parties_.push_back(party);
// LOG(INFO) << party;
uint16_t party_id = 0;
ret = _getPartyIdWithPartyName(*itr, party_id);
if (ret) {
std::stringstream ss;
ss << "Can't get party id with party name " << *itr << ".";
LOG(ERROR) << ss.str();
throw std::runtime_error(ss.str());
}
parties_.emplace_back(party_id);
LOG(INFO) << "Reveal result to party " << *itr << ", party id "
<< party_id << ".";
}
// LOG(INFO) << parties;
res_name_ = param_map["ResFileName"].value_string();
} catch (std::exception &e) {
......@@ -164,8 +214,7 @@ int ArithmeticExecutor<Dbit>::loadParams(primihub::rpc::Task &task) {
}
template <Decimal Dbit> int ArithmeticExecutor<Dbit>::loadDataset() {
int ret = _LoadDatasetFromCSV(data_file_path_);
// file reading error or file empty
int ret = _LoadDatasetFromCSV(dataset_id_);
if (ret <= 0) {
LOG(ERROR) << "Load dataset for train failed.";
return -1;
......@@ -174,6 +223,7 @@ template <Decimal Dbit> int ArithmeticExecutor<Dbit>::loadDataset() {
if (is_cmp) {
return 0;
}
mpc_exec_->initColumnConfig(party_id_);
for (auto &pair : col_and_owner_)
mpc_exec_->importColumnOwner(pair.first, pair.second);
......@@ -208,22 +258,28 @@ template <Decimal Dbit> int ArithmeticExecutor<Dbit>::execute() {
try {
sbMatrix sh_res;
f64Matrix<Dbit> m;
if (col_and_owner_[expr_.substr(4, 1)] == party_id_) {
m.resize(1, col_and_val_double[expr_.substr(4, 1)].size());
for (size_t i = 0; i < col_and_val_double[expr_.substr(4, 1)].size();
i++)
m(i) = col_and_val_double[expr_.substr(4, 1)][i];
// CMP(col0,col1)
int pos = expr_.find(',');
int pos_end = expr_.find(')');
std::string cmp_par_1 = expr_.substr(4, pos - 4);
LOG(INFO) << "cmp_par_1: " << cmp_par_1;
std::string cmp_par_2 = expr_.substr(pos + 1, pos_end - pos - 1);
LOG(INFO) << "cmp_par_2: " << cmp_par_2;
if (col_and_owner_[cmp_par_1] == party_id_) {
m.resize(1, col_and_val_double[cmp_par_1].size());
for (size_t i = 0; i < col_and_val_double[cmp_par_1].size(); i++)
m(i) = col_and_val_double[cmp_par_1][i];
mpc_op_exec_->MPC_Compare(m, sh_res);
} else if (col_and_owner_[expr_.substr(6, 1)] == party_id_) {
m.resize(1, col_and_val_double[expr_.substr(6, 1)].size());
for (size_t i = 0; i < col_and_val_double[expr_.substr(6, 1)].size();
i++)
m(i) = col_and_val_double[expr_.substr(6, 1)][i];
} else if (col_and_owner_[cmp_par_2] == party_id_) {
m.resize(1, col_and_val_double[cmp_par_2].size());
for (size_t i = 0; i < col_and_val_double[cmp_par_2].size(); i++)
m(i) = col_and_val_double[cmp_par_2][i];
mpc_op_exec_->MPC_Compare(m, sh_res);
} else
mpc_op_exec_->MPC_Compare(sh_res);
// reveal
for (const auto& party : parties_) {
for (const auto &party : parties_) {
if (party_id_ == party) {
i64Matrix tmp = mpc_op_exec_->reveal(sh_res);
for (size_t i = 0; i < tmp.rows(); i++)
......@@ -234,28 +290,21 @@ template <Decimal Dbit> int ArithmeticExecutor<Dbit>::execute() {
}
} catch (std::exception &e) {
LOG(ERROR) << "In party " << party_id_ << ":\n" << e.what() << ".";
return -1;
}
return 0;
}
try {
mpc_exec_->runMPCEvaluate();
if (mpc_exec_->isFP64RunMode()) {
mpc_exec_->revealMPCResult(parties_, final_val_double_);
// for (auto itr = final_val_double_.begin(); itr !=
// final_val_double_.end();
// itr++)
// LOG(INFO) << *itr;
} else {
mpc_exec_->revealMPCResult(parties_, final_val_int64_);
// for (auto itr = final_val_int64_.begin(); itr !=
// final_val_int64_.end();
// itr++)
// LOG(INFO) << *itr;
}
} catch (const std::exception &e) {
std::string msg = "In party 0, ";
msg = msg + e.what();
throw std::runtime_error(msg);
throw std::runtime_error(e.what());
}
return 0;
}
......@@ -317,7 +366,8 @@ template <Decimal Dbit> int ArithmeticExecutor<Dbit>::saveModel(void) {
std::shared_ptr<CSVDriver> csv_driver =
std::dynamic_pointer_cast<CSVDriver>(driver);
std::string filepath = "data/" + res_name_ + ".csv";
std::string filepath = res_name_;
int ret = 0;
if (col_and_val_double.size() != 0)
ret = csv_driver->write(table, filepath);
......@@ -334,20 +384,32 @@ template <Decimal Dbit> int ArithmeticExecutor<Dbit>::saveModel(void) {
template <Decimal Dbit>
int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &filename) {
std::string nodeaddr("test address"); // TODO
std::shared_ptr<DataDriver> driver =
DataDirverFactory::getDriver("CSV", nodeaddr);
auto cursor = driver->read(filename);
auto ds = cursor->read();
auto driver = this->datasetService()->getDriver(dataset_id_);
if (driver == nullptr) {
LOG(ERROR) << "get dataset read driver for dataset id: " << dataset_id_
<< " failed";
return -1;
}
auto cursor = driver->read();
if (cursor == nullptr) {
LOG(ERROR) << "get read cursor for dataset id: " << dataset_id_
<< " failed";
return -1;
}
std::shared_ptr<Dataset> ds = cursor->read();
if (ds == nullptr) {
LOG(ERROR) << "get data for dataset failed";
return -1;
}
std::shared_ptr<Table> table = std::get<std::shared_ptr<Table>>(ds->data);
// Label column.
std::vector<std::string> col_names = table->ColumnNames();
// for (auto itr = col_names.begin(); itr != col_names.end(); itr++) {
// LOG(INFO) << *itr;
// }
bool errors = false;
int num_col = table->num_columns();
// 'array' include values in a column of csv file.
int chunk_num = table->column(num_col - 1)->chunks().size();
int64_t array_len = 0;
......@@ -359,10 +421,12 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &filename) {
LOG(INFO) << "Label column '" << col_names[num_col - 1] << "' has "
<< array_len << " values.";
// Force the same value count in every column.
for (int i = 0; i < num_col; i++) {
int chunk_num = table->column(i)->chunks().size();
if (col_and_dtype_.count(col_names[i]) != 1)
continue;
if (col_and_dtype_[col_names[i]] == 0) {
if (table->schema()->GetFieldByName(col_names[i])->type()->id() != 9) {
LOG(ERROR) << "Local data type is inconsistent with the demand data "
......@@ -370,6 +434,7 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &filename) {
"double!Please input consistent data type!";
return -1;
}
std::vector<int64_t> tmp_data;
int64_t tmp_len = 0;
for (int k = 0; k < chunk_num; k++) {
......@@ -378,7 +443,6 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &filename) {
tmp_len += array->length();
for (int64_t j = 0; j < array->length(); j++) {
tmp_data.push_back(array->Value(j));
// LOG(INFO) << array->Value(j);
}
}
if (tmp_len != array_len) {
......@@ -389,13 +453,6 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &filename) {
}
col_and_val_int.insert(
pair<string, std::vector<int64_t>>(col_names[i], tmp_data));
// for (auto itr = col_and_val_int.begin(); itr != col_and_val_int.end();
// itr++) {
// LOG(INFO) << itr->first;
// auto tmp_vec = itr->second;
// for (auto iter = tmp_vec.begin(); iter != tmp_vec.end(); iter++)
// LOG(INFO) << *iter;
// }
} else {
std::vector<double> tmp_data;
int64_t tmp_len = 0;
......@@ -406,7 +463,6 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &filename) {
tmp_len += array->length();
for (int64_t j = 0; j < array->length(); j++) {
tmp_data.push_back(array->Value(j));
// LOG(INFO) << array->Value(j);
}
}
if (tmp_len != array_len) {
......@@ -423,7 +479,6 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &filename) {
tmp_len += array->length();
for (int64_t j = 0; j < array->length(); j++) {
tmp_data.push_back(array->Value(j));
// LOG(INFO) << array->Value(j);
}
}
if (tmp_len != array_len) {
......@@ -436,13 +491,6 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &filename) {
}
col_and_val_double.insert(
pair<string, std::vector<double>>(col_names[i], tmp_data));
// for (auto itr = col_and_val_double.begin();
// itr != col_and_val_double.end(); itr++) {
// LOG(INFO) << itr->first;
// auto tmp_vec = itr->second;
// for (auto iter = tmp_vec.begin(); iter != tmp_vec.end(); iter++)
// LOG(INFO) << *iter;
// }
}
}
if (errors)
......@@ -450,6 +498,7 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &filename) {
return array_len;
}
template class ArithmeticExecutor<D32>;
template class ArithmeticExecutor<D16>;
......
......@@ -17,9 +17,8 @@
#include "src/primihub/data_store/driver.h"
namespace primihub {
template <Decimal Dbit>
class ArithmeticExecutor : public AlgorithmBase {
template <Decimal Dbit> class ArithmeticExecutor : public AlgorithmBase {
public:
explicit ArithmeticExecutor(PartyConfig &config,
std::shared_ptr<DatasetService> dataset_service);
......@@ -31,34 +30,40 @@ public:
int saveModel(void);
private:
// int _ConstructShares(sf64Matrix<D> &w, sf64Matrix<D> &train_data,
// sf64Matrix<D> &train_label, sf64Matrix<D>
// &test_data, sf64Matrix<D> &test_label);
int _LoadDatasetFromCSV(std::string &filename);
int _getPartyIdWithPartyName(const std::string &party_name,
uint16_t &party_id);
void _fillPartyNameAndPartyId(const primihub::rpc::Task &task);
bool is_cmp;
MPCExpressExecutor<Dbit> *mpc_exec_;
MPCOperator *mpc_op_exec_;
std::string res_name_;
uint16_t local_id_;
std::pair<std::string, uint16_t> next_addr_;
std::pair<std::string, uint16_t> prev_addr_;
std::map<std::string, uint16_t> party_info_;
std::vector<uint32_t> parties_;
Session ep_next_;
Session ep_prev_;
IOService ios_;
std::pair<std::string, uint16_t> next_addr_;
std::pair<std::string, uint16_t> prev_addr_;
std::string next_ip_, prev_ip_;
uint16_t next_port_, prev_port_;
std::string data_file_path_;
std::map<std::string, u32> col_and_owner_;
std::map<std::string, bool> col_and_dtype_;
std::vector<uint32_t> parties_;
std::string dataset_id_;
uint32_t party_id_;
MPCExpressExecutor<Dbit> *mpc_exec_;
MPCOperator *mpc_op_exec_;
std::vector<double> final_val_double_;
std::vector<int64_t> final_val_int64_;
std::vector<bool> cmp_res_;
std::string expr_;
std::map<std::string, u32> col_and_owner_;
std::map<std::string, bool> col_and_dtype_;
std::map<std::string, std::vector<double>> col_and_val_double;
std::map<std::string, std::vector<int64_t>> col_and_val_int;
};
} // namespace primihub
\ No newline at end of file
} // namespace primihub
......@@ -58,6 +58,7 @@ std::shared_ptr<arrow::Table> ReadCSVFile(const std::string& file_path,
<< "detail: " << maybe_table.status();
return nullptr;
}
return *maybe_table;
}
} // namespace csv
......@@ -340,12 +341,15 @@ std::unique_ptr<Cursor> CSVDriver::read() {
if (ret != retcode::SUCCESS) {
return nullptr;
}
read_options.skip_rows = 1; // skip title row
auto arrow_data = csv::ReadCSVFile(csv_access_info->file_path_, read_options,
parse_options, convert_options);
if (arrow_data == nullptr) {
return nullptr;
}
std::vector<FieldType> fileds;
auto arrow_fileds = arrow_data->schema()->fields();
for (const auto& field : arrow_fileds) {
......
......@@ -1287,6 +1287,8 @@ void MPCExpressExecutor<Dbit>::revealMPCResult(std::vector<uint32_t> &parties,
<< ".";
} else {
mpc_op_->reveal(*p_final_share, party);
LOG(INFO) << "Reveal MPC result to party "
<< static_cast<char>(party + '0') << ".";
}
}
......@@ -1316,6 +1318,8 @@ void MPCExpressExecutor<Dbit>::revealMPCResult(std::vector<uint32_t> &parties,
<< ".";
} else {
mpc_op_->reveal(*p_final_share, party);
LOG(INFO) << "Reveal MPC result to party "
<< static_cast<char>(party + '0') << ".";
}
}
......
......@@ -160,12 +160,16 @@ namespace primihub::task
LOG(ERROR) << "Algorithm is not initialized";
return -1;
}
// algorithm_->set_task_info(platform(),job_id(),task_id());
algorithm_->loadParams(task_param_);
int ret = 0;
do
{
ret = algorithm_->loadParams(task_param_);
if (ret) {
LOG(ERROR) << "Load params failed.";
break;
}
ret = algorithm_->loadDataset();
if (ret)
{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册