未验证 提交 47e65456 编写于 作者: PhoenixTree2013's avatar PhoenixTree2013 提交者: GitHub

pir refactor (#597)

上级 405e10f6
......@@ -59,6 +59,22 @@ enum class Visibility {
PRIVATE = 1,
};
class RoleValidation {
public:
static bool IsClient(const std::string& party_name) {
return party_name == PARTY_CLIENT;
}
static bool IsServer(const std::string& party_name) {
return party_name == PARTY_SERVER;
}
static bool IsTeeCompute(const std::string& party_name) {
return party_name == PARTY_TEE_COMPUTE;
}
};
struct Node {
Node() = default;
Node(const std::string& id, const std::string& ip,
......
package(default_visibility = ["//visibility:public"])
cc_library(
name = "common_def",
hdrs = ["common.h"],
)
\ No newline at end of file
// "Copyright [2023] <PrimiHub>"
#ifndef SRC_PRIMIHUB_KERNEL_PIR_COMMON_H_
#define SRC_PRIMIHUB_KERNEL_PIR_COMMON_H_
#include <unordered_map>
#include <vector>
#include <string>
namespace primihub::pir {
using PirDataType = std::unordered_map<std::string, std::vector<std::string>>;
enum class PirType {
ID_PIR = 0,
KEY_PIR,
};
} // namespace primihub::pir
#endif // SRC_PRIMIHUB_KERNEL_PIR_COMMON_H_
package(default_visibility = ["//visibility:public"])
cc_library(
name = "factory",
hdrs = ["factory.h"],
deps = [
"//src/primihub/kernel/pir:common_def",
":base_pir_operator",
":keyword_pir_operator",
]
)
cc_library(
name = "base_pir_operator",
hdrs = ["base_pir.h"],
srcs = ["base_pir.cc"],
deps = [
"//src/primihub/kernel/pir:common_def",
"//src/primihub/util:endian_util",
"//src/primihub/util:util_lib",
"//src/primihub/common:common_defination",
"//src/primihub/util/network:communication_lib",
],
)
cc_library(
name = "keyword_pir_operator",
hdrs = ["keyword_pir.h"],
srcs = ["keyword_pir.cc"],
copts = [
"-w",
"-D_ASPI",
],
deps = [
":base_pir_operator",
"//src/primihub/util:endian_util",
"//src/primihub/util:util_lib",
"//src/primihub/protos:worker_proto",
"@mircrosoft_apsi//:APSI",
]
)
// "Copyright [2023] <PrimiHub>"
#include "src/primihub/kernel/pir/operator/base_pir.h"
namespace primihub::pir {
retcode BasePirOperator::Execute(const PirDataType& input,
PirDataType* result) {
return OnExecute(input, result);
}
} // namespace primihub::pir
// "Copyright [2023] <PrimiHub>"
#ifndef SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_BASE_PIR_H_
#define SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_BASE_PIR_H_
#include <map>
#include <string>
#include "src/primihub/common/common.h"
#include "src/primihub/kernel/pir/common.h"
#include "src/primihub/util/network/link_context.h"
namespace primihub::pir {
using LinkContext = network::LinkContext;
struct Options {
LinkContext* link_ctx_ref;
std::map<std::string, Node> party_info;
std::string self_party;
std::string code;
// online
bool use_cache{false};
// offline task
bool generate_db{false};
std::string db_path;
Node peer_node;
};
class BasePirOperator {
public:
explicit BasePirOperator(const Options& options) : options_(options) {}
virtual ~BasePirOperator() = default;
/**
* PSI protocol
*/
retcode Execute(const PirDataType& input, PirDataType* result);
virtual retcode OnExecute(const PirDataType& input, PirDataType* result) = 0;
void set_stop() {stop_.store(true);}
protected:
bool has_stopped() {
return stop_.load(std::memory_order::memory_order_relaxed);
}
std::string PartyName() {return options_.self_party;}
LinkContext* GetLinkContext() {return options_.link_ctx_ref;}
Node& PeerNode() {return options_.peer_node;}
protected:
std::atomic<bool> stop_{false};
Options options_;
std::string key_{"pir_key"};
};
} // namespace primihub::pir
#endif // SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_BASE_PIR_H_
// "Copyright [2023] <PrimiHub>"
#ifndef SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_FACTORY_H_
#define SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_FACTORY_H_
#include <glog/logging.h>
#include <memory>
#include "src/primihub/kernel/pir/common.h"
#include "src/primihub/kernel/pir/operator/keyword_pir.h"
namespace primihub::pir {
class Factory {
public:
static std::unique_ptr<BasePirOperator> Create(PirType pir_type,
const Options& options) {
std::unique_ptr<BasePirOperator> operator_ptr{nullptr};
switch (pir_type) {
case PirType::ID_PIR:
LOG(ERROR) << "Unimplement";
break;
case PirType::KEY_PIR:
operator_ptr = std::make_unique<KeywordPirOperator>(options);
break;
default:
LOG(ERROR) << "unknown pir operator: " << static_cast<int>(pir_type);
break;
}
return operator_ptr;
}
};
} // namespace primihub::pir
#endif // SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_FACTORY_H_
// "Copyright [2023] <PrimiHub>"
#include "src/primihub/kernel/pir/operator/id_pir.h"
// "Copyright [2023] <PrimiHub>"
#ifndef SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_ID_PIR_H_
#define SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_ID_PIR_H_
#endif // SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_ID_PIR_H_
此差异已折叠。
// "Copyright [2023] <PrimiHub>"
#ifndef SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_KEYWORD_PIR_H_
#define SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_KEYWORD_PIR_H_
#include <variant>
#include <vector>
#include <memory>
#include <utility>
#include <string>
#include "src/primihub/kernel/pir/operator/base_pir.h"
#include "src/primihub/kernel/pir/common.h"
// APSI
#include "apsi/thread_pool_mgr.h"
#include "apsi/sender_db.h"
#include "apsi/oprf/oprf_sender.h"
#include "apsi/powers.h"
#include "apsi/util/common_utils.h"
#include "apsi/sender.h"
#include "apsi/bin_bundle.h"
#include "apsi/item.h"
#include "apsi/receiver.h"
// SEAL
#include "seal/context.h"
#include "seal/modulus.h"
#include "seal/util/common.h"
#include "seal/util/defines.h"
using namespace apsi; // NOLINT
using namespace apsi::sender; // NOLINT
using namespace apsi::oprf; // NOLINT
using namespace apsi::network; // NOLINT
using namespace seal; // NOLINT
using namespace seal::util; // NOLINT
namespace primihub::pir {
using UnlabeledData = std::vector<apsi::Item>;
using LabeledData = std::vector<std::pair<apsi::Item, apsi::Label>>;
using DBData = std::variant<UnlabeledData, LabeledData>;
class KeywordPirOperator : public BasePirOperator {
public:
enum class RequestType : uint8_t {
PsiParam = 0,
Oprf,
Query,
};
explicit KeywordPirOperator(const Options& options) :
BasePirOperator(options) {}
retcode OnExecute(const PirDataType& input, PirDataType* result) override;
protected:
retcode ExecuteAsClient(const PirDataType& input, PirDataType* result);
retcode ExecuteAsServer(const PirDataType& input);
protected:
// ------------------------Receiver----------------------------
/**
* Performs a parameter request from sender
*/
retcode RequestPSIParams();
/**
* Performs an OPRF request on a vector of items and returns a
vector of OPRF hashed items of the same size as the input vector.
*/
retcode RequestOprf(const std::vector<Item>& items,
std::vector<apsi::HashedItem>*, std::vector<apsi::LabelKey>*);
/**
* Performs a labeled PSI query. The query is a vector of items,
* and the result is a same-size vector of MatchRecord objects.
* If an item is in the intersection,
* the corresponding MatchRecord indicates it in the `found` field,
* and the `label` field may contain the corresponding
* label if a sender's data included it.
*/
retcode RequestQuery();
retcode ExtractResult(const std::vector<std::string>& orig_vec,
const std::vector<apsi::receiver::MatchRecord>& query_result,
PirDataType* result);
protected:
// ------------------------Sender----------------------------
std::unique_ptr<apsi::PSIParams> SetPsiParams();
/**
* process a Get Parameters request to the Sender.
*/
retcode ProcessPSIParams();
/**
process an OPRF query request to the Sender.
*/
retcode ProcessOprf();
/**
process a Query request to the Sender.
*/
retcode ProcessQuery(std::shared_ptr<apsi::sender::SenderDB> sender_db);
retcode ComputePowers(const shared_ptr<apsi::sender::SenderDB> &sender_db,
const apsi::CryptoContext &crypto_context,
std::vector<apsi::sender::CiphertextPowers> &all_powers,
const apsi::PowersDag &pd,
uint32_t bundle_idx,
seal::MemoryPoolHandle &pool);
auto ProcessBinBundleCache(
const shared_ptr<apsi::sender::SenderDB> &sender_db,
const apsi::CryptoContext &crypto_context,
reference_wrapper<const apsi::sender::BinBundleCache> cache,
std::vector<apsi::sender::CiphertextPowers> &all_powers,
uint32_t bundle_idx,
compr_mode_type compr_mode,
seal::MemoryPoolHandle &pool) ->
std::unique_ptr<apsi::network::ResultPackage>;
std::unique_ptr<DBData> CreateDb(const PirDataType& input);
retcode CreateDbDataCache(const DBData& db_data,
std::unique_ptr<apsi::PSIParams> psi_params,
apsi::oprf::OPRFKey &oprf_key,
size_t nonce_byte_count,
bool compress);
auto CreateSenderDb(const DBData &db_data,
std::unique_ptr<PSIParams> psi_params,
apsi::oprf::OPRFKey &oprf_key,
size_t nonce_byte_count,
bool compress) -> std::shared_ptr<SenderDB>;
bool DbCacheAvailable(const std::string& db_path);
std::shared_ptr<apsi::sender::SenderDB>
LoadDbFromCache(const std::string& db_path);
private:
std::string psi_params_str_;
std::unique_ptr<apsi::oprf::OPRFKey> oprf_key_{nullptr};
std::unique_ptr<apsi::receiver::Receiver> receiver_{nullptr};
std::unique_ptr<apsi::PSIParams> psi_params_{nullptr};
};
} // namespace primihub::pir
#endif // SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_KEYWORD_PIR_H_
......@@ -90,10 +90,6 @@ retcode Worker::execute(const PushTaskRequest *pushTaskRequest) {
void Worker::kill_task() {
if (task_ptr) {
task_ptr->kill_task();
return;
}
if (task_server_ptr) {
task_server_ptr->kill_task();
}
}
......
......@@ -39,7 +39,6 @@
#include "src/primihub/node/nodelet.h"
#include "src/primihub/protos/worker.pb.h"
#include "src/primihub/task/semantic/task.h"
#include "src/primihub/task/semantic/private_server_base.h"
#include "src/primihub/common/common.h"
using primihub::rpc::PushTaskRequest;
......@@ -71,9 +70,6 @@ class Worker {
std::shared_ptr<primihub::task::TaskBase> getTask() {
return task_ptr;
}
std::shared_ptr<primihub::task::ServerTaskBase> getServerTask() {
return task_server_ptr;
}
retcode waitForTaskReady();
// scheduler method
......@@ -89,7 +85,6 @@ class Worker {
mutable absl::Mutex worker_map_mutex_;
std::shared_ptr<primihub::task::TaskBase> task_ptr{nullptr};
std::shared_ptr<primihub::task::ServerTaskBase> task_server_ptr{nullptr};
const std::string& node_id;
std::shared_ptr<Nodelet> nodelet;
std::string worker_id_;
......
......@@ -29,7 +29,6 @@ cc_library(
":pir_task",
":psi_task",
":tee_task",
":private_server_base",
],
)
cc_library(
......@@ -116,65 +115,14 @@ cc_library(
# pir task
cc_library(
name = "pir_task",
deps = [
":keyword_pir_task",
],
# deps = select({
# "microsoft-apsi" : [":keyword_pir_task"],
# "//conditions:default": [":id_pir_task"],
# }),
)
cc_library(
name = "keyword_pir_task",
hdrs = [
"keyword_pir_client_task.h",
"keyword_pir_server_task.h",
],
srcs = [
"keyword_pir_client_task.cc",
"keyword_pir_server_task.cc",
],
copts = [
"-w",
"-D_ASPI",
],
defines = ["USE_MICROSOFT_APSI"],
deps = [
":task_interface",
"//src/primihub/protos:common_proto",
"@mircrosoft_apsi//:APSI",
]
)
cc_library(
name = "private_server_base",
hdrs = ["private_server_base.h"],
srcs = ["private_server_base.cc"],
deps = [
":task_interface",
"//src/primihub/protos:worker_proto",
"//src/primihub/common:common_defination",
"//src/primihub/data_store:data_store_lib",
"//src/primihub/service:dataset_service",
],
)
cc_library(
name = "id_pir_task",
hdrs = [
"pir_client_task.h",
"pir_server_task.h",
],
srcs = [
"pir_client_task.cc",
"pir_server_task.cc",
],
deps = [
":private_server_base",
":task_interface",
"@org_openmined_pir//pir/cpp:pir",
],
name = "pir_task",
hdrs = ["pir_task.h"],
srcs = ["pir_task.cc"],
deps = [
":task_interface",
"//src/primihub/kernel/pir:common_def",
"//src/primihub/kernel/pir/operator:factory",
],
)
# task semantic parser
......
......@@ -24,15 +24,7 @@
#include "src/primihub/task/semantic/mpc_task.h"
#include "src/primihub/task/semantic/fl_task.h"
#include "src/primihub/task/semantic/psi_task.h"
#include "src/primihub/task/semantic/private_server_base.h"
#ifndef USE_MICROSOFT_APSI
#include "src/primihub/task/semantic/pir_client_task.h"
#include "src/primihub/task/semantic/pir_server_task.h"
#else
#include "src/primihub/task/semantic/keyword_pir_client_task.h"
#include "src/primihub/task/semantic/keyword_pir_server_task.h"
#endif
#include "src/primihub/task/semantic/pir_task.h"
#include "src/primihub/task/semantic/tee_task.h"
#include "src/primihub/service/dataset/service.h"
......@@ -42,8 +34,6 @@ using primihub::rpc::PushTaskRequest;
using primihub::rpc::Language;
using primihub::rpc::TaskType;
using primihub::service::DatasetService;
using primihub::rpc::PsiTag;
using primihub::rpc::PirType;
namespace primihub::task {
......@@ -123,38 +113,7 @@ class TaskFactory {
const PushTaskRequest& request,
std::shared_ptr<DatasetService> dataset_service) {
const auto& task_config = request.task();
const auto& param_map = task_config.params().param_map();
int pir_type = PirType::ID_PIR;
auto param_it = param_map.find("pirType");
if (param_it != param_map.end()) {
pir_type = param_it->second.value_int32();
}
#ifndef USE_MICROSOFT_APSI
const auto& job_id = request.task().task_info().job_id();
const auto& task_id = request.task().task_info().task_id();
if (pir_type == PirType::ID_PIR) {
return std::make_shared<PIRClientTask>(
node_id, job_id, task_id, &task_config, dataset_service);
} else {
// TODO, using condition compile, fix in future
LOG(WARNING) << "ID_PIR is not supported when MICROSOFT_APSI enabled";
return nullptr;
}
#else // KEYWORD PIR
if (pir_type == PirType::KEY_PIR) {
std::string party_name = task_config.party_name();
if (party_name == PARTY_SERVER) {
return std::make_shared<KeywordPIRServerTask>(&task_config,
dataset_service);
} else {
return std::make_shared<KeywordPIRClientTask>(&task_config,
dataset_service);
}
} else {
LOG(ERROR) << "Unsupported pir type: " << pir_type;
return nullptr;
}
#endif
return std::make_shared<PirTask>(&task_config, dataset_service);
}
static std::shared_ptr<TaskBase> CreateTEETask(const std::string& node_id,
......@@ -166,28 +125,7 @@ class TaskFactory {
dataset_service);
}
static std::shared_ptr<ServerTaskBase> Create(const std::string& node_id,
rpc::TaskType task_type,
const ExecuteTaskRequest& request,
ExecuteTaskResponse *response,
std::shared_ptr<DatasetService> dataset_service) {
if (task_type == rpc::TaskType::NODE_PIR_TASK) {
#ifdef USE_MICROSOFT_APSI
// TODO, using condition compile, fix in future
LOG(WARNING) << "ID_PIR is not supported when using MICROSOFT_APSI";
return nullptr;
#else
return std::make_shared<PIRServerTask>(node_id, request,
response, dataset_service);
#endif
} else {
LOG(ERROR) << "Unsupported task type at server node: "<< task_type <<".";
return nullptr;
}
}
};
} // namespace primihub::task
#endif // SRC_PRIMIHUB_TASK_SEMANTIC_FACTORY_H_
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "src/primihub/task/semantic/keyword_pir_client_task.h"
#include <thread>
#include <chrono>
#include <sstream>
#include "src/primihub/util/util.h"
#include "apsi/item.h"
#include "apsi/util/common_utils.h"
#include "src/primihub/util/file_util.h"
#include "src/primihub/protos/worker.pb.h"
#include "seal/util/common.h"
using namespace apsi;
using namespace apsi::network;
namespace primihub::task {
KeywordPIRClientTask::KeywordPIRClientTask(
const TaskParam *task_param, std::shared_ptr<DatasetService> dataset_service)
: TaskBase(task_param, dataset_service) {
}
retcode KeywordPIRClientTask::_LoadParams(Task &task) {
CHECK_TASK_STOPPED(retcode::FAIL);
std::string party_name = task.party_name();
const auto& param_map = task.params().param_map();
try {
auto client_data_it = param_map.find("clientData");
if (client_data_it != param_map.end()) {
auto& client_data = client_data_it->second;
if (client_data.is_array()) {
recv_query_data_direct = true; // read query data from clientData key directly
const auto& items = client_data.value_string_array().value_string_array();
for (const auto& item : items) {
recv_data_.push_back(item);
}
} else {
dataset_path_ = client_data.value_string();
dataset_id_ = client_data.value_string();
}
} else {
// check client has dataset
const auto& party_datasets = task.party_datasets();
auto it = party_datasets.find(party_name);
if (it == party_datasets.end()) {
LOG(ERROR) << "no query data found for client, party_name: " << party_name;
return retcode::FAIL;
}
const auto& datasets_map = it->second.data();
auto iter = datasets_map.find(party_name);
if (iter == datasets_map.end()) {
LOG(ERROR) << "no query data found for client, party_name: " << party_name;
return retcode::FAIL;
}
dataset_id_ = iter->second;
}
VLOG(7) << "dataset_id: " << dataset_id_;
auto result_file_path_it = param_map.find("outputFullFilename");
if (result_file_path_it != param_map.end()) {
result_file_path_ = result_file_path_it->second.value_string();
VLOG(5) << "result_file_path_: " << result_file_path_;
} else {
LOG(ERROR) << "no keyword outputFullFilename match";
return retcode::FAIL;
}
// get server dataset id
do {
const auto& party_datasets = task.party_datasets();
auto it = party_datasets.find(PARTY_SERVER);
if (it == party_datasets.end()) {
LOG(WARNING) << "no dataset found for party_name: " << PARTY_SERVER;
break;
}
const auto& datasets_map = it->second.data();
auto iter = datasets_map.find(PARTY_SERVER);
if (iter == datasets_map.end()) {
LOG(WARNING) << "no dataset found for party_name: " << PARTY_SERVER;
break;
}
std::string server_dataset_id = iter->second;
auto& dataset_service = this->getDatasetService();
auto driver = dataset_service->getDriver(server_dataset_id);
if (driver == nullptr) {
LOG(WARNING) << "no dataset access info found for id: " << server_dataset_id;
break;
}
auto& access_info = driver->dataSetAccessInfo();
if (access_info == nullptr) {
LOG(WARNING) << "no dataset access info found for id: " << server_dataset_id;
break;
}
auto& schema = access_info->Schema();
for (const auto& field : schema) {
server_dataset_schema_.push_back(std::get<0>(field));
}
} while (0);
} catch (std::exception &e) {
LOG(ERROR) << "Failed to load params: " << e.what();
return retcode::FAIL;
}
const auto& party_info = task.party_access_info();
auto it = party_info.find(PARTY_SERVER);
if (it == party_info.end()) {
LOG(ERROR) << "client can not found access info to server";
return retcode::FAIL;
}
auto& pb_node = it->second;
pbNode2Node(pb_node, &peer_node_);
VLOG(5) << "peer_node: " << peer_node_.to_string();
return retcode::SUCCESS;
}
KeywordPIRClientTask::DatasetDBPair KeywordPIRClientTask::_LoadDataFromDataset() {
apsi::util::CSVReader::DBData db_data;
std::vector<std::string> orig_items;
auto driver = this->getDatasetService()->getDriver(this->dataset_id_);
if (driver == nullptr) {
LOG(ERROR) << "get driver for dataset: " << this->dataset_id_ << " failed";
return std::make_pair(nullptr, std::vector<std::string>());
}
auto access_info = dynamic_cast<CSVAccessInfo*>(driver->dataSetAccessInfo().get());
if (access_info == nullptr) {
LOG(ERROR) << "get data accessinfo for dataset: " << this->dataset_id_ << " failed";
return std::make_pair(nullptr, std::vector<std::string>());
}
dataset_path_ = access_info->file_path_;
try {
apsi::util::CSVReader reader(dataset_path_);
std::tie(db_data, orig_items) = reader.read();
} catch (const std::exception &ex) {
LOG(ERROR) << "Could not open or read file `"
<< dataset_path_ << "`: "
<< ex.what();
return std::make_pair(nullptr, orig_items);
}
return {std::make_unique<apsi::util::CSVReader::DBData>(std::move(db_data)), std::move(orig_items)};
}
KeywordPIRClientTask::DatasetDBPair KeywordPIRClientTask::_LoadDataFromRecvData() {
if (recv_data_.empty()) {
LOG(ERROR) << "query data is empty";
return std::make_pair(nullptr, std::vector<std::string>());
}
// build db_data;
// std::unqiue_ptr<apsi::util::CSVReader::DBData>
apsi::util::CSVReader::DBData db_data = apsi::util::CSVReader::UnlabeledData();
for(const auto& item_str : recv_data_) {
apsi::Item db_item = item_str;
std::get<apsi::util::CSVReader::UnlabeledData>(db_data).push_back(std::move(db_item));
}
return {std::make_unique<apsi::util::CSVReader::DBData>(std::move(db_data)), recv_data_};
// return std::make_pair(std::move(db_data), std::move(orig_items));
}
KeywordPIRClientTask::DatasetDBPair KeywordPIRClientTask::_LoadDataset(void) {
if (!recv_query_data_direct) {
return _LoadDataFromDataset();
} else {
return _LoadDataFromRecvData();
}
}
retcode KeywordPIRClientTask::saveResult(
const std::vector<std::string>& orig_items,
const std::vector<Item>& items,
const std::vector<MatchRecord>& intersection) {
CHECK_TASK_STOPPED(retcode::FAIL);
if (orig_items.size() != items.size()) {
LOG(ERROR) << "Keyword PIR orig_items must have the same size as items, detail: "
<< "orig_items size: " << orig_items.size() << " items size: " << items.size();
return retcode::FAIL;
}
std::vector<std::vector<std::string>> result_data;
result_data.resize(2);
for (auto& item : result_data) {
item.reserve(orig_items.size());
}
auto& key = result_data[0];
auto& result_value = result_data[1];
for (size_t i = 0; i < orig_items.size(); i++) {
if (!intersection[i].found) {
VLOG(0) << "no match result found for query: [" << orig_items[i] << "]";
continue;
}
if (intersection[i].label) {
std::string label_info = intersection[i].label.to_string();
std::vector<std::string> labels;
std::string sep = DATA_RECORD_SEP;
str_split(label_info, &labels, sep);
for (const auto& lable_ : labels) {
key.push_back(orig_items[i]);
result_value.push_back(lable_);
}
} else {
LOG(WARNING) << "no value found for query key: " << orig_items[i];
}
}
VLOG(0) << "save query result to : " << result_file_path_;
std::vector<std::shared_ptr<arrow::Field>> schema_vector;
std::vector<std::string> tmp_colums{"key", "value"};
for (const auto& col_name : tmp_colums) {
schema_vector.push_back(arrow::field(col_name, arrow::int64()));
}
std::vector<std::shared_ptr<arrow::Array>> arrow_array;
for (auto& item : result_data) {
arrow::StringBuilder builder;
builder.AppendValues(item);
std::shared_ptr<arrow::Array> array;
builder.Finish(&array);
arrow_array.push_back(std::move(array));
}
auto schema = std::make_shared<arrow::Schema>(schema_vector);
// std::shared_ptr<arrow::Table>
auto table = arrow::Table::Make(schema, arrow_array);
auto driver = DataDirverFactory::getDriver("CSV", "test address");
auto csv_driver = std::dynamic_pointer_cast<CSVDriver>(driver);
auto rtcode = csv_driver->Write(server_dataset_schema_, table, result_file_path_);
if (rtcode != retcode::SUCCESS) {
LOG(ERROR) << "save PIR data to file " << result_file_path_ << " failed.";
return retcode::FAIL;
}
return retcode::SUCCESS;
}
retcode KeywordPIRClientTask::requestPSIParams() {
CHECK_TASK_STOPPED(retcode::FAIL);
RequestType type = RequestType::PsiParam;
std::string request{reinterpret_cast<char*>(&type), sizeof(type)};
VLOG(5) << "send_data length: " << request.length();
std::string response_str;
auto& link_ctx = this->getTaskContext().getLinkContext();
CHECK_NULLPOINTER_WITH_ERROR_MSG(link_ctx, "LinkContext is empty");
auto channel = link_ctx->getChannel(peer_node_);
auto ret = channel->sendRecv(this->key, request, &response_str);
if (ret != retcode::SUCCESS) {
LOG(ERROR) << "send requestPSIParams to peer: [" << peer_node_.to_string()
<< "] failed";
return ret;
}
if (VLOG_IS_ON(5)) {
std::string tmp_str;
for (const auto& chr : response_str) {
tmp_str.append(std::to_string(static_cast<int>(chr))).append(" ");
}
VLOG(5) << "recv_data size: " << response_str.size() << " "
<< "data content: " << tmp_str;
}
// create psi params
// static std::pair<PSIParams, std::size_t> Load(std::istream &in);
std::istringstream stream_in(response_str);
auto [parse_data, ret_size] = PSIParams::Load(stream_in);
psi_params_ = std::make_unique<PSIParams>(parse_data);
VLOG(5) << "parsed psi param, size: " << ret_size << " "
<< "content: " << psi_params_->to_string();
return retcode::SUCCESS;
}
static std::string to_hexstring(const Item &item) {
std::stringstream ss;
ss << std::hex;
auto item_string = item.to_string();
for(int i(0); i < 16; ++i)
ss << std::setw(2) << std::setfill('0') << (int)item_string[i];
return ss.str();
}
retcode KeywordPIRClientTask::requestOprf(const std::vector<Item>& items,
std::vector<apsi::HashedItem>* res_items_ptr,
std::vector<apsi::LabelKey>* res_label_keys_ptr) {
CHECK_TASK_STOPPED(retcode::FAIL);
RequestType type = RequestType::Oprf;
std::string oprf_response;
auto oprf_receiver = this->receiver_->CreateOPRFReceiver(items);
auto& res_items = *res_items_ptr;
auto& res_label_keys = *res_label_keys_ptr;
res_items.resize(oprf_receiver.item_count());
res_label_keys.resize(oprf_receiver.item_count());
auto oprf_request = oprf_receiver.query_data();
VLOG(5) << "oprf_request data length: " << oprf_request.size();
std::string_view oprf_request_sv{
reinterpret_cast<char*>(const_cast<unsigned char*>(oprf_request.data())), oprf_request.size()};
auto& link_ctx = this->getTaskContext().getLinkContext();
CHECK_NULLPOINTER_WITH_ERROR_MSG(link_ctx, "LinkContext is empty");
auto channel = link_ctx->getChannel(peer_node_);
// auto ret = channel->sendRecv(this->key, oprf_request_sv, &oprf_response);
auto ret = this->send(this->key, peer_node_, oprf_request_sv);
if (ret != retcode::SUCCESS) {
LOG(ERROR) << "requestOprf to peer: [" << peer_node_.to_string()
<< "] failed";
return ret;
}
ret = this->recv(this->key, &oprf_response);
if (ret != retcode::SUCCESS || oprf_response.empty()) {
LOG(ERROR) << "receive oprf_response from peer: [" << peer_node_.to_string()
<< "] failed";
return retcode::FAIL;
}
VLOG(5) << "received oprf response length: " << oprf_response.length() << " ";
oprf_receiver.process_responses(oprf_response, res_items, res_label_keys);
return retcode::SUCCESS;
}
retcode KeywordPIRClientTask::requestQuery() {
RequestType type = RequestType::Query;
std::string send_data{reinterpret_cast<char*>(&type), sizeof(type)};
VLOG(5) << "send_data length: " << send_data.length();
return retcode::SUCCESS;
}
int KeywordPIRClientTask::execute() {
auto ret = _LoadParams(task_param_);
if (ret != retcode::SUCCESS) {
LOG(ERROR) << "Pir client load task params failed.";
return -1;
}
VLOG(5) << "begin to request psi params";
ret = requestPSIParams();
CHECK_RETCODE_WITH_RETVALUE(ret, -1);
auto [query_data, orig_items] = _LoadDataset();
if (!query_data || !holds_alternative<CSVReader::UnlabeledData>(*query_data)) {
LOG(ERROR) << "Failed to read keyword PIR query file: terminating";
return -1;
}
auto& items = std::get<CSVReader::UnlabeledData>(*query_data);
std::vector<Item> items_vec(items.begin(), items.end());
std::vector<HashedItem> oprf_items;
std::vector<LabelKey> label_keys;
VLOG(5) << "begin to Receiver::RequestOPRF";
ret = requestOprf(items_vec, &oprf_items, &label_keys);
CHECK_RETCODE_WITH_RETVALUE(ret, -1);
CHECK_TASK_STOPPED(-1);
if (VLOG_IS_ON(5)) {
for (int i = 0; i < items_vec.size(); i++) {
VLOG(5) << "item[" << i << "]'s PRF value: " << to_hexstring(oprf_items[i]);
}
}
VLOG(5) << "Receiver::RequestOPRF end, begin to receiver.request_query";
// request query
this->receiver_ = std::make_unique<Receiver>(*psi_params_);
std::vector<MatchRecord> query_result;
try {
auto query = this->receiver_->create_query(oprf_items);
// chl.send(move(query.first));
auto request_query_data = std::move(query.first);
std::ostringstream string_ss;
request_query_data->save(string_ss);
std::string query_data_str = string_ss.str();
auto itt = move(query.second);
VLOG(5) << "query_data_str size: " << query_data_str.size();
ret = this->send(this->key, peer_node_, query_data_str);
CHECK_RETCODE_WITH_RETVALUE(ret, -1);
// receive package count
uint32_t package_count = 0;
ret = this->recv("package_count", reinterpret_cast<char*>(&package_count), sizeof(package_count));
CHECK_RETCODE_WITH_RETVALUE(ret, -1);
VLOG(5) << "received package count: " << package_count;
std::vector<apsi::ResultPart> result_packages;
for (size_t i = 0; i < package_count; i++) {
std::string recv_data;
ret = this->recv(this->key, &recv_data);
CHECK_RETCODE_WITH_RETVALUE(ret, -1);
VLOG(5) << "client received data length: " << recv_data.size();
std::istringstream stream_in(recv_data);
apsi::ResultPart result_part = std::make_unique<apsi::network::ResultPackage>();
auto seal_context = this->receiver_->get_seal_context();
result_part->load(stream_in, seal_context);
result_packages.push_back(std::move(result_part));
}
query_result = this->receiver_->process_result(label_keys, itt, result_packages);
VLOG(5) << "query_resultquery_resultquery_resultquery_result: " << query_result.size();
} catch (const std::exception &ex) {
LOG(ERROR) << "Failed sending keyword PIR query: " << ex.what();
return -1;
}
VLOG(5) << "receiver.request_query end";
ret = this->saveResult(orig_items, items, query_result);
CHECK_RETCODE_WITH_RETVALUE(ret, -1);
return 0;
}
} // namespace primihub::task
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef SRC_PRIMIHUB_TASK_SEMANTIC_KEYWORD_PIR_CLIENT_TASK_H_
#define SRC_PRIMIHUB_TASK_SEMANTIC_KEYWORD_PIR_CLIENT_TASK_H_
#include <vector>
#include "apsi/item.h"
#include "apsi/match_record.h"
#include "apsi/util/csv_reader.h"
#include "src/primihub/task/semantic/task.h"
#include "src/primihub/common/common.h"
#include "apsi/receiver.h"
using apsi::Item;
using apsi::receiver::MatchRecord;
using apsi::util::CSVReader;
using namespace apsi::receiver;
namespace primihub::task {
class KeywordPIRClientTask : public TaskBase {
public:
using DBDataPtr = std::unique_ptr<apsi::util::CSVReader::DBData>;
using DatasetDBPair = std::pair<DBDataPtr, std::vector<std::string>>;
enum class RequestType : uint8_t {
PsiParam = 0,
Oprf,
Query,
};
explicit KeywordPIRClientTask(const TaskParam *task_param,
std::shared_ptr<DatasetService> dataset_service);
~KeywordPIRClientTask() = default;
int execute() override;
retcode saveResult(const std::vector<std::string>& orig_items,
const std::vector<apsi::Item>& items,
const std::vector<apsi::receiver::MatchRecord>& intersection);
protected:
/**
* Performs a parameter request from sender
*/
retcode requestPSIParams();
/**
* Performs an OPRF request on a vector of items and returns a
vector of OPRF hashed items of the same size as the input vector.
*/
retcode requestOprf(const std::vector<Item>& items,
std::vector<apsi::HashedItem>*, std::vector<apsi::LabelKey>*);
/**
* Performs a labeled PSI query. The query is a vector of
items, and the result is a same-size vector of MatchRecord objects. If an item is in the
intersection, the corresponding MatchRecord indicates it in the `found` field, and the
`label` field may contain the corresponding label if a sender's data included it.
*/
retcode requestQuery();
private:
retcode _LoadParams(Task &task);
DatasetDBPair _LoadDataset();
// load dataset according by url
DatasetDBPair _LoadDataFromDataset();
// load data from request directly
DatasetDBPair _LoadDataFromRecvData();
private:
std::string dataset_path_;
std::string dataset_id_;
std::string result_file_path_;
std::string server_address_;
bool recv_query_data_direct{false};
uint32_t server_data_port{2222};
primihub::Node peer_node_;
std::string key{"default"};
std::unique_ptr<apsi::PSIParams> psi_params_{nullptr};
std::unique_ptr<apsi::receiver::Receiver> receiver_{nullptr};
std::vector<std::string> recv_data_;
std::vector<std::string> server_dataset_schema_;
};
} // namespace primihub::task
#endif // SRC_PRIMIHUB_TASK_SEMANTIC_KEYWORD_PIR_CLIENT_TASK_H_
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef SRC_PRIMIHUB_TASK_SEMANTIC_KEYWORD_PIR_SERVER_TASK_H_
#define SRC_PRIMIHUB_TASK_SEMANTIC_KEYWORD_PIR_SERVER_TASK_H_
#include <utility>
#include <vector>
#include <string>
#include <memory>
// APSI
#include "apsi/util/common_utils.h"
#include "apsi/sender.h"
#include "apsi/oprf/oprf_sender.h"
#include "apsi/bin_bundle.h"
#include "apsi/item.h"
// SEAL
#include "seal/context.h"
#include "seal/modulus.h"
#include "seal/util/common.h"
#include "seal/util/defines.h"
#include "src/primihub/protos/common.pb.h"
#include "src/primihub/task/semantic/task.h"
namespace primihub::task {
using UnlabeledData = std::vector<apsi::Item>;
using LabeledData = std::vector<std::pair<apsi::Item, apsi::Label>>;
using DBData = std::variant<UnlabeledData, LabeledData>;
class KeywordPIRServerTask : public TaskBase {
public:
enum class RequestType : uint8_t {
PsiParam = 0,
Opfr,
Query,
};
explicit KeywordPIRServerTask(const TaskParam* task_param,
std::shared_ptr<DatasetService> dataset_service);
~KeywordPIRServerTask() = default;
int execute() override;
protected:
/**
* process a Get Parameters request to the Sender.
*/
retcode processPSIParams();
/**
process an OPRF query request to the Sender.
*/
retcode processOprf();
/**
process a Query request to the Sender.
*/
retcode processQuery(std::shared_ptr<apsi::sender::SenderDB> sender_db);
retcode ComputePowers(const shared_ptr<apsi::sender::SenderDB> &sender_db,
const apsi::CryptoContext &crypto_context,
std::vector<apsi::sender::CiphertextPowers> &all_powers,
const apsi::PowersDag &pd,
uint32_t bundle_idx,
seal::MemoryPoolHandle &pool);
auto ProcessBinBundleCache(const shared_ptr<apsi::sender::SenderDB> &sender_db,
const apsi::CryptoContext &crypto_context,
reference_wrapper<const apsi::sender::BinBundleCache> cache,
std::vector<apsi::sender::CiphertextPowers> &all_powers,
uint32_t bundle_idx,
compr_mode_type compr_mode,
seal::MemoryPoolHandle &pool) ->
std::unique_ptr<apsi::network::ResultPackage>;
private:
retcode _LoadParams(Task &task);
std::unique_ptr<DBData> _LoadDataset(void);
std::unique_ptr<apsi::PSIParams> _SetPsiParams();
std::shared_ptr<apsi::sender::SenderDB> create_sender_db(const DBData& db_data,
std::unique_ptr<apsi::PSIParams> psi_params,
apsi::oprf::OPRFKey &oprf_key,
size_t nonce_byte_count,
bool compress);
std::shared_ptr<apsi::sender::SenderDB> LoadDbFromCache(const std::string& db_file_cache_);
std::unique_ptr<DBData> CreateDbData(std::shared_ptr<Dataset>& data);
std::vector<std::string> GetSelectedContent(std::shared_ptr<arrow::Table>& data_tbl,
const std::vector<int>& selected_col);
retcode CreateDbDataCache(const DBData& db_data,
std::unique_ptr<apsi::PSIParams> psi_params,
apsi::oprf::OPRFKey &oprf_key,
size_t nonce_byte_count,
bool compress);
bool DbCacheAvailable(const std::string& db_file_cache);
private:
std::string dataset_path_;
std::string dataset_id_;
std::string db_cache_dir_{"data/cache"};
std::string db_file_cache_;
uint32_t data_port{2222};
std::string client_address;
primihub::Node client_node_;
std::string key{"default"};
std::string psi_params_str_;
std::unique_ptr<apsi::oprf::OPRFKey> oprf_key_{nullptr};
bool generate_db_offline_{false};
};
} // namespace primihub::task
#endif // SRC_PRIMIHUB_TASK_SEMANTIC_KEYWORD_PIR_SERVER_TASK_H_
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "src/primihub/task/semantic/pir_client_task.h"
#include <string>
#include "src/primihub/data_store/factory.h"
#include "src/primihub/util/util.h"
using arrow::Table;
using arrow::StringArray;
using arrow::Int64Builder;
using primihub::rpc::VarType;
namespace primihub::task {
int validateDirection(std::string file_path) {
int pos = file_path.find_last_of('/');
std::string path;
if (pos > 0) {
path = file_path.substr(0, pos);
if (access(path.c_str(), 0) == -1) {
std::string cmd = "mkdir -p " + path;
int ret = system(cmd.c_str());
if (ret)
return -1;
}
}
return 0;
}
PIRClientTask::PIRClientTask(const std::string &node_id,
const std::string &job_id,
const std::string &task_id,
const TaskParam *task_param,
std::shared_ptr<DatasetService> dataset_service)
: TaskBase(task_param, dataset_service), node_id_(node_id),
job_id_(job_id), task_id_(task_id) {}
int PIRClientTask::_LoadParams(Task &task) {
auto param_map = task.params().param_map();
try {
result_file_path_ = param_map["outputFullFilename"].value_string();
server_address_ = param_map["serverAddress"].value_string();
server_dataset_ = param_map[server_address_].value_string();
db_size_ = stoi(param_map["databaseSize"].value_string()); // temperarily read db size direly from frontend
std::vector<std::string> tmp_indices;
str_split(param_map["queryIndeies"].value_string(), &tmp_indices, ',');
for (std::string &index : tmp_indices) {
int idx = stoi(index);
indices_.push_back(idx);
}
} catch (std::exception &e) {
LOG(ERROR) << "Failed to load params: " << e.what();
return -1;
}
return 0;
}
int PIRClientTask::_SetUpDB(size_t __dbsize, size_t dimensions, size_t elem_size,
uint32_t plain_mod_bit_size, uint32_t bits_per_coeff,
bool use_ciphertext_multiplication = false) {
// db_size_ = dbsize;
encryption_params_ = pir::GenerateEncryptionParams(POLY_MODULUS_DEGREE,
plain_mod_bit_size);
pir_params_ = *(pir::CreatePIRParameters(db_size_, elem_size, dimensions, encryption_params_,
use_ciphertext_multiplication, bits_per_coeff));
client_ = *(PIRClient::Create(pir_params_));
if (client_ == nullptr) {
LOG(ERROR) << "Failed to create pir client.";
return -1;
}
return 0;
}
int PIRClientTask::_ProcessResponse(const ExecuteTaskResponse &taskResponse) {
pir::Response response;
size_t num_reply =
static_cast<size_t>(taskResponse.pir_response().reply().size());
for (size_t i = 0; i < num_reply; i++) {
pir::Ciphertexts* ptr_reply = response.add_reply();
size_t num_ct =
static_cast<std::int64_t>(taskResponse.pir_response().reply()[i].ct().size());
for (size_t j = 0; j < num_ct; j++) {
ptr_reply->add_ct(taskResponse.pir_response().reply()[i].ct()[j]);
}
}
auto result = client_->ProcessResponse(indices_, response);
if (result.ok()) {
for (size_t i = 0; i < std::move(result).value().size(); i++) {
result_.push_back(std::move(result).value()[i]);
}
} else {
LOG(ERROR) << "Failed to process pir server response: "
<< result.status();
return -1;
}
return 0;
}
int PIRClientTask::saveResult() {
arrow::MemoryPool *pool = arrow::default_memory_pool();
arrow::StringBuilder builder(pool);
for (std::int64_t i = 0; i < result_.size(); i++) {
builder.Append(result_[i]);
}
std::shared_ptr<arrow::Array> array;
builder.Finish(&array);
std::vector<std::shared_ptr<arrow::Field>> schema_vector = {
arrow::field("reslut", arrow::utf8())};
auto schema = std::make_shared<arrow::Schema>(schema_vector);
std::shared_ptr<arrow::Table> table = arrow::Table::Make(schema, {array});
std::shared_ptr<DataDriver> driver =
DataDirverFactory::getDriver("CSV", "pir result");
std::shared_ptr<CSVDriver> csv_driver =
std::dynamic_pointer_cast<CSVDriver>(driver);
if (validateDirection(result_file_path_)) {
LOG(ERROR) << "can't access file path: "
<< result_file_path_;
return -1;
}
int ret = csv_driver->write(table, result_file_path_);
if (ret != 0) {
LOG(ERROR) << "Save PIR result to file " << result_file_path_ << " failed.";
return -1;
}
LOG(INFO) << "Save PIR result to " << result_file_path_ << ".";
return 0;
}
uint32_t compute_plain_mod_bit_size(size_t dbsize, size_t elem_size) {
uint32_t plain_mod_bit_size = PLAIN_MOD_BIT_SIZE_UPBOUND;
while (true) {
plain_mod_bit_size--;
uint64_t elem_per_plaintext = POLY_MODULUS_DEGREE \
* (plain_mod_bit_size - 1) / 8 / elem_size;
uint64_t num_plaintext = dbsize / elem_per_plaintext + 1;
if (num_plaintext <=
(uint64_t)1 << (NOISE_BUDGET_BASE - 2 * plain_mod_bit_size))
{
break;
}
}
return plain_mod_bit_size;
}
int PIRClientTask::execute() {
int ret = _LoadParams(task_param_);
if (ret) {
LOG(ERROR) << "Pir client load task params failed.";
return ret;
}
size_t dimensions = 1;
size_t elem_size = ELEM_SIZE;
uint32_t plain_mod_bit_size = compute_plain_mod_bit_size(db_size_, elem_size);
bool use_ciphertext_multiplication = true;
uint32_t poly_modulus_degree = POLY_MODULUS_DEGREE;
uint32_t bits_per_coeff = 0;
ret = _SetUpDB(0, dimensions, elem_size, // temperarily read db size direly from frontend
plain_mod_bit_size, bits_per_coeff,
use_ciphertext_multiplication);
if (ret) {
LOG(ERROR) << "Failed to initialize pir client.";
return -1;
}
//pir::Request request_proto = std::move(client_->CreateRequest(indices_)).value();
pir::Request request_proto;
auto request_or = client_->CreateRequest(indices_);
if (request_or.ok()) {
request_proto = std::move(request_or).value();
} else {
LOG(ERROR) << "Pir create request failed: "
<< request_or.status();
return -1;
}
grpc::ClientContext client_context;
grpc::ChannelArguments channel_args;
channel_args.SetMaxReceiveMessageSize(128*1024*1024);
std::shared_ptr<grpc::Channel> channel =
grpc::CreateCustomChannel(server_address_, grpc::InsecureChannelCredentials(), channel_args);
std::unique_ptr<VMNode::Stub> stub = VMNode::NewStub(channel);
using stream_t = std::shared_ptr<grpc::ClientReaderWriter<ExecuteTaskRequest, ExecuteTaskResponse>>;
stream_t client_stream(stub->ExecuteTask(&client_context));
size_t limited_size = 1 << 21;
size_t query_num = request_proto.query().size();
const auto& querys = request_proto.query();
size_t sended_index{0};
std::vector<ExecuteTaskRequest> send_requests;
do {
ExecuteTaskRequest taskRequest;
PirRequest * ptr_request = taskRequest.mutable_pir_request();
ptr_request->set_galois_keys(request_proto.galois_keys());
ptr_request->set_relin_keys(request_proto.relin_keys());
size_t pack_size = 0;
for (size_t i = sended_index; i < query_num; i++) {
// calculate length of query
size_t query_size = 0;
const auto& query = querys[i];
for (const auto& ct : query.ct()) {
query_size += ct.size();
}
if (pack_size + query_size > limited_size) {
break;
}
auto query_ptr = ptr_request->add_query();
for (const auto& ct : query.ct()) {
query_ptr->add_ct(ct);
}
sended_index++;
}
auto *ptr_params = taskRequest.mutable_params()->mutable_param_map();
ParamValue pv;
pv.set_var_type(VarType::STRING);
pv.set_value_string(server_dataset_);
(*ptr_params)["serverData"] = pv;
send_requests.push_back(std::move(taskRequest));
if (sended_index >= query_num) {
break;
}
} while (true);
// send request to server
for (const auto& request : send_requests) {
client_stream->Write(request);
}
client_stream->WritesDone();
ExecuteTaskResponse taskResponse;
ExecuteTaskResponse recv_response;
auto pir_response = taskResponse.mutable_pir_response();
bool is_initialized{false};
while (client_stream->Read(&recv_response)) {
const auto& recv_pir_response = recv_response.pir_response();
if (!is_initialized) {
pir_response->set_ret_code(recv_pir_response.ret_code());
is_initialized = true;
}
for (const auto& reply : recv_pir_response.reply()) {
auto reply_ptr = pir_response->add_reply();
for (const auto& ct : reply.ct()) {
reply_ptr->add_ct(ct);
}
}
}
Status status = client_stream->Finish();
if (status.ok()) {
if (taskResponse.psi_response().ret_code()) {
LOG(ERROR) << "Node pir server process request error.";
return -1;
}
int ret = _ProcessResponse(taskResponse);
if (ret) {
LOG(ERROR) << "Node pir client process response failed.";
return -1;
}
ret = saveResult();
if (ret) {
LOG(ERROR) << "Pir save result failed.";
return -1;
}
} else {
LOG(ERROR) << "Pir server return error: "
<< status.error_code() << " " << status.error_message().c_str();
return -1;
}
return 0;
}
}
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef SRC_PRIMIHUB_TASK_SEMANTIC_PIR_CLIENT_TASK_H_
#define SRC_PRIMIHUB_TASK_SEMANTIC_PIR_CLIENT_TASK_H_
#include <grpc/grpc.h>
#include <grpcpp/channel.h>
#include <grpcpp/create_channel.h>
#include "pir/cpp/client.h"
#include "pir/cpp/database.h"
#include "pir/cpp/utils.h"
#include "pir/cpp/string_encoder.h"
#include <map>
#include <memory>
#include <string>
#include <set>
#include "src/primihub/protos/common.grpc.pb.h"
#include "src/primihub/protos/psi.grpc.pb.h"
#include "src/primihub/protos/worker.grpc.pb.h"
#include "src/primihub/task/semantic/task.h"
using pir::PIRParameters;
using pir::EncryptionParameters;
using pir::PIRClient;
// using grpc::ClientContext;
using grpc::Status;
using grpc::Channel;
using primihub::rpc::Ciphertexts;
using primihub::rpc::Task;
using primihub::rpc::ParamValue;
using primihub::rpc::PsiType;
using primihub::rpc::ExecuteTaskRequest;
using primihub::rpc::ExecuteTaskResponse;
using primihub::rpc::PirRequest;
using primihub::rpc::PirResponse;
using primihub::rpc::VMNode;
namespace primihub::task {
constexpr uint32_t POLY_MODULUS_DEGREE = 4096;
constexpr uint32_t ELEM_SIZE = 1024;
constexpr uint32_t PLAIN_MOD_BIT_SIZE_UPBOUND = 29;
constexpr uint32_t NOISE_BUDGET_BASE = 57;
class PIRClientTask : public TaskBase {
public:
explicit PIRClientTask(const std::string &node_id,
const std::string &job_id,
const std::string &task_id,
const TaskParam *task_param,
std::shared_ptr<DatasetService> dataset_service);
~PIRClientTask() {};
int execute() override;
int saveResult(void);
private:
int _LoadParams(Task &task);
int _SetUpDB(size_t dbsize, size_t dimensions, size_t elem_size,
uint32_t plain_mod_bit_size, uint32_t bits_per_coeff,
bool use_ciphertext_multiplication);
int _ProcessResponse(const ExecuteTaskResponse &taskResponse);
const std::string node_id_;
const std::string job_id_;
const std::string task_id_;
std::string server_address_;
std::string result_file_path_;
std::vector<size_t> indices_;
std::vector<std::string> result_;
std::string server_dataset_;
size_t db_size_;
std::shared_ptr<PIRParameters> pir_params_;
EncryptionParameters encryption_params_;
std::unique_ptr<PIRClient> client_;
};
} // namespace primihub::task
#endif // SRC_PRIMIHUB_TASK_SEMANTIC_PIR_CLIENT_TASK_H_
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "src/primihub/task/semantic/pir_server_task.h"
namespace primihub::task {
void initRequest(const PirRequest * request, pir::Request & pir_request) {
pir_request.set_galois_keys(request->galois_keys());
pir_request.set_relin_keys(request->relin_keys());
const size_t num_query = static_cast<size_t>(request->query().size());
for (size_t i = 0; i < num_query; i++) {
pir::Ciphertexts* ptr_query = pir_request.add_query();
const size_t num_ct =
static_cast<size_t>(request->query()[i].ct().size());
for (size_t j = 0; j < num_ct; j++) {
ptr_query->add_ct(request->query()[i].ct()[j]);
}
}
}
PIRServerTask::PIRServerTask(const std::string &node_id,
const ExecuteTaskRequest& request,
ExecuteTaskResponse *response,
std::shared_ptr<DatasetService> dataset_service)
: ServerTaskBase(&(request.params()), dataset_service) {
request_ = &(request.pir_request());
response_ = response->mutable_pir_response();
}
int PIRServerTask::loadParams(Params & params) {
auto param_map = params.param_map();
try {
dataset_path_ = param_map["serverData"].value_string();
} catch (std::exception &e) {
LOG(ERROR) << "Failed to load pir server params: " << e.what();
return -1;
}
return 0;
}
int PIRServerTask::loadDataset() {
// int ret = loadDatasetFromCSV(dataset_path_, 0, elements_, db_size_);
int ret = loadDatasetFromTXT(dataset_path_, elements_);
// file reading error or file empty
if (ret <= 0) {
LOG(ERROR) << "Load dataset for psi client failed.";
return -1;
}
LOG(INFO) << "db size = " << ret;
// output the dataset length
return ret;
}
int PIRServerTask::_SetUpDB(size_t dbsize, size_t dimensions,
size_t elem_size, uint32_t poly_modulus_degree,
uint32_t plain_mod_bit_size, uint32_t bits_per_coeff,
bool use_ciphertext_multiplication) {
encryption_params_ =
pir::GenerateEncryptionParams(poly_modulus_degree, plain_mod_bit_size);
pir_params_ = *(pir::CreatePIRParameters(dbsize, elem_size, dimensions, encryption_params_,
use_ciphertext_multiplication, bits_per_coeff));
db_size_ = dbsize;
if (elements_.size() > dbsize) {
LOG(ERROR) << "Dataset size is not equal dbsize:" << elements_.size();
for (int i = 0; i < elements_.size(); i++) {
LOG(INFO) << "elem: " << elements_[i];
}
return -1;
} else if (elements_.size() < dbsize) {
uint32_t seed = 42;
auto prng =
seal::UniformRandomGeneratorFactory::DefaultFactory()->create({seed});
for (int64_t i = elements_.size(); i < dbsize; i++) {
int rand_num = rand() % 80;
std::string rand_str(rand_num, 0);
prng->generate(rand_str.size(), reinterpret_cast<seal::SEAL_BYTE*>(rand_str.data()));
elements_.push_back(std::to_string(i) + std::to_string(i) + std::to_string(i) + rand_str);
}
}
std::vector<std::string> string_db;
string_db.resize(dbsize, std::string(elem_size, 0));
for (size_t i = 0; i < dbsize; ++i) {
for (int j = 0; j < elements_[i].length(); j++) {
string_db[i][j] = elements_[i][j];
}
}
auto db_status = pir::PIRDatabase::Create(string_db, pir_params_);
if (!db_status.ok()) {
LOG(ERROR) << db_status.status();
return -1;
} else {
pir_db_ = std::move(db_status).value();
}
return 0;
}
uint32_t compute_plain_mod_bit_size_server(size_t dbsize, size_t elem_size) {
uint32_t plain_mod_bit_size = PLAIN_MOD_BIT_SIZE_UPBOUND_SVR;
while (true) {
plain_mod_bit_size--;
uint64_t elem_per_plaintext = POLY_MODULUS_DEGREE_SVR \
* (plain_mod_bit_size - 1) / 8 / elem_size;
uint64_t num_plaintext = dbsize / elem_per_plaintext + 1;
if (num_plaintext <=
(uint64_t)1 << (NOISE_BUDGET_BASE_SVR - 2 * plain_mod_bit_size))
{
break;
}
}
return plain_mod_bit_size;
}
int PIRServerTask::execute() {
LOG(INFO) << "load parameters";
int ret = loadParams(params_);
if (ret) {
LOG(ERROR) << "Load parameters for pir server fialed.";
return -1;
}
LOG(INFO) << "parameters loaded";
LOG(INFO) << "load dataset";
int db_size = loadDataset();
if (db_size <= 0) {
LOG(ERROR) << "Load dataset for pir server failed.";
return -1;
}
LOG(INFO) << "dataset loaded";
size_t dimensions = 1;
size_t elem_size = ELEM_SIZE_SVR;
uint32_t plain_mod_bit_size = compute_plain_mod_bit_size_server(db_size, elem_size);
bool use_ciphertext_multiplication = true;
uint32_t poly_modulus_degree = POLY_MODULUS_DEGREE_SVR;
uint32_t bits_per_coeff = 0;
LOG(INFO) << "create database";
ret = _SetUpDB(db_size, dimensions, elem_size, poly_modulus_degree,
plain_mod_bit_size, bits_per_coeff,
use_ciphertext_multiplication);
if (ret) {
LOG(ERROR) << "Create pir db failed.";
return -1;
}
LOG(INFO) << "database created";
pir::Request pir_request;
initRequest(request_, pir_request);
LOG(INFO) << "create server";
std::unique_ptr<pir::PIRServer> server =
*(pir::PIRServer::Create(pir_db_, pir_params_));
if (server == nullptr) {
LOG(ERROR) << "Failed to create pir server";
return -1;
}
LOG(INFO) << "server created";
LOG(INFO) << "process request";
auto result_status = server->ProcessRequest(pir_request);
if (!result_status.ok()) {
LOG(ERROR) << "Process pir request failed:"
<< result_status.status();
return -1;
}
LOG(INFO) << "request processed";
auto result_raw = std::move(result_status).value();
const size_t num_reply = static_cast<size_t>(result_raw.reply().size());
for (size_t i = 0; i < num_reply; i++) {
Ciphertexts* ptr_reply = response_->add_reply();
const size_t num_ct = static_cast<size_t>(result_raw.reply()[i].ct().size());
for (size_t j = 0; j < num_ct; j++) {
ptr_reply->add_ct(result_raw.reply()[i].ct()[j]);
}
}
return 0;
}
} // namespace primihub::task
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef SRC_PRIMIHUB_TASK_SEMANTIC_PIR_SERVER_TASK_H_
#define SRC_PRIMIHUB_TASK_SEMANTIC_PIR_SERVER_TASK_H_
#include <map>
#include <memory>
#include <string>
#include <stdlib.h>
#include "pir/cpp/server.h"
#include "pir/cpp/database.h"
#include "pir/cpp/utils.h"
#include "pir/cpp/string_encoder.h"
#include "src/primihub/protos/common.grpc.pb.h"
#include "src/primihub/protos/psi.grpc.pb.h"
#include "src/primihub/protos/worker.grpc.pb.h"
#include "src/primihub/task/semantic/private_server_base.h"
using std::shared_ptr;
using primihub::rpc::Params;
using primihub::rpc::Ciphertexts;
using primihub::rpc::PirRequest;
using primihub::rpc::PirResponse;
using primihub::rpc::ExecuteTaskRequest;
using primihub::rpc::ExecuteTaskResponse;
namespace primihub::task {
constexpr uint32_t POLY_MODULUS_DEGREE_SVR = 4096;
constexpr uint32_t ELEM_SIZE_SVR = 1024;
constexpr uint32_t PLAIN_MOD_BIT_SIZE_UPBOUND_SVR = 29;
constexpr uint32_t NOISE_BUDGET_BASE_SVR = 57;
class PIRServerTask : public ServerTaskBase {
public:
explicit PIRServerTask(const std::string &node_id,
const ExecuteTaskRequest& request,
ExecuteTaskResponse *response,
std::shared_ptr<DatasetService> dataset_service);
~PIRServerTask(){}
int loadParams(Params & params) override;
int loadDataset(void) override;
int execute() override;
private:
int _SetUpDB(size_t dbsize, size_t dimensions,
size_t elem_size, uint32_t poly_modulus_degree,
uint32_t plain_mod_bit_size, uint32_t bits_per_coeff,
bool use_ciphertext_multiplication);
//int data_col_;
std::string dataset_path_;
size_t db_size_;
shared_ptr<pir::PIRParameters> pir_params_;
pir::EncryptionParameters encryption_params_;
std::vector<std::string> elements_;
shared_ptr<pir::PIRDatabase> pir_db_;
const PirRequest * request_;
PirResponse * response_;
};
} // namespace primihub::task
#endif SRC_PRIMIHUB_TASK_SEMANTIC_PIR_SERVER_TASK_H_
//
#include "src/primihub/task/semantic/pir_task.h"
#include <glog/logging.h>
#include "src/primihub/kernel/pir/operator/base_pir.h"
#include "src/primihub/kernel/pir/operator/factory.h"
namespace primihub::task {
PirTask::PirTask(const TaskParam* task_param,
std::shared_ptr<DatasetService> dataset_service)
: TaskBase(task_param, dataset_service) {}
retcode PirTask::BuildOptions(const rpc::Task& task, pir::Options* options) {
// build Options for operator
options->self_party = this->party_name();
options->link_ctx_ref = getTaskContext().getLinkContext().get();
options->code = task.code();
auto& party_info = options->party_info;
const auto& pb_party_info = task.party_access_info();
for (const auto& [_party_name, pb_node] : pb_party_info) {
if (_party_name == SCHEDULER_NODE) {
continue;
}
Node node_info;
pbNode2Node(pb_node, &node_info);
party_info[_party_name] = std::move(node_info);
}
if (RoleValidation::IsServer(this->party_name())) {
// paramater for offline generate db info
const auto& param_map = task.params().param_map();
auto iter = param_map.find("DbInfo");
if (iter != param_map.end()) {
options->db_path = iter->second.value_string();
LOG(INFO) << "db_file_cache path: " << options->db_path;
if (this->dataset_id_.empty()) {
LOG(ERROR) << "dataset id is empty for party: " << party_name();
return retcode::FAIL;
}
ValidateDir(options->db_path);
options->generate_db = true;
} else {
// paramater for online task
if (this->dataset_id_.empty()) {
LOG(ERROR) << "dataset id is empty for party: " << party_name();
return retcode::FAIL;
}
// check db cache exist or not
options->db_path = db_cache_dir_ + "/" + this->dataset_id_;
if (DbCacheAvailable(options->db_path)) {
options->use_cache = true;
}
}
}
// peer node info
std::string peer_party_name;
if (RoleValidation::IsServer(this->party_name())) {
peer_party_name = PARTY_CLIENT;
} else if (RoleValidation::IsClient(this->party_name())) {
peer_party_name = PARTY_SERVER;
} else {
LOG(ERROR) << "invalid party: " << this->party_name();
}
auto it = party_info.find(peer_party_name);
if (it != party_info.end()) {
options->peer_node = it->second;
} else {
LOG(WARNING) << "find peer node info failed for party: " << party_name();
}
// end of build Options
return retcode::SUCCESS;
}
retcode PirTask::LoadParams(const rpc::Task& task) {
const auto& param_map = task.params().param_map();
auto iter = param_map.find("pirType");
if (iter != param_map.end()) {
pir_type_ = iter->second.value_int32();
}
const auto& party_datasets = task.party_datasets();
auto dataset_it = party_datasets.find(party_name());
if (dataset_it != party_datasets.end()) {
const auto& datasets_map = dataset_it->second.data();
auto it = datasets_map.find(party_name());
if (it == datasets_map.end()) {
LOG(WARNING) << "no datasets is set for party: " << party_name();
} else {
dataset_id_ = it->second;
VLOG(5) << "data set id: " << dataset_id_;
}
}
auto ret = BuildOptions(task, &this->options_);
if (ret != retcode::SUCCESS) {
LOG(ERROR) << "build operator options for party: "
<< party_name() << " failed";
return retcode::FAIL;
}
if (RoleValidation::IsClient(this->party_name())) {
VLOG(7) << "dataset_id: " << dataset_id_;
auto it = param_map.find("outputFullFilename");
if (it != param_map.end()) {
result_file_path_ = it->second.value_string();
VLOG(5) << "result_file_path_: " << result_file_path_;
} else {
LOG(ERROR) << "no keyword outputFullFilename match";
return retcode::FAIL;
}
GetServerDataSetSchema(task);
}
return retcode::SUCCESS;
}
retcode PirTask::GetServerDataSetSchema(const rpc::Task& task) {
// get server dataset id
const auto& party_datasets = task.party_datasets();
auto it = party_datasets.find(PARTY_SERVER);
if (it == party_datasets.end()) {
LOG(WARNING) << "no dataset found for party_name: " << PARTY_SERVER;
return retcode::FAIL;
}
const auto& datasets_map = it->second.data();
auto iter = datasets_map.find(PARTY_SERVER);
if (iter == datasets_map.end()) {
LOG(WARNING) << "no dataset found for party_name: " << PARTY_SERVER;
return retcode::FAIL;
}
auto& server_dataset_id = iter->second;
auto& dataset_service = this->getDatasetService();
auto driver = dataset_service->getDriver(server_dataset_id);
if (driver == nullptr) {
LOG(WARNING) << "no dataset access info found for id: "
<< server_dataset_id;
return retcode::FAIL;
}
auto& access_info = driver->dataSetAccessInfo();
if (access_info == nullptr) {
LOG(WARNING) << "no dataset access info found for id: "
<< server_dataset_id;
return retcode::FAIL;
}
auto& schema = access_info->Schema();
for (const auto& field : schema) {
server_dataset_schema_.push_back(std::get<0>(field));
}
return retcode::SUCCESS;
}
retcode PirTask::LoadDataset() {
CHECK_TASK_STOPPED(retcode::FAIL);
if (RoleValidation::IsClient(this->party_name())) {
return ClientLoadDataset();
} else if (RoleValidation::IsServer(this->party_name())) {
return ServerLoadDataset();
} else {
LOG(WARNING) << "party: " << this->party_name()
<< " does not load dataset";
return retcode::SUCCESS;
}
}
retcode PirTask::ClientLoadDataset() {
const auto& param_map = getTaskParam()->params().param_map();
auto client_data_it = param_map.find("clientData");
if (client_data_it != param_map.end()) {
auto& client_data = client_data_it->second;
if (client_data.is_array()) {
const auto& items = client_data.value_string_array().value_string_array();
for (const auto& item : items) {
elements_[item];
}
if (elements_.empty()) {
LOG(ERROR) << "no query data set by client";
return retcode::FAIL;
}
} else {
auto item = client_data.value_string();
elements_[item];
}
return retcode::SUCCESS;
}
if (this->dataset_id_.empty()) {
LOG(ERROR) << "no dataset found for client: " << party_name();
return retcode::FAIL;
}
VLOG(7) << "dataset_id: " << this->dataset_id_;
auto data_ptr = LoadDataSetInternal(this->dataset_id_);
if (data_ptr == nullptr) {
LOG(ERROR) << "read data for dataset id: "
<< this->dataset_id_ << " failed";
return retcode::FAIL;
}
auto& table = std::get<std::shared_ptr<arrow::Table>>(data_ptr->data);
std::vector<int> key_col = {0};
auto key_array = GetSelectedContent(table, key_col);
for (auto& item : key_array) {
elements_[item];
}
return retcode::SUCCESS;
}
retcode PirTask::ServerLoadDataset() {
if (this->options_.use_cache) {
VLOG(0) << "using cache data for party: " << party_name();
return retcode::SUCCESS;
}
auto data_ptr = LoadDataSetInternal(this->dataset_id_);
if (data_ptr == nullptr) {
LOG(ERROR) << "read data for dataset id: "
<< this->dataset_id_ << " failed";
return retcode::FAIL;
}
auto& table = std::get<std::shared_ptr<arrow::Table>>(data_ptr->data);
int col_count = table->num_columns();
size_t row_count = table->num_rows();
if (col_count < 2) {
LOG(ERROR) << "data for server must have lable";
return retcode::FAIL;
}
std::vector<int> key_col = {0};
auto key_array = GetSelectedContent(table, key_col);
// get label
std::vector<int> value_col;
for (int i = 1; i < col_count; i++) {
value_col.push_back(i);
}
if (value_col.empty()) {
LOG(ERROR) << "no selected colum for lable";
return retcode::FAIL;
}
auto value_array = GetSelectedContent(table, value_col);
elements_.reserve(key_array.size());
for (size_t i = 0; i < key_array.size(); ++i) {
auto& key = key_array[i];
auto& value = value_array[i];
auto it = elements_.find(key);
if (it != elements_.end()) {
it->second.push_back(value);
} else {
std::vector<std::string> vec;
vec.push_back(value);
elements_.insert({key, std::move(vec)});
}
}
return retcode::SUCCESS;
}
std::shared_ptr<Dataset> PirTask::LoadDataSetInternal(
const std::string& dataset_id) {
auto driver = this->getDatasetService()->getDriver(dataset_id);
if (driver == nullptr) {
LOG(ERROR) << "get driver for dataset: " << dataset_id << " failed";
return nullptr;
}
auto cursor = driver->GetCursor();
if (cursor == nullptr) {
LOG(ERROR) << "init cursor failed for dataset id: " << dataset_id;
return nullptr;
}
// maybe pass schema to get expected data type
// copy dataset schema, and change all filed to string
auto schema = driver->dataSetAccessInfo()->Schema();
for (auto& field : schema) {
auto& type = std::get<1>(field);
type = arrow::Type::type::STRING;
}
auto data = cursor->read(schema);
if (data == nullptr) {
LOG(ERROR) << "read data failed for dataset id: " << dataset_id;
return nullptr;
}
return data;
}
retcode PirTask::SaveResult() {
if (!NeedSaveResult()) {
return retcode::SUCCESS;
}
VLOG(0) << "save query result to : " << result_file_path_;
std::vector<std::shared_ptr<arrow::Field>> schema_vector;
std::vector<std::string> tmp_colums{"key", "value"};
for (const auto& col_name : tmp_colums) {
schema_vector.push_back(arrow::field(col_name, arrow::int64()));
}
std::vector<std::shared_ptr<arrow::Array>> arrow_array;
arrow::StringBuilder key_builder;
arrow::StringBuilder value_builder;
for (auto& [key, item_vec] : this->result_) {
for (const auto& item : item_vec) {
key_builder.Append(key);
value_builder.Append(item);
}
}
std::shared_ptr<arrow::Array> key_array;
key_builder.Finish(&key_array);
arrow_array.push_back(std::move(key_array));
std::shared_ptr<arrow::Array> value_array;
value_builder.Finish(&value_array);
arrow_array.push_back(std::move(value_array));
auto schema = std::make_shared<arrow::Schema>(schema_vector);
// std::shared_ptr<arrow::Table>
auto table = arrow::Table::Make(schema, arrow_array);
auto driver = DataDirverFactory::getDriver("CSV", "test address");
auto csv_driver = std::dynamic_pointer_cast<CSVDriver>(driver);
auto rtcode = csv_driver->Write(server_dataset_schema_, table, result_file_path_);
if (rtcode != retcode::SUCCESS) {
LOG(ERROR) << "save PIR data to file " << result_file_path_ << " failed.";
return retcode::FAIL;
}
return retcode::SUCCESS;
}
retcode PirTask::InitOperator() {
auto type = static_cast<primihub::pir::PirType>(pir_type_);
this->operator_ = primihub::pir::Factory::Create(type, options_);
if (this->operator_ == nullptr) {
LOG(ERROR) << "create pir operator failed";
return retcode::FAIL;
}
return retcode::SUCCESS;
}
retcode PirTask::ExecuteOperator() {
return operator_->Execute(elements_, &result_);
}
int PirTask::execute() {
SCopedTimer timer;
auto ret = LoadParams(task_param_);
if (ret != retcode::SUCCESS) {
LOG(ERROR) << "Pir load task params failed.";
return -1;
}
auto load_params_ts = timer.timeElapse();
VLOG(5) << "LoadParams time cost(ms): " << load_params_ts;
ret = LoadDataset();
if (ret != retcode::SUCCESS) {
LOG(ERROR) << "Pir load dataset failed.";
return -1;
}
auto load_dataset_ts = timer.timeElapse();
auto load_dataset_time_cost = load_dataset_ts - load_params_ts;
VLOG(5) << "LoadDataset time cost(ms): " << load_dataset_time_cost;
ret = InitOperator();
if (ret != retcode::SUCCESS) {
LOG(ERROR) << "Pir init operator failed.";
return -1;
}
auto init_op_ts = timer.timeElapse();
auto init_op_time_cost = init_op_ts - load_dataset_ts;
VLOG(5) << "InitOperator time cost(ms): " << init_op_time_cost;
ret = ExecuteOperator();
if (ret != retcode::SUCCESS) {
LOG(ERROR) << "Pir execute operator failed.";
return -1;
}
auto exec_op_ts = timer.timeElapse();
auto exec_op_time_cost = exec_op_ts - init_op_ts;
VLOG(5) << "ExecuteOperator time cost(ms): " << exec_op_time_cost;
ret = SaveResult();
if (ret != retcode::SUCCESS) {
LOG(ERROR) << "Pir save result failed.";
return -1;
}
auto save_res_ts = timer.timeElapse();
auto save_res_time_cost = save_res_ts - exec_op_ts;
VLOG(5) << "SaveResult time cost(ms): " << save_res_time_cost;
return 0;
}
std::vector<std::string> PirTask::GetSelectedContent(
std::shared_ptr<arrow::Table>& data_tbl,
const std::vector<int>& selected_col) {
// return std::vector<std::string>();
int col_count = data_tbl->num_columns();
size_t row_count = data_tbl->num_rows();
if (selected_col.empty()) {
LOG(ERROR) << "no col selected for data";
return std::vector<std::string>();
}
std::vector<std::string> content_array;
auto lable_ptr = data_tbl->column(selected_col[0]);
auto chunk_size = lable_ptr->num_chunks();
size_t total_row_count = col_count * chunk_size;
content_array.reserve(total_row_count);
for (int i = 0; i < chunk_size; ++i) {
auto array = std::static_pointer_cast<arrow::StringArray>(lable_ptr->chunk(i));
for (int64_t j = 0; j < array->length(); j++) {
content_array.push_back(array->GetString(j));
}
}
// process left colums
for (size_t i = 1; i < selected_col.size(); ++i) {
size_t index{0};
int col_index = selected_col[i];
auto lable_ptr = data_tbl->column(col_index);
int chunk_size = lable_ptr->num_chunks();
for (int j = 0; j < chunk_size; ++j) {
auto array = std::static_pointer_cast<arrow::StringArray>(lable_ptr->chunk(j));
for (int64_t k = 0; k < array->length(); ++k) {
content_array[index++].append(",").append(array->GetString(k));
}
}
}
return content_array;
}
bool PirTask::NeedSaveResult() {
if (RoleValidation::IsClient(this->party_name())) {
return true;
}
return false;
}
} // namespace primihub::task
// Copyright 2023 <PrimiHub>
#ifndef SRC_PRIMIHUB_TASK_SEMANTIC_PIR_TASK_H_
#define SRC_PRIMIHUB_TASK_SEMANTIC_PIR_TASK_H_
#include <string>
#include "src/primihub/task/semantic/task.h"
#include "src/primihub/common/common.h"
#include "src/primihub/service/dataset/service.h"
#include "src/primihub/kernel/pir/common.h"
#include "src/primihub/kernel/pir/operator/base_pir.h"
#include "src/primihub/util/util.h"
#include "src/primihub/util/file_util.h"
namespace primihub::task {
using BasePirOperator = primihub::pir::BasePirOperator;
class PirTask : public TaskBase {
public:
PirTask(const TaskParam *task_param,
std::shared_ptr<DatasetService> dataset_service);
~PirTask() = default;
int execute() override;
protected:
retcode LoadParams(const rpc::Task& task);
retcode GetServerDataSetSchema(const rpc::Task& task);
retcode LoadDataset();
retcode ClientLoadDataset();
retcode ServerLoadDataset();
std::shared_ptr<Dataset> LoadDataSetInternal(const std::string& dataset_id);
bool DbCacheAvailable(const std::string& db_file_cache) {
return FileExists(db_file_cache);
}
std::vector<std::string> GetSelectedContent(
std::shared_ptr<arrow::Table>& data_tbl,
const std::vector<int>& selected_col);
retcode SaveResult();
retcode InitOperator();
retcode ExecuteOperator();
retcode BuildOptions(const rpc::Task& task,
primihub::pir::Options* option);
bool NeedSaveResult();
private:
int pir_type_{rpc::PirType::KEY_PIR};
std::string dataset_path_;
std::string dataset_id_;
std::string result_file_path_;
primihub::pir::PirDataType elements_;
primihub::pir::PirDataType result_;
primihub::pir::Options options_;
std::string db_cache_dir_{"data/cache"};
std::unique_ptr<BasePirOperator> operator_{nullptr};
std::vector<std::string> server_dataset_schema_;
// std::string dataset_path_;
// std::string dataset_id_;
// std::string db_file_cache_;
// primihub::Node client_node_;
// std::string key{"key_pir"};
// std::string psi_params_str_;
// std::unique_ptr<apsi::oprf::OPRFKey> oprf_key_{nullptr};
// bool generate_db_offline_{false};
};
} // namespace primihub::task
#endif // SRC_PRIMIHUB_TASK_SEMANTIC_PIR_TASK_H_
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "src/primihub/task/semantic/private_server_base.h"
#include "src/primihub/data_store/factory.h"
#include <fstream>
using arrow::Array;
using arrow::StringArray;
using arrow::Table;
namespace primihub::task {
ServerTaskBase::ServerTaskBase(const Params *params,
std::shared_ptr<DatasetService> dataset_service)
: dataset_service_(dataset_service) {
setTaskParam(params);
}
Params* ServerTaskBase::getTaskParam() {
return &params_;
}
void ServerTaskBase::setTaskParam(const Params *params) {
params_.CopyFrom(*params);
}
int ServerTaskBase::loadDatasetFromSQLite(const std::string& conn_str, int data_col,
std::vector<std::string>& col_array, int64_t max_num) {
std::string nodeaddr("localhost"); // TODO
// std::shared_ptr<DataDriver>
auto driver = DataDirverFactory::getDriver("SQLITE", nodeaddr);
if (driver == nullptr) {
LOG(ERROR) << "create sqlite db driver failed";
return -1;
}
// std::shared_ptr<Cursor> &cursor
auto cursor = driver->read(conn_str);
// std::shared_ptr<Dataset>
auto ds = cursor->read();
if (ds == nullptr) {
return -1;
}
auto table = std::get<std::shared_ptr<Table>>(ds->data);
int num_col = table->num_columns();
if (num_col < data_col) {
LOG(ERROR) << "psi dataset colunum number is smaller than data_col";
return -1;
}
auto array = std::static_pointer_cast<StringArray>(table->column(data_col)->chunk(0));
for (int64_t i = 0; i < array->length(); i++) {
if (max_num > 0 && max_num == i) {
break;
}
col_array.push_back(array->GetString(i));
}
VLOG(5) << "psi server loaded data records: " << col_array.size();
return array->length();
}
int ServerTaskBase::loadDatasetFromCSV(const std::string& filename, int data_col,
std::vector<std::string> &col_array,
int64_t max_num) {
std::string nodeaddr("test address"); // TODO
std::shared_ptr<DataDriver> driver =
DataDirverFactory::getDriver("CSV", nodeaddr);
auto cursor = driver->read(filename);
auto ds = cursor->read();
std::shared_ptr<Table> table = std::get<std::shared_ptr<Table>>(ds->data);
int num_col = table->num_columns();
if (num_col < data_col) {
LOG(ERROR) << "psi dataset colunum number is smaller than data_col";
return -1;
}
int64_t num_rows = table->num_rows();
int64_t num_records = max_num > 0 ? max_num : num_rows;
col_array.reserve(num_records);
auto col_ptr = table->column(data_col);
int chunk_size = col_ptr->num_chunks();
for (int i = 0; i < chunk_size; i++) {
auto array = std::static_pointer_cast<StringArray>(col_ptr->chunk(i));
for (size_t j = 0; j < array->length(); j++) {
col_array.push_back(array->GetString(j));
if (max_num > 0 && max_num == col_array.size()) {
return col_array.size();
}
}
}
return col_array.size();
}
int ServerTaskBase::loadDatasetFromTXT(std::string &filename,
std::vector <std::string> &col_array) {
LOG(INFO) << "loading file ...";
std::ifstream infile;
infile.open(filename);
col_array.clear();
std::string tmp;
std::getline(infile, tmp); // ignore the first line
while (std::getline(infile, tmp)) {
col_array.push_back(tmp);
}
infile.close();
return col_array.size();
}
} // namespace primihub::task
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef SRC_PRIMIHUB_TASK_SEMANTIC_PRIVATE_SERVER_BASE_H_
#define SRC_PRIMIHUB_TASK_SEMANTIC_PRIVATE_SERVER_BASE_H_
#include <map>
#include <memory>
#include <string>
#include <atomic>
#include <glog/logging.h>
#include "src/primihub/protos/common.pb.h"
#include "src/primihub/service/dataset/service.h"
#include "src/primihub/task/semantic/task.h"
using primihub::rpc::Params;
using primihub::service::DatasetService;
namespace primihub::task {
class ServerTaskBase {
public:
// using task_context_t = TaskContext<primihub::rpc::ExecuteTaskRequest, primihub::rpc::ExecuteTaskResponse>;
//
using task_context_t = TaskContext;
ServerTaskBase(const Params *params,
std::shared_ptr<DatasetService> dataset_service);
~ServerTaskBase(){}
virtual int execute() = 0;
virtual int loadParams(Params & params) = 0;
virtual int loadDataset(void) = 0;
virtual void kill_task() {
LOG(WARNING) << "task receives kill task request and stop stauts";
stop_.store(true);
task_context_.clean();
}
bool has_stopped() {
return stop_.load(std::memory_order_relaxed);
}
std::shared_ptr<DatasetService>& getDatasetService() {
return dataset_service_;
}
void setTaskParam(const Params *params);
Params* getTaskParam();
inline task_context_t& getTaskContext() {
return task_context_;
}
inline task_context_t* getMutableTaskContext() {
return &task_context_;
}
protected:
int loadDatasetFromCSV(const std::string &filename, int data_col,
std::vector <std::string> &col_array, int64_t max_num = 0);
int loadDatasetFromSQLite(const std::string& conn_str, int data_col,
std::vector<std::string>& col_array, int64_t max_num = 0);
int loadDatasetFromTXT(std::string &filename,
std::vector <std::string> &col_array);
std::atomic<bool> stop_{false};
Params params_;
std::shared_ptr<DatasetService> dataset_service_;
task_context_t task_context_;
};
} // namespace primihub::task
#endif // SRC_PRIMIHUB_TASK_SEMANTIC_PRIVATE_SERVER_BASE_H_
/*
Copyright 2022 Primihub
Copyright 2022 PrimiHub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......
//
#include "src/primihub/util/network/link_context.h"
namespace primihub::network {
void LinkContext::Clean() {
stop_.store(true);
LOG(WARNING) << "stop all in data queue";
{
std::lock_guard<std::mutex> lck(in_queue_mtx);
for(auto it = in_data_queue.begin(); it != in_data_queue.end(); ++it) {
it->second.shutdown();
}
}
LOG(WARNING) << "stop all out data queue";
{
std::lock_guard<std::mutex> lck(out_queue_mtx);
for(auto it = out_data_queue.begin(); it != out_data_queue.end(); ++it) {
it->second.shutdown();
}
}
LOG(WARNING) << "stop all complete queue";
{
std::lock_guard<std::mutex> lck(complete_queue_mtx);
for(auto it = complete_queue.begin(); it != complete_queue.end(); ++it) {
it->second.shutdown();
}
}
}
LinkContext::StringDataQueue& LinkContext::GetRecvQueue(
const std::string& key) {
std::unique_lock<std::mutex> lck(this->in_queue_mtx);
auto it = in_data_queue.find(key);
if (it != in_data_queue.end()) {
return it->second;
} else {
in_data_queue[key];
if (stop_.load(std::memory_order::memory_order_relaxed)) {
in_data_queue[key].shutdown();
}
return in_data_queue[key];
}
}
LinkContext::StringDataQueue& LinkContext::GetSendQueue(
const std::string& key) {
std::unique_lock<std::mutex> lck(this->out_queue_mtx);
auto it = out_data_queue.find(key);
if (it != out_data_queue.end()) {
return it->second;
} else {
return out_data_queue[key];
}
}
LinkContext::StatusDataQueue& LinkContext::GetCompleteQueue(
const std::string& key) {
std::unique_lock<std::mutex> lck(this->complete_queue_mtx);
auto it = complete_queue.find(key);
if (it != complete_queue.end()) {
return it->second;
} else {
return complete_queue[key];
}
}
retcode LinkContext::Send(const std::string& key,
const Node& dest_node,
const std::string& send_buf) {
std::string_view send_data_sv{send_buf.data(), send_buf.size()};
return Send(key, dest_node, send_data_sv);
}
retcode LinkContext::Send(const std::string& key,
const Node& dest_node,
std::string_view send_buf_sv) {
auto ch = getChannel(dest_node);
return ch->send(key, send_buf_sv);
}
retcode LinkContext::Send(const std::string& key,
const Node& dest_node,
char* send_buf, size_t send_size) {
std::string_view send_data_sv{send_buf, send_size};
return Send(key, dest_node, send_data_sv);
}
retcode LinkContext::Recv(const std::string& key, std::string* recv_buf) {
std::string recv_buf_tmp;
auto& recv_queue = GetRecvQueue(key);
recv_queue.wait_and_pop(recv_buf_tmp);
*recv_buf = std::move(recv_buf_tmp);
return retcode::SUCCESS;
}
retcode LinkContext::Recv(const std::string& key,
char* recv_buf, size_t recv_size) {
std::string recv_buf_tmp;
auto& recv_queue = GetRecvQueue(key);
recv_queue.wait_and_pop(recv_buf_tmp);
if (recv_size != recv_buf_tmp.size()) {
LOG(ERROR) << "recv data does not match, expected: " << recv_size
<< " but get: " << recv_buf_tmp.size();
return retcode::FAIL;
}
memcpy(recv_buf, recv_buf_tmp.data(), recv_size);
return retcode::SUCCESS;
}
retcode LinkContext::SendRecv(const std::string& key,
const Node& dest_node,
std::string_view send_buf,
std::string* recv_buf) {
auto channel = getChannel(dest_node);
auto ret = channel->sendRecv(key, send_buf, recv_buf);
if (ret != retcode::SUCCESS) {
LOG(ERROR) << "send data to peer: [" << dest_node.to_string()
<< "] failed";
return ret;
}
return retcode::SUCCESS;
}
retcode LinkContext::SendRecv(const std::string& key,
const Node& dest_node,
const std::string& send_buf,
std::string* recv_buf) {
auto send_buf_sv = std::string_view(send_buf.data(), send_buf.size());
return SendRecv(key, dest_node, send_buf_sv, recv_buf);
}
retcode LinkContext::SendRecv(const std::string& key,
const Node& dest_node,
const char* send_buf,
size_t length,
std::string* recv_buf) {
auto send_buf_sv = std::string_view(send_buf, length);
return SendRecv(key, dest_node, send_buf_sv, recv_buf);
}
retcode LinkContext::SendRecv(const std::string& key,
const std::string& send_buf,
std::string* recv_buf) {
std::string recv_buf_tmp;
auto& recv_queue = this->GetRecvQueue(key);
recv_queue.wait_and_pop(recv_buf_tmp);
*recv_buf = std::move(recv_buf_tmp);
if (HasStopped()) {
LOG(ERROR) << "link context has been closed";
return retcode::FAIL;
}
auto& send_queue = this->GetSendQueue(key);
send_queue.push(send_buf);
auto& complete_queue = this->GetCompleteQueue(key);
retcode complete_flag;
complete_queue.wait_and_pop(complete_flag);
return retcode::SUCCESS;
}
} // namespace primihub::network
......@@ -20,6 +20,10 @@ class IChannel;
*/
class LinkContext {
public:
using StringDataQueue = primihub::ThreadSafeQueue<std::string>;
using StringDataContainer = std::unordered_map<std::string, StringDataQueue>;
using StatusDataQueue = primihub::ThreadSafeQueue<retcode>;
using StatusDataContainer = std::unordered_map<std::string, StatusDataQueue>;
LinkContext() = default;
virtual ~LinkContext() = default;
inline void setTaskInfo(const std::string& job_id,
......@@ -43,11 +47,17 @@ class LinkContext {
* if channel is not exist, create
*/
virtual std::shared_ptr<IChannel> getChannel(const primihub::Node& node) = 0;
void setRecvTimeout(int32_t recv_timeout_ms) {recv_timeout_ms_ = recv_timeout_ms;}
void setSendTimeout(int32_t send_timeout_ms) {send_timeout_ms_ = send_timeout_ms;}
inline void setRecvTimeout(const int32_t recv_timeout_ms) {
recv_timeout_ms_ = recv_timeout_ms;
}
inline void setSendTimeout(const int32_t send_timeout_ms) {
send_timeout_ms_ = send_timeout_ms;
}
int32_t sendTimeout() const {return send_timeout_ms_;}
int32_t recvTimeout() const {return recv_timeout_ms_;}
primihub::common::CertificateConfig& getCertificateConfig() {return *cert_config_;}
primihub::common::CertificateConfig& getCertificateConfig() {
return *cert_config_;
}
void initCertificate(const std::string& root_ca_path,
const std::string& key_path,
......@@ -62,69 +72,46 @@ class LinkContext {
return retcode::SUCCESS;
}
primihub::ThreadSafeQueue<std::string>&
GetRecvQueue(const std::string& key = "default") {
std::unique_lock<std::mutex> lck(this->in_queue_mtx);
auto it = in_data_queue.find(key);
if (it != in_data_queue.end()) {
return it->second;
} else {
in_data_queue[key];
if (stop_.load(std::memory_order::memory_order_relaxed)) {
in_data_queue[key].shutdown();
}
return in_data_queue[key];
}
}
primihub::ThreadSafeQueue<std::string>&
GetSendQueue(const std::string& key = "default") {
std::unique_lock<std::mutex> lck(this->out_queue_mtx);
auto it = out_data_queue.find(key);
if (it != out_data_queue.end()) {
return it->second;
} else {
return out_data_queue[key];
}
}
StringDataQueue& GetRecvQueue(const std::string& key = "default");
StringDataQueue& GetSendQueue(const std::string& key = "default");
StatusDataQueue& GetCompleteQueue(const std::string& role = "default");
primihub::ThreadSafeQueue<retcode>&
GetCompleteQueue(const std::string& role = "default") {
std::unique_lock<std::mutex> lck(this->complete_queue_mtx);
auto it = complete_queue.find(role);
if (it != complete_queue.end()) {
return it->second;
} else {
return complete_queue[role];
}
}
void Clean();
retcode Send(const std::string& key,
const Node& dest_node, const std::string& send_buf);
retcode Send(const std::string& key,
const Node& dest_node, std::string_view send_buf);
retcode Send(const std::string& key,
const Node& dest_node, char* send_buf, size_t send_size);
retcode Recv(const std::string& key, std::string* recv_buf);
retcode Recv(const std::string& key, char* recv_buf, size_t recv_size);
/**
* sender to process send recv
*/
retcode SendRecv(const std::string& key,
const Node& dest_node,
const std::string& send_buf,
std::string* recv_buf);
retcode SendRecv(const std::string& key,
const Node& dest_node,
std::string_view send_buf,
std::string* recv_buf);
retcode SendRecv(const std::string& key,
const Node& dest_node,
const char* send_buf, size_t length,
std::string* recv_buf);
/**
* receiver to process send recv
*/
retcode SendRecv(const std::string& key,
const std::string& send_buf,
std::string* recv_buf);
void Clean() {
stop_.store(true);
LOG(ERROR) << "stop all in data queue";
{
std::lock_guard<std::mutex> lck(in_queue_mtx);
for(auto it = in_data_queue.begin(); it != in_data_queue.end(); ++it) {
it->second.shutdown();
}
}
LOG(ERROR) << "stop all out data queue";
{
std::lock_guard<std::mutex> lck(out_queue_mtx);
for(auto it = out_data_queue.begin(); it != out_data_queue.end(); ++it) {
it->second.shutdown();
}
}
LOG(ERROR) << "stop all complete queue";
{
std::lock_guard<std::mutex> lck(complete_queue_mtx);
for(auto it = complete_queue.begin(); it != complete_queue.end(); ++it) {
it->second.shutdown();
}
}
}
protected:
bool HasStopped() {
return stop_.load(std::memory_order::memory_order_relaxed);
}
int32_t recv_timeout_ms_{-1};
int32_t send_timeout_ms_{-1};
std::shared_mutex connection_mgr_mtx;
......@@ -135,11 +122,13 @@ class LinkContext {
std::unique_ptr<primihub::common::CertificateConfig> cert_config_{nullptr};
std::mutex in_queue_mtx;
std::unordered_map<std::string, primihub::ThreadSafeQueue<std::string>> in_data_queue;
StringDataContainer in_data_queue;
std::mutex out_queue_mtx;
std::unordered_map<std::string, primihub::ThreadSafeQueue<std::string>> out_data_queue;
StringDataContainer out_data_queue;
std::mutex complete_queue_mtx;
std::unordered_map<std::string, primihub::ThreadSafeQueue<retcode>> complete_queue;
StatusDataContainer complete_queue;
std::atomic<bool> stop_{false};
};
......@@ -150,19 +139,26 @@ class IChannel {
virtual ~IChannel() = default;
virtual retcode send(const std::string& key, const std::string& data) = 0;
virtual retcode send(const std::string& key, std::string_view sv_data) = 0;
virtual bool send_wrapper(const std::string& key, const std::string& data) = 0;
virtual bool send_wrapper(const std::string& key, std::string_view sv_data) = 0;
virtual bool send_wrapper(const std::string& key,
const std::string& data) = 0;
virtual bool send_wrapper(const std::string& key,
std::string_view sv_data) = 0;
virtual retcode sendRecv(const std::string& key,
const std::string& send_data,
std::string* recv_data) = 0;
virtual retcode sendRecv(const std::string& key,
std::string_view send_data,
std::string* recv_data) = 0;
virtual retcode submitTask(const rpc::PushTaskRequest& request, rpc::PushTaskReply* reply) = 0;
virtual retcode executeTask(const rpc::PushTaskRequest& request, rpc::PushTaskReply* reply) = 0;
virtual retcode killTask(const rpc::KillTaskRequest& request, rpc::KillTaskResponse* reply) = 0;
virtual retcode updateTaskStatus(const rpc::TaskStatus& request, rpc::Empty* reply) = 0;
virtual retcode fetchTaskStatus(const rpc::TaskContext& request, rpc::TaskStatusReply* reply) = 0;
virtual retcode submitTask(const rpc::PushTaskRequest& request,
rpc::PushTaskReply* reply) = 0;
virtual retcode executeTask(const rpc::PushTaskRequest& request,
rpc::PushTaskReply* reply) = 0;
virtual retcode killTask(const rpc::KillTaskRequest& request,
rpc::KillTaskResponse* reply) = 0;
virtual retcode updateTaskStatus(const rpc::TaskStatus& request,
rpc::Empty* reply) = 0;
virtual retcode fetchTaskStatus(const rpc::TaskContext& request,
rpc::TaskStatusReply* reply) = 0;
virtual std::string forwardRecv(const std::string& key) = 0;
LinkContext* getLinkContext() {
return link_ctx_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册