Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Primihub
PrimiHub
提交
47e65456
P
PrimiHub
项目概览
Primihub
/
PrimiHub
9 个月 前同步成功
通知
21
Star
1
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PrimiHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
47e65456
编写于
8月 16, 2023
作者:
PhoenixTree2013
提交者:
GitHub
8月 16, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pir refactor (#597)
上级
405e10f6
变更
30
展开全部
隐藏空白更改
内联
并排
Showing
30 changed file
with
1817 addition
and
2503 deletion
+1817
-2503
src/primihub/common/common.h
src/primihub/common/common.h
+16
-0
src/primihub/kernel/pir/BUILD
src/primihub/kernel/pir/BUILD
+5
-0
src/primihub/kernel/pir/common.h
src/primihub/kernel/pir/common.h
+15
-0
src/primihub/kernel/pir/operator/BUILD
src/primihub/kernel/pir/operator/BUILD
+40
-0
src/primihub/kernel/pir/operator/base_pir.cc
src/primihub/kernel/pir/operator/base_pir.cc
+8
-0
src/primihub/kernel/pir/operator/base_pir.h
src/primihub/kernel/pir/operator/base_pir.h
+49
-0
src/primihub/kernel/pir/operator/factory.h
src/primihub/kernel/pir/operator/factory.h
+29
-0
src/primihub/kernel/pir/operator/id_pir.cc
src/primihub/kernel/pir/operator/id_pir.cc
+2
-0
src/primihub/kernel/pir/operator/id_pir.h
src/primihub/kernel/pir/operator/id_pir.h
+5
-0
src/primihub/kernel/pir/operator/keyword_pir.cc
src/primihub/kernel/pir/operator/keyword_pir.cc
+784
-0
src/primihub/kernel/pir/operator/keyword_pir.h
src/primihub/kernel/pir/operator/keyword_pir.h
+142
-0
src/primihub/node/worker/worker.cc
src/primihub/node/worker/worker.cc
+0
-4
src/primihub/node/worker/worker.h
src/primihub/node/worker/worker.h
+0
-5
src/primihub/task/semantic/BUILD
src/primihub/task/semantic/BUILD
+8
-60
src/primihub/task/semantic/factory.h
src/primihub/task/semantic/factory.h
+2
-64
src/primihub/task/semantic/keyword_pir_client_task.cc
src/primihub/task/semantic/keyword_pir_client_task.cc
+0
-414
src/primihub/task/semantic/keyword_pir_client_task.h
src/primihub/task/semantic/keyword_pir_client_task.h
+0
-95
src/primihub/task/semantic/keyword_pir_server_task.cc
src/primihub/task/semantic/keyword_pir_server_task.cc
+0
-777
src/primihub/task/semantic/keyword_pir_server_task.h
src/primihub/task/semantic/keyword_pir_server_task.h
+0
-124
src/primihub/task/semantic/pir_client_task.cc
src/primihub/task/semantic/pir_client_task.cc
+0
-298
src/primihub/task/semantic/pir_client_task.h
src/primihub/task/semantic/pir_client_task.h
+0
-99
src/primihub/task/semantic/pir_server_task.cc
src/primihub/task/semantic/pir_server_task.cc
+0
-202
src/primihub/task/semantic/pir_server_task.h
src/primihub/task/semantic/pir_server_task.h
+0
-84
src/primihub/task/semantic/pir_task.cc
src/primihub/task/semantic/pir_task.cc
+418
-0
src/primihub/task/semantic/pir_task.h
src/primihub/task/semantic/pir_task.h
+66
-0
src/primihub/task/semantic/private_server_base.cc
src/primihub/task/semantic/private_server_base.cc
+0
-125
src/primihub/task/semantic/private_server_base.h
src/primihub/task/semantic/private_server_base.h
+0
-79
src/primihub/task/semantic/task.h
src/primihub/task/semantic/task.h
+1
-1
src/primihub/util/network/link_context.cc
src/primihub/util/network/link_context.cc
+159
-0
src/primihub/util/network/link_context.h
src/primihub/util/network/link_context.h
+68
-72
未找到文件。
src/primihub/common/common.h
浏览文件 @
47e65456
...
...
@@ -59,6 +59,22 @@ enum class Visibility {
PRIVATE
=
1
,
};
class
RoleValidation
{
public:
static
bool
IsClient
(
const
std
::
string
&
party_name
)
{
return
party_name
==
PARTY_CLIENT
;
}
static
bool
IsServer
(
const
std
::
string
&
party_name
)
{
return
party_name
==
PARTY_SERVER
;
}
static
bool
IsTeeCompute
(
const
std
::
string
&
party_name
)
{
return
party_name
==
PARTY_TEE_COMPUTE
;
}
};
struct
Node
{
Node
()
=
default
;
Node
(
const
std
::
string
&
id
,
const
std
::
string
&
ip
,
...
...
src/primihub/kernel/pir/BUILD
0 → 100644
浏览文件 @
47e65456
package
(
default_visibility
=
[
"//visibility:public"
])
cc_library
(
name
=
"common_def"
,
hdrs
=
[
"common.h"
],
)
\ No newline at end of file
src/primihub/kernel/pir/common.h
0 → 100644
浏览文件 @
47e65456
// "Copyright [2023] <PrimiHub>"
#ifndef SRC_PRIMIHUB_KERNEL_PIR_COMMON_H_
#define SRC_PRIMIHUB_KERNEL_PIR_COMMON_H_
#include <unordered_map>
#include <vector>
#include <string>
namespace
primihub
::
pir
{
using
PirDataType
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
enum
class
PirType
{
ID_PIR
=
0
,
KEY_PIR
,
};
}
// namespace primihub::pir
#endif // SRC_PRIMIHUB_KERNEL_PIR_COMMON_H_
src/primihub/kernel/pir/operator/BUILD
0 → 100644
浏览文件 @
47e65456
package
(
default_visibility
=
[
"//visibility:public"
])
cc_library
(
name
=
"factory"
,
hdrs
=
[
"factory.h"
],
deps
=
[
"//src/primihub/kernel/pir:common_def"
,
":base_pir_operator"
,
":keyword_pir_operator"
,
]
)
cc_library
(
name
=
"base_pir_operator"
,
hdrs
=
[
"base_pir.h"
],
srcs
=
[
"base_pir.cc"
],
deps
=
[
"//src/primihub/kernel/pir:common_def"
,
"//src/primihub/util:endian_util"
,
"//src/primihub/util:util_lib"
,
"//src/primihub/common:common_defination"
,
"//src/primihub/util/network:communication_lib"
,
],
)
cc_library
(
name
=
"keyword_pir_operator"
,
hdrs
=
[
"keyword_pir.h"
],
srcs
=
[
"keyword_pir.cc"
],
copts
=
[
"-w"
,
"-D_ASPI"
,
],
deps
=
[
":base_pir_operator"
,
"//src/primihub/util:endian_util"
,
"//src/primihub/util:util_lib"
,
"//src/primihub/protos:worker_proto"
,
"@mircrosoft_apsi//:APSI"
,
]
)
src/primihub/kernel/pir/operator/base_pir.cc
0 → 100644
浏览文件 @
47e65456
// "Copyright [2023] <PrimiHub>"
#include "src/primihub/kernel/pir/operator/base_pir.h"
namespace
primihub
::
pir
{
retcode
BasePirOperator
::
Execute
(
const
PirDataType
&
input
,
PirDataType
*
result
)
{
return
OnExecute
(
input
,
result
);
}
}
// namespace primihub::pir
src/primihub/kernel/pir/operator/base_pir.h
0 → 100644
浏览文件 @
47e65456
// "Copyright [2023] <PrimiHub>"
#ifndef SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_BASE_PIR_H_
#define SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_BASE_PIR_H_
#include <map>
#include <string>
#include "src/primihub/common/common.h"
#include "src/primihub/kernel/pir/common.h"
#include "src/primihub/util/network/link_context.h"
namespace
primihub
::
pir
{
using
LinkContext
=
network
::
LinkContext
;
struct
Options
{
LinkContext
*
link_ctx_ref
;
std
::
map
<
std
::
string
,
Node
>
party_info
;
std
::
string
self_party
;
std
::
string
code
;
// online
bool
use_cache
{
false
};
// offline task
bool
generate_db
{
false
};
std
::
string
db_path
;
Node
peer_node
;
};
class
BasePirOperator
{
public:
explicit
BasePirOperator
(
const
Options
&
options
)
:
options_
(
options
)
{}
virtual
~
BasePirOperator
()
=
default
;
/**
* PSI protocol
*/
retcode
Execute
(
const
PirDataType
&
input
,
PirDataType
*
result
);
virtual
retcode
OnExecute
(
const
PirDataType
&
input
,
PirDataType
*
result
)
=
0
;
void
set_stop
()
{
stop_
.
store
(
true
);}
protected:
bool
has_stopped
()
{
return
stop_
.
load
(
std
::
memory_order
::
memory_order_relaxed
);
}
std
::
string
PartyName
()
{
return
options_
.
self_party
;}
LinkContext
*
GetLinkContext
()
{
return
options_
.
link_ctx_ref
;}
Node
&
PeerNode
()
{
return
options_
.
peer_node
;}
protected:
std
::
atomic
<
bool
>
stop_
{
false
};
Options
options_
;
std
::
string
key_
{
"pir_key"
};
};
}
// namespace primihub::pir
#endif // SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_BASE_PIR_H_
src/primihub/kernel/pir/operator/factory.h
0 → 100644
浏览文件 @
47e65456
// "Copyright [2023] <PrimiHub>"
#ifndef SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_FACTORY_H_
#define SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_FACTORY_H_
#include <glog/logging.h>
#include <memory>
#include "src/primihub/kernel/pir/common.h"
#include "src/primihub/kernel/pir/operator/keyword_pir.h"
namespace
primihub
::
pir
{
class
Factory
{
public:
static
std
::
unique_ptr
<
BasePirOperator
>
Create
(
PirType
pir_type
,
const
Options
&
options
)
{
std
::
unique_ptr
<
BasePirOperator
>
operator_ptr
{
nullptr
};
switch
(
pir_type
)
{
case
PirType
::
ID_PIR
:
LOG
(
ERROR
)
<<
"Unimplement"
;
break
;
case
PirType
::
KEY_PIR
:
operator_ptr
=
std
::
make_unique
<
KeywordPirOperator
>
(
options
);
break
;
default:
LOG
(
ERROR
)
<<
"unknown pir operator: "
<<
static_cast
<
int
>
(
pir_type
);
break
;
}
return
operator_ptr
;
}
};
}
// namespace primihub::pir
#endif // SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_FACTORY_H_
src/primihub/kernel/pir/operator/id_pir.cc
0 → 100644
浏览文件 @
47e65456
// "Copyright [2023] <PrimiHub>"
#include "src/primihub/kernel/pir/operator/id_pir.h"
src/primihub/kernel/pir/operator/id_pir.h
0 → 100644
浏览文件 @
47e65456
// "Copyright [2023] <PrimiHub>"
#ifndef SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_ID_PIR_H_
#define SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_ID_PIR_H_
#endif // SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_ID_PIR_H_
src/primihub/kernel/pir/operator/keyword_pir.cc
0 → 100644
浏览文件 @
47e65456
此差异已折叠。
点击以展开。
src/primihub/kernel/pir/operator/keyword_pir.h
0 → 100644
浏览文件 @
47e65456
// "Copyright [2023] <PrimiHub>"
#ifndef SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_KEYWORD_PIR_H_
#define SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_KEYWORD_PIR_H_
#include <variant>
#include <vector>
#include <memory>
#include <utility>
#include <string>
#include "src/primihub/kernel/pir/operator/base_pir.h"
#include "src/primihub/kernel/pir/common.h"
// APSI
#include "apsi/thread_pool_mgr.h"
#include "apsi/sender_db.h"
#include "apsi/oprf/oprf_sender.h"
#include "apsi/powers.h"
#include "apsi/util/common_utils.h"
#include "apsi/sender.h"
#include "apsi/bin_bundle.h"
#include "apsi/item.h"
#include "apsi/receiver.h"
// SEAL
#include "seal/context.h"
#include "seal/modulus.h"
#include "seal/util/common.h"
#include "seal/util/defines.h"
using
namespace
apsi
;
// NOLINT
using
namespace
apsi
::
sender
;
// NOLINT
using
namespace
apsi
::
oprf
;
// NOLINT
using
namespace
apsi
::
network
;
// NOLINT
using
namespace
seal
;
// NOLINT
using
namespace
seal
::
util
;
// NOLINT
namespace
primihub
::
pir
{
using
UnlabeledData
=
std
::
vector
<
apsi
::
Item
>
;
using
LabeledData
=
std
::
vector
<
std
::
pair
<
apsi
::
Item
,
apsi
::
Label
>>
;
using
DBData
=
std
::
variant
<
UnlabeledData
,
LabeledData
>
;
class
KeywordPirOperator
:
public
BasePirOperator
{
public:
enum
class
RequestType
:
uint8_t
{
PsiParam
=
0
,
Oprf
,
Query
,
};
explicit
KeywordPirOperator
(
const
Options
&
options
)
:
BasePirOperator
(
options
)
{}
retcode
OnExecute
(
const
PirDataType
&
input
,
PirDataType
*
result
)
override
;
protected:
retcode
ExecuteAsClient
(
const
PirDataType
&
input
,
PirDataType
*
result
);
retcode
ExecuteAsServer
(
const
PirDataType
&
input
);
protected:
// ------------------------Receiver----------------------------
/**
* Performs a parameter request from sender
*/
retcode
RequestPSIParams
();
/**
* Performs an OPRF request on a vector of items and returns a
vector of OPRF hashed items of the same size as the input vector.
*/
retcode
RequestOprf
(
const
std
::
vector
<
Item
>&
items
,
std
::
vector
<
apsi
::
HashedItem
>*
,
std
::
vector
<
apsi
::
LabelKey
>*
);
/**
* Performs a labeled PSI query. The query is a vector of items,
* and the result is a same-size vector of MatchRecord objects.
* If an item is in the intersection,
* the corresponding MatchRecord indicates it in the `found` field,
* and the `label` field may contain the corresponding
* label if a sender's data included it.
*/
retcode
RequestQuery
();
retcode
ExtractResult
(
const
std
::
vector
<
std
::
string
>&
orig_vec
,
const
std
::
vector
<
apsi
::
receiver
::
MatchRecord
>&
query_result
,
PirDataType
*
result
);
protected:
// ------------------------Sender----------------------------
std
::
unique_ptr
<
apsi
::
PSIParams
>
SetPsiParams
();
/**
* process a Get Parameters request to the Sender.
*/
retcode
ProcessPSIParams
();
/**
process an OPRF query request to the Sender.
*/
retcode
ProcessOprf
();
/**
process a Query request to the Sender.
*/
retcode
ProcessQuery
(
std
::
shared_ptr
<
apsi
::
sender
::
SenderDB
>
sender_db
);
retcode
ComputePowers
(
const
shared_ptr
<
apsi
::
sender
::
SenderDB
>
&
sender_db
,
const
apsi
::
CryptoContext
&
crypto_context
,
std
::
vector
<
apsi
::
sender
::
CiphertextPowers
>
&
all_powers
,
const
apsi
::
PowersDag
&
pd
,
uint32_t
bundle_idx
,
seal
::
MemoryPoolHandle
&
pool
);
auto
ProcessBinBundleCache
(
const
shared_ptr
<
apsi
::
sender
::
SenderDB
>
&
sender_db
,
const
apsi
::
CryptoContext
&
crypto_context
,
reference_wrapper
<
const
apsi
::
sender
::
BinBundleCache
>
cache
,
std
::
vector
<
apsi
::
sender
::
CiphertextPowers
>
&
all_powers
,
uint32_t
bundle_idx
,
compr_mode_type
compr_mode
,
seal
::
MemoryPoolHandle
&
pool
)
->
std
::
unique_ptr
<
apsi
::
network
::
ResultPackage
>
;
std
::
unique_ptr
<
DBData
>
CreateDb
(
const
PirDataType
&
input
);
retcode
CreateDbDataCache
(
const
DBData
&
db_data
,
std
::
unique_ptr
<
apsi
::
PSIParams
>
psi_params
,
apsi
::
oprf
::
OPRFKey
&
oprf_key
,
size_t
nonce_byte_count
,
bool
compress
);
auto
CreateSenderDb
(
const
DBData
&
db_data
,
std
::
unique_ptr
<
PSIParams
>
psi_params
,
apsi
::
oprf
::
OPRFKey
&
oprf_key
,
size_t
nonce_byte_count
,
bool
compress
)
->
std
::
shared_ptr
<
SenderDB
>
;
bool
DbCacheAvailable
(
const
std
::
string
&
db_path
);
std
::
shared_ptr
<
apsi
::
sender
::
SenderDB
>
LoadDbFromCache
(
const
std
::
string
&
db_path
);
private:
std
::
string
psi_params_str_
;
std
::
unique_ptr
<
apsi
::
oprf
::
OPRFKey
>
oprf_key_
{
nullptr
};
std
::
unique_ptr
<
apsi
::
receiver
::
Receiver
>
receiver_
{
nullptr
};
std
::
unique_ptr
<
apsi
::
PSIParams
>
psi_params_
{
nullptr
};
};
}
// namespace primihub::pir
#endif // SRC_PRIMIHUB_KERNEL_PIR_OPERATOR_KEYWORD_PIR_H_
src/primihub/node/worker/worker.cc
浏览文件 @
47e65456
...
...
@@ -90,10 +90,6 @@ retcode Worker::execute(const PushTaskRequest *pushTaskRequest) {
void
Worker
::
kill_task
()
{
if
(
task_ptr
)
{
task_ptr
->
kill_task
();
return
;
}
if
(
task_server_ptr
)
{
task_server_ptr
->
kill_task
();
}
}
...
...
src/primihub/node/worker/worker.h
浏览文件 @
47e65456
...
...
@@ -39,7 +39,6 @@
#include "src/primihub/node/nodelet.h"
#include "src/primihub/protos/worker.pb.h"
#include "src/primihub/task/semantic/task.h"
#include "src/primihub/task/semantic/private_server_base.h"
#include "src/primihub/common/common.h"
using
primihub
::
rpc
::
PushTaskRequest
;
...
...
@@ -71,9 +70,6 @@ class Worker {
std
::
shared_ptr
<
primihub
::
task
::
TaskBase
>
getTask
()
{
return
task_ptr
;
}
std
::
shared_ptr
<
primihub
::
task
::
ServerTaskBase
>
getServerTask
()
{
return
task_server_ptr
;
}
retcode
waitForTaskReady
();
// scheduler method
...
...
@@ -89,7 +85,6 @@ class Worker {
mutable
absl
::
Mutex
worker_map_mutex_
;
std
::
shared_ptr
<
primihub
::
task
::
TaskBase
>
task_ptr
{
nullptr
};
std
::
shared_ptr
<
primihub
::
task
::
ServerTaskBase
>
task_server_ptr
{
nullptr
};
const
std
::
string
&
node_id
;
std
::
shared_ptr
<
Nodelet
>
nodelet
;
std
::
string
worker_id_
;
...
...
src/primihub/task/semantic/BUILD
浏览文件 @
47e65456
...
...
@@ -29,7 +29,6 @@ cc_library(
":pir_task"
,
":psi_task"
,
":tee_task"
,
":private_server_base"
,
],
)
cc_library
(
...
...
@@ -116,65 +115,14 @@ cc_library(
# pir task
cc_library
(
name
=
"pir_task"
,
deps
=
[
":keyword_pir_task"
,
],
# deps = select({
# "microsoft-apsi" : [":keyword_pir_task"],
# "//conditions:default": [":id_pir_task"],
# }),
)
cc_library
(
name
=
"keyword_pir_task"
,
hdrs
=
[
"keyword_pir_client_task.h"
,
"keyword_pir_server_task.h"
,
],
srcs
=
[
"keyword_pir_client_task.cc"
,
"keyword_pir_server_task.cc"
,
],
copts
=
[
"-w"
,
"-D_ASPI"
,
],
defines
=
[
"USE_MICROSOFT_APSI"
],
deps
=
[
":task_interface"
,
"//src/primihub/protos:common_proto"
,
"@mircrosoft_apsi//:APSI"
,
]
)
cc_library
(
name
=
"private_server_base"
,
hdrs
=
[
"private_server_base.h"
],
srcs
=
[
"private_server_base.cc"
],
deps
=
[
":task_interface"
,
"//src/primihub/protos:worker_proto"
,
"//src/primihub/common:common_defination"
,
"//src/primihub/data_store:data_store_lib"
,
"//src/primihub/service:dataset_service"
,
],
)
cc_library
(
name
=
"id_pir_task"
,
hdrs
=
[
"pir_client_task.h"
,
"pir_server_task.h"
,
],
srcs
=
[
"pir_client_task.cc"
,
"pir_server_task.cc"
,
],
deps
=
[
":private_server_base"
,
":task_interface"
,
"@org_openmined_pir//pir/cpp:pir"
,
],
name
=
"pir_task"
,
hdrs
=
[
"pir_task.h"
],
srcs
=
[
"pir_task.cc"
],
deps
=
[
":task_interface"
,
"//src/primihub/kernel/pir:common_def"
,
"//src/primihub/kernel/pir/operator:factory"
,
],
)
# task semantic parser
...
...
src/primihub/task/semantic/factory.h
浏览文件 @
47e65456
...
...
@@ -24,15 +24,7 @@
#include "src/primihub/task/semantic/mpc_task.h"
#include "src/primihub/task/semantic/fl_task.h"
#include "src/primihub/task/semantic/psi_task.h"
#include "src/primihub/task/semantic/private_server_base.h"
#ifndef USE_MICROSOFT_APSI
#include "src/primihub/task/semantic/pir_client_task.h"
#include "src/primihub/task/semantic/pir_server_task.h"
#else
#include "src/primihub/task/semantic/keyword_pir_client_task.h"
#include "src/primihub/task/semantic/keyword_pir_server_task.h"
#endif
#include "src/primihub/task/semantic/pir_task.h"
#include "src/primihub/task/semantic/tee_task.h"
#include "src/primihub/service/dataset/service.h"
...
...
@@ -42,8 +34,6 @@ using primihub::rpc::PushTaskRequest;
using
primihub
::
rpc
::
Language
;
using
primihub
::
rpc
::
TaskType
;
using
primihub
::
service
::
DatasetService
;
using
primihub
::
rpc
::
PsiTag
;
using
primihub
::
rpc
::
PirType
;
namespace
primihub
::
task
{
...
...
@@ -123,38 +113,7 @@ class TaskFactory {
const
PushTaskRequest
&
request
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
)
{
const
auto
&
task_config
=
request
.
task
();
const
auto
&
param_map
=
task_config
.
params
().
param_map
();
int
pir_type
=
PirType
::
ID_PIR
;
auto
param_it
=
param_map
.
find
(
"pirType"
);
if
(
param_it
!=
param_map
.
end
())
{
pir_type
=
param_it
->
second
.
value_int32
();
}
#ifndef USE_MICROSOFT_APSI
const
auto
&
job_id
=
request
.
task
().
task_info
().
job_id
();
const
auto
&
task_id
=
request
.
task
().
task_info
().
task_id
();
if
(
pir_type
==
PirType
::
ID_PIR
)
{
return
std
::
make_shared
<
PIRClientTask
>
(
node_id
,
job_id
,
task_id
,
&
task_config
,
dataset_service
);
}
else
{
// TODO, using condition compile, fix in future
LOG
(
WARNING
)
<<
"ID_PIR is not supported when MICROSOFT_APSI enabled"
;
return
nullptr
;
}
#else // KEYWORD PIR
if
(
pir_type
==
PirType
::
KEY_PIR
)
{
std
::
string
party_name
=
task_config
.
party_name
();
if
(
party_name
==
PARTY_SERVER
)
{
return
std
::
make_shared
<
KeywordPIRServerTask
>
(
&
task_config
,
dataset_service
);
}
else
{
return
std
::
make_shared
<
KeywordPIRClientTask
>
(
&
task_config
,
dataset_service
);
}
}
else
{
LOG
(
ERROR
)
<<
"Unsupported pir type: "
<<
pir_type
;
return
nullptr
;
}
#endif
return
std
::
make_shared
<
PirTask
>
(
&
task_config
,
dataset_service
);
}
static
std
::
shared_ptr
<
TaskBase
>
CreateTEETask
(
const
std
::
string
&
node_id
,
...
...
@@ -166,28 +125,7 @@ class TaskFactory {
dataset_service
);
}
static
std
::
shared_ptr
<
ServerTaskBase
>
Create
(
const
std
::
string
&
node_id
,
rpc
::
TaskType
task_type
,
const
ExecuteTaskRequest
&
request
,
ExecuteTaskResponse
*
response
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
)
{
if
(
task_type
==
rpc
::
TaskType
::
NODE_PIR_TASK
)
{
#ifdef USE_MICROSOFT_APSI
// TODO, using condition compile, fix in future
LOG
(
WARNING
)
<<
"ID_PIR is not supported when using MICROSOFT_APSI"
;
return
nullptr
;
#else
return
std
::
make_shared
<
PIRServerTask
>
(
node_id
,
request
,
response
,
dataset_service
);
#endif
}
else
{
LOG
(
ERROR
)
<<
"Unsupported task type at server node: "
<<
task_type
<<
"."
;
return
nullptr
;
}
}
};
}
// namespace primihub::task
#endif // SRC_PRIMIHUB_TASK_SEMANTIC_FACTORY_H_
src/primihub/task/semantic/keyword_pir_client_task.cc
已删除
100644 → 0
浏览文件 @
405e10f6
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "src/primihub/task/semantic/keyword_pir_client_task.h"
#include <thread>
#include <chrono>
#include <sstream>
#include "src/primihub/util/util.h"
#include "apsi/item.h"
#include "apsi/util/common_utils.h"
#include "src/primihub/util/file_util.h"
#include "src/primihub/protos/worker.pb.h"
#include "seal/util/common.h"
using
namespace
apsi
;
using
namespace
apsi
::
network
;
namespace
primihub
::
task
{
KeywordPIRClientTask
::
KeywordPIRClientTask
(
const
TaskParam
*
task_param
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
)
:
TaskBase
(
task_param
,
dataset_service
)
{
}
retcode
KeywordPIRClientTask
::
_LoadParams
(
Task
&
task
)
{
CHECK_TASK_STOPPED
(
retcode
::
FAIL
);
std
::
string
party_name
=
task
.
party_name
();
const
auto
&
param_map
=
task
.
params
().
param_map
();
try
{
auto
client_data_it
=
param_map
.
find
(
"clientData"
);
if
(
client_data_it
!=
param_map
.
end
())
{
auto
&
client_data
=
client_data_it
->
second
;
if
(
client_data
.
is_array
())
{
recv_query_data_direct
=
true
;
// read query data from clientData key directly
const
auto
&
items
=
client_data
.
value_string_array
().
value_string_array
();
for
(
const
auto
&
item
:
items
)
{
recv_data_
.
push_back
(
item
);
}
}
else
{
dataset_path_
=
client_data
.
value_string
();
dataset_id_
=
client_data
.
value_string
();
}
}
else
{
// check client has dataset
const
auto
&
party_datasets
=
task
.
party_datasets
();
auto
it
=
party_datasets
.
find
(
party_name
);
if
(
it
==
party_datasets
.
end
())
{
LOG
(
ERROR
)
<<
"no query data found for client, party_name: "
<<
party_name
;
return
retcode
::
FAIL
;
}
const
auto
&
datasets_map
=
it
->
second
.
data
();
auto
iter
=
datasets_map
.
find
(
party_name
);
if
(
iter
==
datasets_map
.
end
())
{
LOG
(
ERROR
)
<<
"no query data found for client, party_name: "
<<
party_name
;
return
retcode
::
FAIL
;
}
dataset_id_
=
iter
->
second
;
}
VLOG
(
7
)
<<
"dataset_id: "
<<
dataset_id_
;
auto
result_file_path_it
=
param_map
.
find
(
"outputFullFilename"
);
if
(
result_file_path_it
!=
param_map
.
end
())
{
result_file_path_
=
result_file_path_it
->
second
.
value_string
();
VLOG
(
5
)
<<
"result_file_path_: "
<<
result_file_path_
;
}
else
{
LOG
(
ERROR
)
<<
"no keyword outputFullFilename match"
;
return
retcode
::
FAIL
;
}
// get server dataset id
do
{
const
auto
&
party_datasets
=
task
.
party_datasets
();
auto
it
=
party_datasets
.
find
(
PARTY_SERVER
);
if
(
it
==
party_datasets
.
end
())
{
LOG
(
WARNING
)
<<
"no dataset found for party_name: "
<<
PARTY_SERVER
;
break
;
}
const
auto
&
datasets_map
=
it
->
second
.
data
();
auto
iter
=
datasets_map
.
find
(
PARTY_SERVER
);
if
(
iter
==
datasets_map
.
end
())
{
LOG
(
WARNING
)
<<
"no dataset found for party_name: "
<<
PARTY_SERVER
;
break
;
}
std
::
string
server_dataset_id
=
iter
->
second
;
auto
&
dataset_service
=
this
->
getDatasetService
();
auto
driver
=
dataset_service
->
getDriver
(
server_dataset_id
);
if
(
driver
==
nullptr
)
{
LOG
(
WARNING
)
<<
"no dataset access info found for id: "
<<
server_dataset_id
;
break
;
}
auto
&
access_info
=
driver
->
dataSetAccessInfo
();
if
(
access_info
==
nullptr
)
{
LOG
(
WARNING
)
<<
"no dataset access info found for id: "
<<
server_dataset_id
;
break
;
}
auto
&
schema
=
access_info
->
Schema
();
for
(
const
auto
&
field
:
schema
)
{
server_dataset_schema_
.
push_back
(
std
::
get
<
0
>
(
field
));
}
}
while
(
0
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"Failed to load params: "
<<
e
.
what
();
return
retcode
::
FAIL
;
}
const
auto
&
party_info
=
task
.
party_access_info
();
auto
it
=
party_info
.
find
(
PARTY_SERVER
);
if
(
it
==
party_info
.
end
())
{
LOG
(
ERROR
)
<<
"client can not found access info to server"
;
return
retcode
::
FAIL
;
}
auto
&
pb_node
=
it
->
second
;
pbNode2Node
(
pb_node
,
&
peer_node_
);
VLOG
(
5
)
<<
"peer_node: "
<<
peer_node_
.
to_string
();
return
retcode
::
SUCCESS
;
}
KeywordPIRClientTask
::
DatasetDBPair
KeywordPIRClientTask
::
_LoadDataFromDataset
()
{
apsi
::
util
::
CSVReader
::
DBData
db_data
;
std
::
vector
<
std
::
string
>
orig_items
;
auto
driver
=
this
->
getDatasetService
()
->
getDriver
(
this
->
dataset_id_
);
if
(
driver
==
nullptr
)
{
LOG
(
ERROR
)
<<
"get driver for dataset: "
<<
this
->
dataset_id_
<<
" failed"
;
return
std
::
make_pair
(
nullptr
,
std
::
vector
<
std
::
string
>
());
}
auto
access_info
=
dynamic_cast
<
CSVAccessInfo
*>
(
driver
->
dataSetAccessInfo
().
get
());
if
(
access_info
==
nullptr
)
{
LOG
(
ERROR
)
<<
"get data accessinfo for dataset: "
<<
this
->
dataset_id_
<<
" failed"
;
return
std
::
make_pair
(
nullptr
,
std
::
vector
<
std
::
string
>
());
}
dataset_path_
=
access_info
->
file_path_
;
try
{
apsi
::
util
::
CSVReader
reader
(
dataset_path_
);
std
::
tie
(
db_data
,
orig_items
)
=
reader
.
read
();
}
catch
(
const
std
::
exception
&
ex
)
{
LOG
(
ERROR
)
<<
"Could not open or read file `"
<<
dataset_path_
<<
"`: "
<<
ex
.
what
();
return
std
::
make_pair
(
nullptr
,
orig_items
);
}
return
{
std
::
make_unique
<
apsi
::
util
::
CSVReader
::
DBData
>
(
std
::
move
(
db_data
)),
std
::
move
(
orig_items
)};
}
KeywordPIRClientTask
::
DatasetDBPair
KeywordPIRClientTask
::
_LoadDataFromRecvData
()
{
if
(
recv_data_
.
empty
())
{
LOG
(
ERROR
)
<<
"query data is empty"
;
return
std
::
make_pair
(
nullptr
,
std
::
vector
<
std
::
string
>
());
}
// build db_data;
// std::unqiue_ptr<apsi::util::CSVReader::DBData>
apsi
::
util
::
CSVReader
::
DBData
db_data
=
apsi
::
util
::
CSVReader
::
UnlabeledData
();
for
(
const
auto
&
item_str
:
recv_data_
)
{
apsi
::
Item
db_item
=
item_str
;
std
::
get
<
apsi
::
util
::
CSVReader
::
UnlabeledData
>
(
db_data
).
push_back
(
std
::
move
(
db_item
));
}
return
{
std
::
make_unique
<
apsi
::
util
::
CSVReader
::
DBData
>
(
std
::
move
(
db_data
)),
recv_data_
};
// return std::make_pair(std::move(db_data), std::move(orig_items));
}
KeywordPIRClientTask
::
DatasetDBPair
KeywordPIRClientTask
::
_LoadDataset
(
void
)
{
if
(
!
recv_query_data_direct
)
{
return
_LoadDataFromDataset
();
}
else
{
return
_LoadDataFromRecvData
();
}
}
retcode
KeywordPIRClientTask
::
saveResult
(
const
std
::
vector
<
std
::
string
>&
orig_items
,
const
std
::
vector
<
Item
>&
items
,
const
std
::
vector
<
MatchRecord
>&
intersection
)
{
CHECK_TASK_STOPPED
(
retcode
::
FAIL
);
if
(
orig_items
.
size
()
!=
items
.
size
())
{
LOG
(
ERROR
)
<<
"Keyword PIR orig_items must have the same size as items, detail: "
<<
"orig_items size: "
<<
orig_items
.
size
()
<<
" items size: "
<<
items
.
size
();
return
retcode
::
FAIL
;
}
std
::
vector
<
std
::
vector
<
std
::
string
>>
result_data
;
result_data
.
resize
(
2
);
for
(
auto
&
item
:
result_data
)
{
item
.
reserve
(
orig_items
.
size
());
}
auto
&
key
=
result_data
[
0
];
auto
&
result_value
=
result_data
[
1
];
for
(
size_t
i
=
0
;
i
<
orig_items
.
size
();
i
++
)
{
if
(
!
intersection
[
i
].
found
)
{
VLOG
(
0
)
<<
"no match result found for query: ["
<<
orig_items
[
i
]
<<
"]"
;
continue
;
}
if
(
intersection
[
i
].
label
)
{
std
::
string
label_info
=
intersection
[
i
].
label
.
to_string
();
std
::
vector
<
std
::
string
>
labels
;
std
::
string
sep
=
DATA_RECORD_SEP
;
str_split
(
label_info
,
&
labels
,
sep
);
for
(
const
auto
&
lable_
:
labels
)
{
key
.
push_back
(
orig_items
[
i
]);
result_value
.
push_back
(
lable_
);
}
}
else
{
LOG
(
WARNING
)
<<
"no value found for query key: "
<<
orig_items
[
i
];
}
}
VLOG
(
0
)
<<
"save query result to : "
<<
result_file_path_
;
std
::
vector
<
std
::
shared_ptr
<
arrow
::
Field
>>
schema_vector
;
std
::
vector
<
std
::
string
>
tmp_colums
{
"key"
,
"value"
};
for
(
const
auto
&
col_name
:
tmp_colums
)
{
schema_vector
.
push_back
(
arrow
::
field
(
col_name
,
arrow
::
int64
()));
}
std
::
vector
<
std
::
shared_ptr
<
arrow
::
Array
>>
arrow_array
;
for
(
auto
&
item
:
result_data
)
{
arrow
::
StringBuilder
builder
;
builder
.
AppendValues
(
item
);
std
::
shared_ptr
<
arrow
::
Array
>
array
;
builder
.
Finish
(
&
array
);
arrow_array
.
push_back
(
std
::
move
(
array
));
}
auto
schema
=
std
::
make_shared
<
arrow
::
Schema
>
(
schema_vector
);
// std::shared_ptr<arrow::Table>
auto
table
=
arrow
::
Table
::
Make
(
schema
,
arrow_array
);
auto
driver
=
DataDirverFactory
::
getDriver
(
"CSV"
,
"test address"
);
auto
csv_driver
=
std
::
dynamic_pointer_cast
<
CSVDriver
>
(
driver
);
auto
rtcode
=
csv_driver
->
Write
(
server_dataset_schema_
,
table
,
result_file_path_
);
if
(
rtcode
!=
retcode
::
SUCCESS
)
{
LOG
(
ERROR
)
<<
"save PIR data to file "
<<
result_file_path_
<<
" failed."
;
return
retcode
::
FAIL
;
}
return
retcode
::
SUCCESS
;
}
retcode
KeywordPIRClientTask
::
requestPSIParams
()
{
CHECK_TASK_STOPPED
(
retcode
::
FAIL
);
RequestType
type
=
RequestType
::
PsiParam
;
std
::
string
request
{
reinterpret_cast
<
char
*>
(
&
type
),
sizeof
(
type
)};
VLOG
(
5
)
<<
"send_data length: "
<<
request
.
length
();
std
::
string
response_str
;
auto
&
link_ctx
=
this
->
getTaskContext
().
getLinkContext
();
CHECK_NULLPOINTER_WITH_ERROR_MSG
(
link_ctx
,
"LinkContext is empty"
);
auto
channel
=
link_ctx
->
getChannel
(
peer_node_
);
auto
ret
=
channel
->
sendRecv
(
this
->
key
,
request
,
&
response_str
);
if
(
ret
!=
retcode
::
SUCCESS
)
{
LOG
(
ERROR
)
<<
"send requestPSIParams to peer: ["
<<
peer_node_
.
to_string
()
<<
"] failed"
;
return
ret
;
}
if
(
VLOG_IS_ON
(
5
))
{
std
::
string
tmp_str
;
for
(
const
auto
&
chr
:
response_str
)
{
tmp_str
.
append
(
std
::
to_string
(
static_cast
<
int
>
(
chr
))).
append
(
" "
);
}
VLOG
(
5
)
<<
"recv_data size: "
<<
response_str
.
size
()
<<
" "
<<
"data content: "
<<
tmp_str
;
}
// create psi params
// static std::pair<PSIParams, std::size_t> Load(std::istream &in);
std
::
istringstream
stream_in
(
response_str
);
auto
[
parse_data
,
ret_size
]
=
PSIParams
::
Load
(
stream_in
);
psi_params_
=
std
::
make_unique
<
PSIParams
>
(
parse_data
);
VLOG
(
5
)
<<
"parsed psi param, size: "
<<
ret_size
<<
" "
<<
"content: "
<<
psi_params_
->
to_string
();
return
retcode
::
SUCCESS
;
}
static
std
::
string
to_hexstring
(
const
Item
&
item
)
{
std
::
stringstream
ss
;
ss
<<
std
::
hex
;
auto
item_string
=
item
.
to_string
();
for
(
int
i
(
0
);
i
<
16
;
++
i
)
ss
<<
std
::
setw
(
2
)
<<
std
::
setfill
(
'0'
)
<<
(
int
)
item_string
[
i
];
return
ss
.
str
();
}
retcode
KeywordPIRClientTask
::
requestOprf
(
const
std
::
vector
<
Item
>&
items
,
std
::
vector
<
apsi
::
HashedItem
>*
res_items_ptr
,
std
::
vector
<
apsi
::
LabelKey
>*
res_label_keys_ptr
)
{
CHECK_TASK_STOPPED
(
retcode
::
FAIL
);
RequestType
type
=
RequestType
::
Oprf
;
std
::
string
oprf_response
;
auto
oprf_receiver
=
this
->
receiver_
->
CreateOPRFReceiver
(
items
);
auto
&
res_items
=
*
res_items_ptr
;
auto
&
res_label_keys
=
*
res_label_keys_ptr
;
res_items
.
resize
(
oprf_receiver
.
item_count
());
res_label_keys
.
resize
(
oprf_receiver
.
item_count
());
auto
oprf_request
=
oprf_receiver
.
query_data
();
VLOG
(
5
)
<<
"oprf_request data length: "
<<
oprf_request
.
size
();
std
::
string_view
oprf_request_sv
{
reinterpret_cast
<
char
*>
(
const_cast
<
unsigned
char
*>
(
oprf_request
.
data
())),
oprf_request
.
size
()};
auto
&
link_ctx
=
this
->
getTaskContext
().
getLinkContext
();
CHECK_NULLPOINTER_WITH_ERROR_MSG
(
link_ctx
,
"LinkContext is empty"
);
auto
channel
=
link_ctx
->
getChannel
(
peer_node_
);
// auto ret = channel->sendRecv(this->key, oprf_request_sv, &oprf_response);
auto
ret
=
this
->
send
(
this
->
key
,
peer_node_
,
oprf_request_sv
);
if
(
ret
!=
retcode
::
SUCCESS
)
{
LOG
(
ERROR
)
<<
"requestOprf to peer: ["
<<
peer_node_
.
to_string
()
<<
"] failed"
;
return
ret
;
}
ret
=
this
->
recv
(
this
->
key
,
&
oprf_response
);
if
(
ret
!=
retcode
::
SUCCESS
||
oprf_response
.
empty
())
{
LOG
(
ERROR
)
<<
"receive oprf_response from peer: ["
<<
peer_node_
.
to_string
()
<<
"] failed"
;
return
retcode
::
FAIL
;
}
VLOG
(
5
)
<<
"received oprf response length: "
<<
oprf_response
.
length
()
<<
" "
;
oprf_receiver
.
process_responses
(
oprf_response
,
res_items
,
res_label_keys
);
return
retcode
::
SUCCESS
;
}
retcode
KeywordPIRClientTask
::
requestQuery
()
{
RequestType
type
=
RequestType
::
Query
;
std
::
string
send_data
{
reinterpret_cast
<
char
*>
(
&
type
),
sizeof
(
type
)};
VLOG
(
5
)
<<
"send_data length: "
<<
send_data
.
length
();
return
retcode
::
SUCCESS
;
}
int
KeywordPIRClientTask
::
execute
()
{
auto
ret
=
_LoadParams
(
task_param_
);
if
(
ret
!=
retcode
::
SUCCESS
)
{
LOG
(
ERROR
)
<<
"Pir client load task params failed."
;
return
-
1
;
}
VLOG
(
5
)
<<
"begin to request psi params"
;
ret
=
requestPSIParams
();
CHECK_RETCODE_WITH_RETVALUE
(
ret
,
-
1
);
auto
[
query_data
,
orig_items
]
=
_LoadDataset
();
if
(
!
query_data
||
!
holds_alternative
<
CSVReader
::
UnlabeledData
>
(
*
query_data
))
{
LOG
(
ERROR
)
<<
"Failed to read keyword PIR query file: terminating"
;
return
-
1
;
}
auto
&
items
=
std
::
get
<
CSVReader
::
UnlabeledData
>
(
*
query_data
);
std
::
vector
<
Item
>
items_vec
(
items
.
begin
(),
items
.
end
());
std
::
vector
<
HashedItem
>
oprf_items
;
std
::
vector
<
LabelKey
>
label_keys
;
VLOG
(
5
)
<<
"begin to Receiver::RequestOPRF"
;
ret
=
requestOprf
(
items_vec
,
&
oprf_items
,
&
label_keys
);
CHECK_RETCODE_WITH_RETVALUE
(
ret
,
-
1
);
CHECK_TASK_STOPPED
(
-
1
);
if
(
VLOG_IS_ON
(
5
))
{
for
(
int
i
=
0
;
i
<
items_vec
.
size
();
i
++
)
{
VLOG
(
5
)
<<
"item["
<<
i
<<
"]'s PRF value: "
<<
to_hexstring
(
oprf_items
[
i
]);
}
}
VLOG
(
5
)
<<
"Receiver::RequestOPRF end, begin to receiver.request_query"
;
// request query
this
->
receiver_
=
std
::
make_unique
<
Receiver
>
(
*
psi_params_
);
std
::
vector
<
MatchRecord
>
query_result
;
try
{
auto
query
=
this
->
receiver_
->
create_query
(
oprf_items
);
// chl.send(move(query.first));
auto
request_query_data
=
std
::
move
(
query
.
first
);
std
::
ostringstream
string_ss
;
request_query_data
->
save
(
string_ss
);
std
::
string
query_data_str
=
string_ss
.
str
();
auto
itt
=
move
(
query
.
second
);
VLOG
(
5
)
<<
"query_data_str size: "
<<
query_data_str
.
size
();
ret
=
this
->
send
(
this
->
key
,
peer_node_
,
query_data_str
);
CHECK_RETCODE_WITH_RETVALUE
(
ret
,
-
1
);
// receive package count
uint32_t
package_count
=
0
;
ret
=
this
->
recv
(
"package_count"
,
reinterpret_cast
<
char
*>
(
&
package_count
),
sizeof
(
package_count
));
CHECK_RETCODE_WITH_RETVALUE
(
ret
,
-
1
);
VLOG
(
5
)
<<
"received package count: "
<<
package_count
;
std
::
vector
<
apsi
::
ResultPart
>
result_packages
;
for
(
size_t
i
=
0
;
i
<
package_count
;
i
++
)
{
std
::
string
recv_data
;
ret
=
this
->
recv
(
this
->
key
,
&
recv_data
);
CHECK_RETCODE_WITH_RETVALUE
(
ret
,
-
1
);
VLOG
(
5
)
<<
"client received data length: "
<<
recv_data
.
size
();
std
::
istringstream
stream_in
(
recv_data
);
apsi
::
ResultPart
result_part
=
std
::
make_unique
<
apsi
::
network
::
ResultPackage
>
();
auto
seal_context
=
this
->
receiver_
->
get_seal_context
();
result_part
->
load
(
stream_in
,
seal_context
);
result_packages
.
push_back
(
std
::
move
(
result_part
));
}
query_result
=
this
->
receiver_
->
process_result
(
label_keys
,
itt
,
result_packages
);
VLOG
(
5
)
<<
"query_resultquery_resultquery_resultquery_result: "
<<
query_result
.
size
();
}
catch
(
const
std
::
exception
&
ex
)
{
LOG
(
ERROR
)
<<
"Failed sending keyword PIR query: "
<<
ex
.
what
();
return
-
1
;
}
VLOG
(
5
)
<<
"receiver.request_query end"
;
ret
=
this
->
saveResult
(
orig_items
,
items
,
query_result
);
CHECK_RETCODE_WITH_RETVALUE
(
ret
,
-
1
);
return
0
;
}
}
// namespace primihub::task
src/primihub/task/semantic/keyword_pir_client_task.h
已删除
100644 → 0
浏览文件 @
405e10f6
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef SRC_PRIMIHUB_TASK_SEMANTIC_KEYWORD_PIR_CLIENT_TASK_H_
#define SRC_PRIMIHUB_TASK_SEMANTIC_KEYWORD_PIR_CLIENT_TASK_H_
#include <vector>
#include "apsi/item.h"
#include "apsi/match_record.h"
#include "apsi/util/csv_reader.h"
#include "src/primihub/task/semantic/task.h"
#include "src/primihub/common/common.h"
#include "apsi/receiver.h"
using
apsi
::
Item
;
using
apsi
::
receiver
::
MatchRecord
;
using
apsi
::
util
::
CSVReader
;
using
namespace
apsi
::
receiver
;
namespace
primihub
::
task
{
class
KeywordPIRClientTask
:
public
TaskBase
{
public:
using
DBDataPtr
=
std
::
unique_ptr
<
apsi
::
util
::
CSVReader
::
DBData
>
;
using
DatasetDBPair
=
std
::
pair
<
DBDataPtr
,
std
::
vector
<
std
::
string
>>
;
enum
class
RequestType
:
uint8_t
{
PsiParam
=
0
,
Oprf
,
Query
,
};
explicit
KeywordPIRClientTask
(
const
TaskParam
*
task_param
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
);
~
KeywordPIRClientTask
()
=
default
;
int
execute
()
override
;
retcode
saveResult
(
const
std
::
vector
<
std
::
string
>&
orig_items
,
const
std
::
vector
<
apsi
::
Item
>&
items
,
const
std
::
vector
<
apsi
::
receiver
::
MatchRecord
>&
intersection
);
protected:
/**
* Performs a parameter request from sender
*/
retcode
requestPSIParams
();
/**
* Performs an OPRF request on a vector of items and returns a
vector of OPRF hashed items of the same size as the input vector.
*/
retcode
requestOprf
(
const
std
::
vector
<
Item
>&
items
,
std
::
vector
<
apsi
::
HashedItem
>*
,
std
::
vector
<
apsi
::
LabelKey
>*
);
/**
* Performs a labeled PSI query. The query is a vector of
items, and the result is a same-size vector of MatchRecord objects. If an item is in the
intersection, the corresponding MatchRecord indicates it in the `found` field, and the
`label` field may contain the corresponding label if a sender's data included it.
*/
retcode
requestQuery
();
private:
retcode
_LoadParams
(
Task
&
task
);
DatasetDBPair
_LoadDataset
();
// load dataset according by url
DatasetDBPair
_LoadDataFromDataset
();
// load data from request directly
DatasetDBPair
_LoadDataFromRecvData
();
private:
std
::
string
dataset_path_
;
std
::
string
dataset_id_
;
std
::
string
result_file_path_
;
std
::
string
server_address_
;
bool
recv_query_data_direct
{
false
};
uint32_t
server_data_port
{
2222
};
primihub
::
Node
peer_node_
;
std
::
string
key
{
"default"
};
std
::
unique_ptr
<
apsi
::
PSIParams
>
psi_params_
{
nullptr
};
std
::
unique_ptr
<
apsi
::
receiver
::
Receiver
>
receiver_
{
nullptr
};
std
::
vector
<
std
::
string
>
recv_data_
;
std
::
vector
<
std
::
string
>
server_dataset_schema_
;
};
}
// namespace primihub::task
#endif // SRC_PRIMIHUB_TASK_SEMANTIC_KEYWORD_PIR_CLIENT_TASK_H_
src/primihub/task/semantic/keyword_pir_server_task.cc
已删除
100644 → 0
浏览文件 @
405e10f6
此差异已折叠。
点击以展开。
src/primihub/task/semantic/keyword_pir_server_task.h
已删除
100644 → 0
浏览文件 @
405e10f6
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef SRC_PRIMIHUB_TASK_SEMANTIC_KEYWORD_PIR_SERVER_TASK_H_
#define SRC_PRIMIHUB_TASK_SEMANTIC_KEYWORD_PIR_SERVER_TASK_H_
#include <utility>
#include <vector>
#include <string>
#include <memory>
// APSI
#include "apsi/util/common_utils.h"
#include "apsi/sender.h"
#include "apsi/oprf/oprf_sender.h"
#include "apsi/bin_bundle.h"
#include "apsi/item.h"
// SEAL
#include "seal/context.h"
#include "seal/modulus.h"
#include "seal/util/common.h"
#include "seal/util/defines.h"
#include "src/primihub/protos/common.pb.h"
#include "src/primihub/task/semantic/task.h"
namespace
primihub
::
task
{
using
UnlabeledData
=
std
::
vector
<
apsi
::
Item
>
;
using
LabeledData
=
std
::
vector
<
std
::
pair
<
apsi
::
Item
,
apsi
::
Label
>>
;
using
DBData
=
std
::
variant
<
UnlabeledData
,
LabeledData
>
;
class
KeywordPIRServerTask
:
public
TaskBase
{
public:
enum
class
RequestType
:
uint8_t
{
PsiParam
=
0
,
Opfr
,
Query
,
};
explicit
KeywordPIRServerTask
(
const
TaskParam
*
task_param
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
);
~
KeywordPIRServerTask
()
=
default
;
int
execute
()
override
;
protected:
/**
* process a Get Parameters request to the Sender.
*/
retcode
processPSIParams
();
/**
process an OPRF query request to the Sender.
*/
retcode
processOprf
();
/**
process a Query request to the Sender.
*/
retcode
processQuery
(
std
::
shared_ptr
<
apsi
::
sender
::
SenderDB
>
sender_db
);
retcode
ComputePowers
(
const
shared_ptr
<
apsi
::
sender
::
SenderDB
>
&
sender_db
,
const
apsi
::
CryptoContext
&
crypto_context
,
std
::
vector
<
apsi
::
sender
::
CiphertextPowers
>
&
all_powers
,
const
apsi
::
PowersDag
&
pd
,
uint32_t
bundle_idx
,
seal
::
MemoryPoolHandle
&
pool
);
auto
ProcessBinBundleCache
(
const
shared_ptr
<
apsi
::
sender
::
SenderDB
>
&
sender_db
,
const
apsi
::
CryptoContext
&
crypto_context
,
reference_wrapper
<
const
apsi
::
sender
::
BinBundleCache
>
cache
,
std
::
vector
<
apsi
::
sender
::
CiphertextPowers
>
&
all_powers
,
uint32_t
bundle_idx
,
compr_mode_type
compr_mode
,
seal
::
MemoryPoolHandle
&
pool
)
->
std
::
unique_ptr
<
apsi
::
network
::
ResultPackage
>
;
private:
retcode
_LoadParams
(
Task
&
task
);
std
::
unique_ptr
<
DBData
>
_LoadDataset
(
void
);
std
::
unique_ptr
<
apsi
::
PSIParams
>
_SetPsiParams
();
std
::
shared_ptr
<
apsi
::
sender
::
SenderDB
>
create_sender_db
(
const
DBData
&
db_data
,
std
::
unique_ptr
<
apsi
::
PSIParams
>
psi_params
,
apsi
::
oprf
::
OPRFKey
&
oprf_key
,
size_t
nonce_byte_count
,
bool
compress
);
std
::
shared_ptr
<
apsi
::
sender
::
SenderDB
>
LoadDbFromCache
(
const
std
::
string
&
db_file_cache_
);
std
::
unique_ptr
<
DBData
>
CreateDbData
(
std
::
shared_ptr
<
Dataset
>&
data
);
std
::
vector
<
std
::
string
>
GetSelectedContent
(
std
::
shared_ptr
<
arrow
::
Table
>&
data_tbl
,
const
std
::
vector
<
int
>&
selected_col
);
retcode
CreateDbDataCache
(
const
DBData
&
db_data
,
std
::
unique_ptr
<
apsi
::
PSIParams
>
psi_params
,
apsi
::
oprf
::
OPRFKey
&
oprf_key
,
size_t
nonce_byte_count
,
bool
compress
);
bool
DbCacheAvailable
(
const
std
::
string
&
db_file_cache
);
private:
std
::
string
dataset_path_
;
std
::
string
dataset_id_
;
std
::
string
db_cache_dir_
{
"data/cache"
};
std
::
string
db_file_cache_
;
uint32_t
data_port
{
2222
};
std
::
string
client_address
;
primihub
::
Node
client_node_
;
std
::
string
key
{
"default"
};
std
::
string
psi_params_str_
;
std
::
unique_ptr
<
apsi
::
oprf
::
OPRFKey
>
oprf_key_
{
nullptr
};
bool
generate_db_offline_
{
false
};
};
}
// namespace primihub::task
#endif // SRC_PRIMIHUB_TASK_SEMANTIC_KEYWORD_PIR_SERVER_TASK_H_
src/primihub/task/semantic/pir_client_task.cc
已删除
100644 → 0
浏览文件 @
405e10f6
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "src/primihub/task/semantic/pir_client_task.h"
#include <string>
#include "src/primihub/data_store/factory.h"
#include "src/primihub/util/util.h"
using
arrow
::
Table
;
using
arrow
::
StringArray
;
using
arrow
::
Int64Builder
;
using
primihub
::
rpc
::
VarType
;
namespace
primihub
::
task
{
int
validateDirection
(
std
::
string
file_path
)
{
int
pos
=
file_path
.
find_last_of
(
'/'
);
std
::
string
path
;
if
(
pos
>
0
)
{
path
=
file_path
.
substr
(
0
,
pos
);
if
(
access
(
path
.
c_str
(),
0
)
==
-
1
)
{
std
::
string
cmd
=
"mkdir -p "
+
path
;
int
ret
=
system
(
cmd
.
c_str
());
if
(
ret
)
return
-
1
;
}
}
return
0
;
}
PIRClientTask
::
PIRClientTask
(
const
std
::
string
&
node_id
,
const
std
::
string
&
job_id
,
const
std
::
string
&
task_id
,
const
TaskParam
*
task_param
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
)
:
TaskBase
(
task_param
,
dataset_service
),
node_id_
(
node_id
),
job_id_
(
job_id
),
task_id_
(
task_id
)
{}
int
PIRClientTask
::
_LoadParams
(
Task
&
task
)
{
auto
param_map
=
task
.
params
().
param_map
();
try
{
result_file_path_
=
param_map
[
"outputFullFilename"
].
value_string
();
server_address_
=
param_map
[
"serverAddress"
].
value_string
();
server_dataset_
=
param_map
[
server_address_
].
value_string
();
db_size_
=
stoi
(
param_map
[
"databaseSize"
].
value_string
());
// temperarily read db size direly from frontend
std
::
vector
<
std
::
string
>
tmp_indices
;
str_split
(
param_map
[
"queryIndeies"
].
value_string
(),
&
tmp_indices
,
','
);
for
(
std
::
string
&
index
:
tmp_indices
)
{
int
idx
=
stoi
(
index
);
indices_
.
push_back
(
idx
);
}
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"Failed to load params: "
<<
e
.
what
();
return
-
1
;
}
return
0
;
}
int
PIRClientTask
::
_SetUpDB
(
size_t
__dbsize
,
size_t
dimensions
,
size_t
elem_size
,
uint32_t
plain_mod_bit_size
,
uint32_t
bits_per_coeff
,
bool
use_ciphertext_multiplication
=
false
)
{
// db_size_ = dbsize;
encryption_params_
=
pir
::
GenerateEncryptionParams
(
POLY_MODULUS_DEGREE
,
plain_mod_bit_size
);
pir_params_
=
*
(
pir
::
CreatePIRParameters
(
db_size_
,
elem_size
,
dimensions
,
encryption_params_
,
use_ciphertext_multiplication
,
bits_per_coeff
));
client_
=
*
(
PIRClient
::
Create
(
pir_params_
));
if
(
client_
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Failed to create pir client."
;
return
-
1
;
}
return
0
;
}
int
PIRClientTask
::
_ProcessResponse
(
const
ExecuteTaskResponse
&
taskResponse
)
{
pir
::
Response
response
;
size_t
num_reply
=
static_cast
<
size_t
>
(
taskResponse
.
pir_response
().
reply
().
size
());
for
(
size_t
i
=
0
;
i
<
num_reply
;
i
++
)
{
pir
::
Ciphertexts
*
ptr_reply
=
response
.
add_reply
();
size_t
num_ct
=
static_cast
<
std
::
int64_t
>
(
taskResponse
.
pir_response
().
reply
()[
i
].
ct
().
size
());
for
(
size_t
j
=
0
;
j
<
num_ct
;
j
++
)
{
ptr_reply
->
add_ct
(
taskResponse
.
pir_response
().
reply
()[
i
].
ct
()[
j
]);
}
}
auto
result
=
client_
->
ProcessResponse
(
indices_
,
response
);
if
(
result
.
ok
())
{
for
(
size_t
i
=
0
;
i
<
std
::
move
(
result
).
value
().
size
();
i
++
)
{
result_
.
push_back
(
std
::
move
(
result
).
value
()[
i
]);
}
}
else
{
LOG
(
ERROR
)
<<
"Failed to process pir server response: "
<<
result
.
status
();
return
-
1
;
}
return
0
;
}
int
PIRClientTask
::
saveResult
()
{
arrow
::
MemoryPool
*
pool
=
arrow
::
default_memory_pool
();
arrow
::
StringBuilder
builder
(
pool
);
for
(
std
::
int64_t
i
=
0
;
i
<
result_
.
size
();
i
++
)
{
builder
.
Append
(
result_
[
i
]);
}
std
::
shared_ptr
<
arrow
::
Array
>
array
;
builder
.
Finish
(
&
array
);
std
::
vector
<
std
::
shared_ptr
<
arrow
::
Field
>>
schema_vector
=
{
arrow
::
field
(
"reslut"
,
arrow
::
utf8
())};
auto
schema
=
std
::
make_shared
<
arrow
::
Schema
>
(
schema_vector
);
std
::
shared_ptr
<
arrow
::
Table
>
table
=
arrow
::
Table
::
Make
(
schema
,
{
array
});
std
::
shared_ptr
<
DataDriver
>
driver
=
DataDirverFactory
::
getDriver
(
"CSV"
,
"pir result"
);
std
::
shared_ptr
<
CSVDriver
>
csv_driver
=
std
::
dynamic_pointer_cast
<
CSVDriver
>
(
driver
);
if
(
validateDirection
(
result_file_path_
))
{
LOG
(
ERROR
)
<<
"can't access file path: "
<<
result_file_path_
;
return
-
1
;
}
int
ret
=
csv_driver
->
write
(
table
,
result_file_path_
);
if
(
ret
!=
0
)
{
LOG
(
ERROR
)
<<
"Save PIR result to file "
<<
result_file_path_
<<
" failed."
;
return
-
1
;
}
LOG
(
INFO
)
<<
"Save PIR result to "
<<
result_file_path_
<<
"."
;
return
0
;
}
uint32_t
compute_plain_mod_bit_size
(
size_t
dbsize
,
size_t
elem_size
)
{
uint32_t
plain_mod_bit_size
=
PLAIN_MOD_BIT_SIZE_UPBOUND
;
while
(
true
)
{
plain_mod_bit_size
--
;
uint64_t
elem_per_plaintext
=
POLY_MODULUS_DEGREE
\
*
(
plain_mod_bit_size
-
1
)
/
8
/
elem_size
;
uint64_t
num_plaintext
=
dbsize
/
elem_per_plaintext
+
1
;
if
(
num_plaintext
<=
(
uint64_t
)
1
<<
(
NOISE_BUDGET_BASE
-
2
*
plain_mod_bit_size
))
{
break
;
}
}
return
plain_mod_bit_size
;
}
int
PIRClientTask
::
execute
()
{
int
ret
=
_LoadParams
(
task_param_
);
if
(
ret
)
{
LOG
(
ERROR
)
<<
"Pir client load task params failed."
;
return
ret
;
}
size_t
dimensions
=
1
;
size_t
elem_size
=
ELEM_SIZE
;
uint32_t
plain_mod_bit_size
=
compute_plain_mod_bit_size
(
db_size_
,
elem_size
);
bool
use_ciphertext_multiplication
=
true
;
uint32_t
poly_modulus_degree
=
POLY_MODULUS_DEGREE
;
uint32_t
bits_per_coeff
=
0
;
ret
=
_SetUpDB
(
0
,
dimensions
,
elem_size
,
// temperarily read db size direly from frontend
plain_mod_bit_size
,
bits_per_coeff
,
use_ciphertext_multiplication
);
if
(
ret
)
{
LOG
(
ERROR
)
<<
"Failed to initialize pir client."
;
return
-
1
;
}
//pir::Request request_proto = std::move(client_->CreateRequest(indices_)).value();
pir
::
Request
request_proto
;
auto
request_or
=
client_
->
CreateRequest
(
indices_
);
if
(
request_or
.
ok
())
{
request_proto
=
std
::
move
(
request_or
).
value
();
}
else
{
LOG
(
ERROR
)
<<
"Pir create request failed: "
<<
request_or
.
status
();
return
-
1
;
}
grpc
::
ClientContext
client_context
;
grpc
::
ChannelArguments
channel_args
;
channel_args
.
SetMaxReceiveMessageSize
(
128
*
1024
*
1024
);
std
::
shared_ptr
<
grpc
::
Channel
>
channel
=
grpc
::
CreateCustomChannel
(
server_address_
,
grpc
::
InsecureChannelCredentials
(),
channel_args
);
std
::
unique_ptr
<
VMNode
::
Stub
>
stub
=
VMNode
::
NewStub
(
channel
);
using
stream_t
=
std
::
shared_ptr
<
grpc
::
ClientReaderWriter
<
ExecuteTaskRequest
,
ExecuteTaskResponse
>>
;
stream_t
client_stream
(
stub
->
ExecuteTask
(
&
client_context
));
size_t
limited_size
=
1
<<
21
;
size_t
query_num
=
request_proto
.
query
().
size
();
const
auto
&
querys
=
request_proto
.
query
();
size_t
sended_index
{
0
};
std
::
vector
<
ExecuteTaskRequest
>
send_requests
;
do
{
ExecuteTaskRequest
taskRequest
;
PirRequest
*
ptr_request
=
taskRequest
.
mutable_pir_request
();
ptr_request
->
set_galois_keys
(
request_proto
.
galois_keys
());
ptr_request
->
set_relin_keys
(
request_proto
.
relin_keys
());
size_t
pack_size
=
0
;
for
(
size_t
i
=
sended_index
;
i
<
query_num
;
i
++
)
{
// calculate length of query
size_t
query_size
=
0
;
const
auto
&
query
=
querys
[
i
];
for
(
const
auto
&
ct
:
query
.
ct
())
{
query_size
+=
ct
.
size
();
}
if
(
pack_size
+
query_size
>
limited_size
)
{
break
;
}
auto
query_ptr
=
ptr_request
->
add_query
();
for
(
const
auto
&
ct
:
query
.
ct
())
{
query_ptr
->
add_ct
(
ct
);
}
sended_index
++
;
}
auto
*
ptr_params
=
taskRequest
.
mutable_params
()
->
mutable_param_map
();
ParamValue
pv
;
pv
.
set_var_type
(
VarType
::
STRING
);
pv
.
set_value_string
(
server_dataset_
);
(
*
ptr_params
)[
"serverData"
]
=
pv
;
send_requests
.
push_back
(
std
::
move
(
taskRequest
));
if
(
sended_index
>=
query_num
)
{
break
;
}
}
while
(
true
);
// send request to server
for
(
const
auto
&
request
:
send_requests
)
{
client_stream
->
Write
(
request
);
}
client_stream
->
WritesDone
();
ExecuteTaskResponse
taskResponse
;
ExecuteTaskResponse
recv_response
;
auto
pir_response
=
taskResponse
.
mutable_pir_response
();
bool
is_initialized
{
false
};
while
(
client_stream
->
Read
(
&
recv_response
))
{
const
auto
&
recv_pir_response
=
recv_response
.
pir_response
();
if
(
!
is_initialized
)
{
pir_response
->
set_ret_code
(
recv_pir_response
.
ret_code
());
is_initialized
=
true
;
}
for
(
const
auto
&
reply
:
recv_pir_response
.
reply
())
{
auto
reply_ptr
=
pir_response
->
add_reply
();
for
(
const
auto
&
ct
:
reply
.
ct
())
{
reply_ptr
->
add_ct
(
ct
);
}
}
}
Status
status
=
client_stream
->
Finish
();
if
(
status
.
ok
())
{
if
(
taskResponse
.
psi_response
().
ret_code
())
{
LOG
(
ERROR
)
<<
"Node pir server process request error."
;
return
-
1
;
}
int
ret
=
_ProcessResponse
(
taskResponse
);
if
(
ret
)
{
LOG
(
ERROR
)
<<
"Node pir client process response failed."
;
return
-
1
;
}
ret
=
saveResult
();
if
(
ret
)
{
LOG
(
ERROR
)
<<
"Pir save result failed."
;
return
-
1
;
}
}
else
{
LOG
(
ERROR
)
<<
"Pir server return error: "
<<
status
.
error_code
()
<<
" "
<<
status
.
error_message
().
c_str
();
return
-
1
;
}
return
0
;
}
}
src/primihub/task/semantic/pir_client_task.h
已删除
100644 → 0
浏览文件 @
405e10f6
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef SRC_PRIMIHUB_TASK_SEMANTIC_PIR_CLIENT_TASK_H_
#define SRC_PRIMIHUB_TASK_SEMANTIC_PIR_CLIENT_TASK_H_
#include <grpc/grpc.h>
#include <grpcpp/channel.h>
#include <grpcpp/create_channel.h>
#include "pir/cpp/client.h"
#include "pir/cpp/database.h"
#include "pir/cpp/utils.h"
#include "pir/cpp/string_encoder.h"
#include <map>
#include <memory>
#include <string>
#include <set>
#include "src/primihub/protos/common.grpc.pb.h"
#include "src/primihub/protos/psi.grpc.pb.h"
#include "src/primihub/protos/worker.grpc.pb.h"
#include "src/primihub/task/semantic/task.h"
using
pir
::
PIRParameters
;
using
pir
::
EncryptionParameters
;
using
pir
::
PIRClient
;
// using grpc::ClientContext;
using
grpc
::
Status
;
using
grpc
::
Channel
;
using
primihub
::
rpc
::
Ciphertexts
;
using
primihub
::
rpc
::
Task
;
using
primihub
::
rpc
::
ParamValue
;
using
primihub
::
rpc
::
PsiType
;
using
primihub
::
rpc
::
ExecuteTaskRequest
;
using
primihub
::
rpc
::
ExecuteTaskResponse
;
using
primihub
::
rpc
::
PirRequest
;
using
primihub
::
rpc
::
PirResponse
;
using
primihub
::
rpc
::
VMNode
;
namespace
primihub
::
task
{
constexpr
uint32_t
POLY_MODULUS_DEGREE
=
4096
;
constexpr
uint32_t
ELEM_SIZE
=
1024
;
constexpr
uint32_t
PLAIN_MOD_BIT_SIZE_UPBOUND
=
29
;
constexpr
uint32_t
NOISE_BUDGET_BASE
=
57
;
class
PIRClientTask
:
public
TaskBase
{
public:
explicit
PIRClientTask
(
const
std
::
string
&
node_id
,
const
std
::
string
&
job_id
,
const
std
::
string
&
task_id
,
const
TaskParam
*
task_param
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
);
~
PIRClientTask
()
{};
int
execute
()
override
;
int
saveResult
(
void
);
private:
int
_LoadParams
(
Task
&
task
);
int
_SetUpDB
(
size_t
dbsize
,
size_t
dimensions
,
size_t
elem_size
,
uint32_t
plain_mod_bit_size
,
uint32_t
bits_per_coeff
,
bool
use_ciphertext_multiplication
);
int
_ProcessResponse
(
const
ExecuteTaskResponse
&
taskResponse
);
const
std
::
string
node_id_
;
const
std
::
string
job_id_
;
const
std
::
string
task_id_
;
std
::
string
server_address_
;
std
::
string
result_file_path_
;
std
::
vector
<
size_t
>
indices_
;
std
::
vector
<
std
::
string
>
result_
;
std
::
string
server_dataset_
;
size_t
db_size_
;
std
::
shared_ptr
<
PIRParameters
>
pir_params_
;
EncryptionParameters
encryption_params_
;
std
::
unique_ptr
<
PIRClient
>
client_
;
};
}
// namespace primihub::task
#endif // SRC_PRIMIHUB_TASK_SEMANTIC_PIR_CLIENT_TASK_H_
src/primihub/task/semantic/pir_server_task.cc
已删除
100644 → 0
浏览文件 @
405e10f6
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "src/primihub/task/semantic/pir_server_task.h"
namespace
primihub
::
task
{
void
initRequest
(
const
PirRequest
*
request
,
pir
::
Request
&
pir_request
)
{
pir_request
.
set_galois_keys
(
request
->
galois_keys
());
pir_request
.
set_relin_keys
(
request
->
relin_keys
());
const
size_t
num_query
=
static_cast
<
size_t
>
(
request
->
query
().
size
());
for
(
size_t
i
=
0
;
i
<
num_query
;
i
++
)
{
pir
::
Ciphertexts
*
ptr_query
=
pir_request
.
add_query
();
const
size_t
num_ct
=
static_cast
<
size_t
>
(
request
->
query
()[
i
].
ct
().
size
());
for
(
size_t
j
=
0
;
j
<
num_ct
;
j
++
)
{
ptr_query
->
add_ct
(
request
->
query
()[
i
].
ct
()[
j
]);
}
}
}
PIRServerTask
::
PIRServerTask
(
const
std
::
string
&
node_id
,
const
ExecuteTaskRequest
&
request
,
ExecuteTaskResponse
*
response
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
)
:
ServerTaskBase
(
&
(
request
.
params
()),
dataset_service
)
{
request_
=
&
(
request
.
pir_request
());
response_
=
response
->
mutable_pir_response
();
}
int
PIRServerTask
::
loadParams
(
Params
&
params
)
{
auto
param_map
=
params
.
param_map
();
try
{
dataset_path_
=
param_map
[
"serverData"
].
value_string
();
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"Failed to load pir server params: "
<<
e
.
what
();
return
-
1
;
}
return
0
;
}
int
PIRServerTask
::
loadDataset
()
{
// int ret = loadDatasetFromCSV(dataset_path_, 0, elements_, db_size_);
int
ret
=
loadDatasetFromTXT
(
dataset_path_
,
elements_
);
// file reading error or file empty
if
(
ret
<=
0
)
{
LOG
(
ERROR
)
<<
"Load dataset for psi client failed."
;
return
-
1
;
}
LOG
(
INFO
)
<<
"db size = "
<<
ret
;
// output the dataset length
return
ret
;
}
int
PIRServerTask
::
_SetUpDB
(
size_t
dbsize
,
size_t
dimensions
,
size_t
elem_size
,
uint32_t
poly_modulus_degree
,
uint32_t
plain_mod_bit_size
,
uint32_t
bits_per_coeff
,
bool
use_ciphertext_multiplication
)
{
encryption_params_
=
pir
::
GenerateEncryptionParams
(
poly_modulus_degree
,
plain_mod_bit_size
);
pir_params_
=
*
(
pir
::
CreatePIRParameters
(
dbsize
,
elem_size
,
dimensions
,
encryption_params_
,
use_ciphertext_multiplication
,
bits_per_coeff
));
db_size_
=
dbsize
;
if
(
elements_
.
size
()
>
dbsize
)
{
LOG
(
ERROR
)
<<
"Dataset size is not equal dbsize:"
<<
elements_
.
size
();
for
(
int
i
=
0
;
i
<
elements_
.
size
();
i
++
)
{
LOG
(
INFO
)
<<
"elem: "
<<
elements_
[
i
];
}
return
-
1
;
}
else
if
(
elements_
.
size
()
<
dbsize
)
{
uint32_t
seed
=
42
;
auto
prng
=
seal
::
UniformRandomGeneratorFactory
::
DefaultFactory
()
->
create
({
seed
});
for
(
int64_t
i
=
elements_
.
size
();
i
<
dbsize
;
i
++
)
{
int
rand_num
=
rand
()
%
80
;
std
::
string
rand_str
(
rand_num
,
0
);
prng
->
generate
(
rand_str
.
size
(),
reinterpret_cast
<
seal
::
SEAL_BYTE
*>
(
rand_str
.
data
()));
elements_
.
push_back
(
std
::
to_string
(
i
)
+
std
::
to_string
(
i
)
+
std
::
to_string
(
i
)
+
rand_str
);
}
}
std
::
vector
<
std
::
string
>
string_db
;
string_db
.
resize
(
dbsize
,
std
::
string
(
elem_size
,
0
));
for
(
size_t
i
=
0
;
i
<
dbsize
;
++
i
)
{
for
(
int
j
=
0
;
j
<
elements_
[
i
].
length
();
j
++
)
{
string_db
[
i
][
j
]
=
elements_
[
i
][
j
];
}
}
auto
db_status
=
pir
::
PIRDatabase
::
Create
(
string_db
,
pir_params_
);
if
(
!
db_status
.
ok
())
{
LOG
(
ERROR
)
<<
db_status
.
status
();
return
-
1
;
}
else
{
pir_db_
=
std
::
move
(
db_status
).
value
();
}
return
0
;
}
uint32_t
compute_plain_mod_bit_size_server
(
size_t
dbsize
,
size_t
elem_size
)
{
uint32_t
plain_mod_bit_size
=
PLAIN_MOD_BIT_SIZE_UPBOUND_SVR
;
while
(
true
)
{
plain_mod_bit_size
--
;
uint64_t
elem_per_plaintext
=
POLY_MODULUS_DEGREE_SVR
\
*
(
plain_mod_bit_size
-
1
)
/
8
/
elem_size
;
uint64_t
num_plaintext
=
dbsize
/
elem_per_plaintext
+
1
;
if
(
num_plaintext
<=
(
uint64_t
)
1
<<
(
NOISE_BUDGET_BASE_SVR
-
2
*
plain_mod_bit_size
))
{
break
;
}
}
return
plain_mod_bit_size
;
}
int
PIRServerTask
::
execute
()
{
LOG
(
INFO
)
<<
"load parameters"
;
int
ret
=
loadParams
(
params_
);
if
(
ret
)
{
LOG
(
ERROR
)
<<
"Load parameters for pir server fialed."
;
return
-
1
;
}
LOG
(
INFO
)
<<
"parameters loaded"
;
LOG
(
INFO
)
<<
"load dataset"
;
int
db_size
=
loadDataset
();
if
(
db_size
<=
0
)
{
LOG
(
ERROR
)
<<
"Load dataset for pir server failed."
;
return
-
1
;
}
LOG
(
INFO
)
<<
"dataset loaded"
;
size_t
dimensions
=
1
;
size_t
elem_size
=
ELEM_SIZE_SVR
;
uint32_t
plain_mod_bit_size
=
compute_plain_mod_bit_size_server
(
db_size
,
elem_size
);
bool
use_ciphertext_multiplication
=
true
;
uint32_t
poly_modulus_degree
=
POLY_MODULUS_DEGREE_SVR
;
uint32_t
bits_per_coeff
=
0
;
LOG
(
INFO
)
<<
"create database"
;
ret
=
_SetUpDB
(
db_size
,
dimensions
,
elem_size
,
poly_modulus_degree
,
plain_mod_bit_size
,
bits_per_coeff
,
use_ciphertext_multiplication
);
if
(
ret
)
{
LOG
(
ERROR
)
<<
"Create pir db failed."
;
return
-
1
;
}
LOG
(
INFO
)
<<
"database created"
;
pir
::
Request
pir_request
;
initRequest
(
request_
,
pir_request
);
LOG
(
INFO
)
<<
"create server"
;
std
::
unique_ptr
<
pir
::
PIRServer
>
server
=
*
(
pir
::
PIRServer
::
Create
(
pir_db_
,
pir_params_
));
if
(
server
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Failed to create pir server"
;
return
-
1
;
}
LOG
(
INFO
)
<<
"server created"
;
LOG
(
INFO
)
<<
"process request"
;
auto
result_status
=
server
->
ProcessRequest
(
pir_request
);
if
(
!
result_status
.
ok
())
{
LOG
(
ERROR
)
<<
"Process pir request failed:"
<<
result_status
.
status
();
return
-
1
;
}
LOG
(
INFO
)
<<
"request processed"
;
auto
result_raw
=
std
::
move
(
result_status
).
value
();
const
size_t
num_reply
=
static_cast
<
size_t
>
(
result_raw
.
reply
().
size
());
for
(
size_t
i
=
0
;
i
<
num_reply
;
i
++
)
{
Ciphertexts
*
ptr_reply
=
response_
->
add_reply
();
const
size_t
num_ct
=
static_cast
<
size_t
>
(
result_raw
.
reply
()[
i
].
ct
().
size
());
for
(
size_t
j
=
0
;
j
<
num_ct
;
j
++
)
{
ptr_reply
->
add_ct
(
result_raw
.
reply
()[
i
].
ct
()[
j
]);
}
}
return
0
;
}
}
// namespace primihub::task
src/primihub/task/semantic/pir_server_task.h
已删除
100644 → 0
浏览文件 @
405e10f6
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef SRC_PRIMIHUB_TASK_SEMANTIC_PIR_SERVER_TASK_H_
#define SRC_PRIMIHUB_TASK_SEMANTIC_PIR_SERVER_TASK_H_
#include <map>
#include <memory>
#include <string>
#include <stdlib.h>
#include "pir/cpp/server.h"
#include "pir/cpp/database.h"
#include "pir/cpp/utils.h"
#include "pir/cpp/string_encoder.h"
#include "src/primihub/protos/common.grpc.pb.h"
#include "src/primihub/protos/psi.grpc.pb.h"
#include "src/primihub/protos/worker.grpc.pb.h"
#include "src/primihub/task/semantic/private_server_base.h"
using
std
::
shared_ptr
;
using
primihub
::
rpc
::
Params
;
using
primihub
::
rpc
::
Ciphertexts
;
using
primihub
::
rpc
::
PirRequest
;
using
primihub
::
rpc
::
PirResponse
;
using
primihub
::
rpc
::
ExecuteTaskRequest
;
using
primihub
::
rpc
::
ExecuteTaskResponse
;
namespace
primihub
::
task
{
constexpr
uint32_t
POLY_MODULUS_DEGREE_SVR
=
4096
;
constexpr
uint32_t
ELEM_SIZE_SVR
=
1024
;
constexpr
uint32_t
PLAIN_MOD_BIT_SIZE_UPBOUND_SVR
=
29
;
constexpr
uint32_t
NOISE_BUDGET_BASE_SVR
=
57
;
class
PIRServerTask
:
public
ServerTaskBase
{
public:
explicit
PIRServerTask
(
const
std
::
string
&
node_id
,
const
ExecuteTaskRequest
&
request
,
ExecuteTaskResponse
*
response
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
);
~
PIRServerTask
(){}
int
loadParams
(
Params
&
params
)
override
;
int
loadDataset
(
void
)
override
;
int
execute
()
override
;
private:
int
_SetUpDB
(
size_t
dbsize
,
size_t
dimensions
,
size_t
elem_size
,
uint32_t
poly_modulus_degree
,
uint32_t
plain_mod_bit_size
,
uint32_t
bits_per_coeff
,
bool
use_ciphertext_multiplication
);
//int data_col_;
std
::
string
dataset_path_
;
size_t
db_size_
;
shared_ptr
<
pir
::
PIRParameters
>
pir_params_
;
pir
::
EncryptionParameters
encryption_params_
;
std
::
vector
<
std
::
string
>
elements_
;
shared_ptr
<
pir
::
PIRDatabase
>
pir_db_
;
const
PirRequest
*
request_
;
PirResponse
*
response_
;
};
}
// namespace primihub::task
#endif SRC_PRIMIHUB_TASK_SEMANTIC_PIR_SERVER_TASK_H_
src/primihub/task/semantic/pir_task.cc
0 → 100644
浏览文件 @
47e65456
//
#include "src/primihub/task/semantic/pir_task.h"
#include <glog/logging.h>
#include "src/primihub/kernel/pir/operator/base_pir.h"
#include "src/primihub/kernel/pir/operator/factory.h"
namespace
primihub
::
task
{
PirTask
::
PirTask
(
const
TaskParam
*
task_param
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
)
:
TaskBase
(
task_param
,
dataset_service
)
{}
retcode
PirTask
::
BuildOptions
(
const
rpc
::
Task
&
task
,
pir
::
Options
*
options
)
{
// build Options for operator
options
->
self_party
=
this
->
party_name
();
options
->
link_ctx_ref
=
getTaskContext
().
getLinkContext
().
get
();
options
->
code
=
task
.
code
();
auto
&
party_info
=
options
->
party_info
;
const
auto
&
pb_party_info
=
task
.
party_access_info
();
for
(
const
auto
&
[
_party_name
,
pb_node
]
:
pb_party_info
)
{
if
(
_party_name
==
SCHEDULER_NODE
)
{
continue
;
}
Node
node_info
;
pbNode2Node
(
pb_node
,
&
node_info
);
party_info
[
_party_name
]
=
std
::
move
(
node_info
);
}
if
(
RoleValidation
::
IsServer
(
this
->
party_name
()))
{
// paramater for offline generate db info
const
auto
&
param_map
=
task
.
params
().
param_map
();
auto
iter
=
param_map
.
find
(
"DbInfo"
);
if
(
iter
!=
param_map
.
end
())
{
options
->
db_path
=
iter
->
second
.
value_string
();
LOG
(
INFO
)
<<
"db_file_cache path: "
<<
options
->
db_path
;
if
(
this
->
dataset_id_
.
empty
())
{
LOG
(
ERROR
)
<<
"dataset id is empty for party: "
<<
party_name
();
return
retcode
::
FAIL
;
}
ValidateDir
(
options
->
db_path
);
options
->
generate_db
=
true
;
}
else
{
// paramater for online task
if
(
this
->
dataset_id_
.
empty
())
{
LOG
(
ERROR
)
<<
"dataset id is empty for party: "
<<
party_name
();
return
retcode
::
FAIL
;
}
// check db cache exist or not
options
->
db_path
=
db_cache_dir_
+
"/"
+
this
->
dataset_id_
;
if
(
DbCacheAvailable
(
options
->
db_path
))
{
options
->
use_cache
=
true
;
}
}
}
// peer node info
std
::
string
peer_party_name
;
if
(
RoleValidation
::
IsServer
(
this
->
party_name
()))
{
peer_party_name
=
PARTY_CLIENT
;
}
else
if
(
RoleValidation
::
IsClient
(
this
->
party_name
()))
{
peer_party_name
=
PARTY_SERVER
;
}
else
{
LOG
(
ERROR
)
<<
"invalid party: "
<<
this
->
party_name
();
}
auto
it
=
party_info
.
find
(
peer_party_name
);
if
(
it
!=
party_info
.
end
())
{
options
->
peer_node
=
it
->
second
;
}
else
{
LOG
(
WARNING
)
<<
"find peer node info failed for party: "
<<
party_name
();
}
// end of build Options
return
retcode
::
SUCCESS
;
}
retcode
PirTask
::
LoadParams
(
const
rpc
::
Task
&
task
)
{
const
auto
&
param_map
=
task
.
params
().
param_map
();
auto
iter
=
param_map
.
find
(
"pirType"
);
if
(
iter
!=
param_map
.
end
())
{
pir_type_
=
iter
->
second
.
value_int32
();
}
const
auto
&
party_datasets
=
task
.
party_datasets
();
auto
dataset_it
=
party_datasets
.
find
(
party_name
());
if
(
dataset_it
!=
party_datasets
.
end
())
{
const
auto
&
datasets_map
=
dataset_it
->
second
.
data
();
auto
it
=
datasets_map
.
find
(
party_name
());
if
(
it
==
datasets_map
.
end
())
{
LOG
(
WARNING
)
<<
"no datasets is set for party: "
<<
party_name
();
}
else
{
dataset_id_
=
it
->
second
;
VLOG
(
5
)
<<
"data set id: "
<<
dataset_id_
;
}
}
auto
ret
=
BuildOptions
(
task
,
&
this
->
options_
);
if
(
ret
!=
retcode
::
SUCCESS
)
{
LOG
(
ERROR
)
<<
"build operator options for party: "
<<
party_name
()
<<
" failed"
;
return
retcode
::
FAIL
;
}
if
(
RoleValidation
::
IsClient
(
this
->
party_name
()))
{
VLOG
(
7
)
<<
"dataset_id: "
<<
dataset_id_
;
auto
it
=
param_map
.
find
(
"outputFullFilename"
);
if
(
it
!=
param_map
.
end
())
{
result_file_path_
=
it
->
second
.
value_string
();
VLOG
(
5
)
<<
"result_file_path_: "
<<
result_file_path_
;
}
else
{
LOG
(
ERROR
)
<<
"no keyword outputFullFilename match"
;
return
retcode
::
FAIL
;
}
GetServerDataSetSchema
(
task
);
}
return
retcode
::
SUCCESS
;
}
retcode
PirTask
::
GetServerDataSetSchema
(
const
rpc
::
Task
&
task
)
{
// get server dataset id
const
auto
&
party_datasets
=
task
.
party_datasets
();
auto
it
=
party_datasets
.
find
(
PARTY_SERVER
);
if
(
it
==
party_datasets
.
end
())
{
LOG
(
WARNING
)
<<
"no dataset found for party_name: "
<<
PARTY_SERVER
;
return
retcode
::
FAIL
;
}
const
auto
&
datasets_map
=
it
->
second
.
data
();
auto
iter
=
datasets_map
.
find
(
PARTY_SERVER
);
if
(
iter
==
datasets_map
.
end
())
{
LOG
(
WARNING
)
<<
"no dataset found for party_name: "
<<
PARTY_SERVER
;
return
retcode
::
FAIL
;
}
auto
&
server_dataset_id
=
iter
->
second
;
auto
&
dataset_service
=
this
->
getDatasetService
();
auto
driver
=
dataset_service
->
getDriver
(
server_dataset_id
);
if
(
driver
==
nullptr
)
{
LOG
(
WARNING
)
<<
"no dataset access info found for id: "
<<
server_dataset_id
;
return
retcode
::
FAIL
;
}
auto
&
access_info
=
driver
->
dataSetAccessInfo
();
if
(
access_info
==
nullptr
)
{
LOG
(
WARNING
)
<<
"no dataset access info found for id: "
<<
server_dataset_id
;
return
retcode
::
FAIL
;
}
auto
&
schema
=
access_info
->
Schema
();
for
(
const
auto
&
field
:
schema
)
{
server_dataset_schema_
.
push_back
(
std
::
get
<
0
>
(
field
));
}
return
retcode
::
SUCCESS
;
}
retcode
PirTask
::
LoadDataset
()
{
CHECK_TASK_STOPPED
(
retcode
::
FAIL
);
if
(
RoleValidation
::
IsClient
(
this
->
party_name
()))
{
return
ClientLoadDataset
();
}
else
if
(
RoleValidation
::
IsServer
(
this
->
party_name
()))
{
return
ServerLoadDataset
();
}
else
{
LOG
(
WARNING
)
<<
"party: "
<<
this
->
party_name
()
<<
" does not load dataset"
;
return
retcode
::
SUCCESS
;
}
}
retcode
PirTask
::
ClientLoadDataset
()
{
const
auto
&
param_map
=
getTaskParam
()
->
params
().
param_map
();
auto
client_data_it
=
param_map
.
find
(
"clientData"
);
if
(
client_data_it
!=
param_map
.
end
())
{
auto
&
client_data
=
client_data_it
->
second
;
if
(
client_data
.
is_array
())
{
const
auto
&
items
=
client_data
.
value_string_array
().
value_string_array
();
for
(
const
auto
&
item
:
items
)
{
elements_
[
item
];
}
if
(
elements_
.
empty
())
{
LOG
(
ERROR
)
<<
"no query data set by client"
;
return
retcode
::
FAIL
;
}
}
else
{
auto
item
=
client_data
.
value_string
();
elements_
[
item
];
}
return
retcode
::
SUCCESS
;
}
if
(
this
->
dataset_id_
.
empty
())
{
LOG
(
ERROR
)
<<
"no dataset found for client: "
<<
party_name
();
return
retcode
::
FAIL
;
}
VLOG
(
7
)
<<
"dataset_id: "
<<
this
->
dataset_id_
;
auto
data_ptr
=
LoadDataSetInternal
(
this
->
dataset_id_
);
if
(
data_ptr
==
nullptr
)
{
LOG
(
ERROR
)
<<
"read data for dataset id: "
<<
this
->
dataset_id_
<<
" failed"
;
return
retcode
::
FAIL
;
}
auto
&
table
=
std
::
get
<
std
::
shared_ptr
<
arrow
::
Table
>>
(
data_ptr
->
data
);
std
::
vector
<
int
>
key_col
=
{
0
};
auto
key_array
=
GetSelectedContent
(
table
,
key_col
);
for
(
auto
&
item
:
key_array
)
{
elements_
[
item
];
}
return
retcode
::
SUCCESS
;
}
retcode
PirTask
::
ServerLoadDataset
()
{
if
(
this
->
options_
.
use_cache
)
{
VLOG
(
0
)
<<
"using cache data for party: "
<<
party_name
();
return
retcode
::
SUCCESS
;
}
auto
data_ptr
=
LoadDataSetInternal
(
this
->
dataset_id_
);
if
(
data_ptr
==
nullptr
)
{
LOG
(
ERROR
)
<<
"read data for dataset id: "
<<
this
->
dataset_id_
<<
" failed"
;
return
retcode
::
FAIL
;
}
auto
&
table
=
std
::
get
<
std
::
shared_ptr
<
arrow
::
Table
>>
(
data_ptr
->
data
);
int
col_count
=
table
->
num_columns
();
size_t
row_count
=
table
->
num_rows
();
if
(
col_count
<
2
)
{
LOG
(
ERROR
)
<<
"data for server must have lable"
;
return
retcode
::
FAIL
;
}
std
::
vector
<
int
>
key_col
=
{
0
};
auto
key_array
=
GetSelectedContent
(
table
,
key_col
);
// get label
std
::
vector
<
int
>
value_col
;
for
(
int
i
=
1
;
i
<
col_count
;
i
++
)
{
value_col
.
push_back
(
i
);
}
if
(
value_col
.
empty
())
{
LOG
(
ERROR
)
<<
"no selected colum for lable"
;
return
retcode
::
FAIL
;
}
auto
value_array
=
GetSelectedContent
(
table
,
value_col
);
elements_
.
reserve
(
key_array
.
size
());
for
(
size_t
i
=
0
;
i
<
key_array
.
size
();
++
i
)
{
auto
&
key
=
key_array
[
i
];
auto
&
value
=
value_array
[
i
];
auto
it
=
elements_
.
find
(
key
);
if
(
it
!=
elements_
.
end
())
{
it
->
second
.
push_back
(
value
);
}
else
{
std
::
vector
<
std
::
string
>
vec
;
vec
.
push_back
(
value
);
elements_
.
insert
({
key
,
std
::
move
(
vec
)});
}
}
return
retcode
::
SUCCESS
;
}
std
::
shared_ptr
<
Dataset
>
PirTask
::
LoadDataSetInternal
(
const
std
::
string
&
dataset_id
)
{
auto
driver
=
this
->
getDatasetService
()
->
getDriver
(
dataset_id
);
if
(
driver
==
nullptr
)
{
LOG
(
ERROR
)
<<
"get driver for dataset: "
<<
dataset_id
<<
" failed"
;
return
nullptr
;
}
auto
cursor
=
driver
->
GetCursor
();
if
(
cursor
==
nullptr
)
{
LOG
(
ERROR
)
<<
"init cursor failed for dataset id: "
<<
dataset_id
;
return
nullptr
;
}
// maybe pass schema to get expected data type
// copy dataset schema, and change all filed to string
auto
schema
=
driver
->
dataSetAccessInfo
()
->
Schema
();
for
(
auto
&
field
:
schema
)
{
auto
&
type
=
std
::
get
<
1
>
(
field
);
type
=
arrow
::
Type
::
type
::
STRING
;
}
auto
data
=
cursor
->
read
(
schema
);
if
(
data
==
nullptr
)
{
LOG
(
ERROR
)
<<
"read data failed for dataset id: "
<<
dataset_id
;
return
nullptr
;
}
return
data
;
}
retcode
PirTask
::
SaveResult
()
{
if
(
!
NeedSaveResult
())
{
return
retcode
::
SUCCESS
;
}
VLOG
(
0
)
<<
"save query result to : "
<<
result_file_path_
;
std
::
vector
<
std
::
shared_ptr
<
arrow
::
Field
>>
schema_vector
;
std
::
vector
<
std
::
string
>
tmp_colums
{
"key"
,
"value"
};
for
(
const
auto
&
col_name
:
tmp_colums
)
{
schema_vector
.
push_back
(
arrow
::
field
(
col_name
,
arrow
::
int64
()));
}
std
::
vector
<
std
::
shared_ptr
<
arrow
::
Array
>>
arrow_array
;
arrow
::
StringBuilder
key_builder
;
arrow
::
StringBuilder
value_builder
;
for
(
auto
&
[
key
,
item_vec
]
:
this
->
result_
)
{
for
(
const
auto
&
item
:
item_vec
)
{
key_builder
.
Append
(
key
);
value_builder
.
Append
(
item
);
}
}
std
::
shared_ptr
<
arrow
::
Array
>
key_array
;
key_builder
.
Finish
(
&
key_array
);
arrow_array
.
push_back
(
std
::
move
(
key_array
));
std
::
shared_ptr
<
arrow
::
Array
>
value_array
;
value_builder
.
Finish
(
&
value_array
);
arrow_array
.
push_back
(
std
::
move
(
value_array
));
auto
schema
=
std
::
make_shared
<
arrow
::
Schema
>
(
schema_vector
);
// std::shared_ptr<arrow::Table>
auto
table
=
arrow
::
Table
::
Make
(
schema
,
arrow_array
);
auto
driver
=
DataDirverFactory
::
getDriver
(
"CSV"
,
"test address"
);
auto
csv_driver
=
std
::
dynamic_pointer_cast
<
CSVDriver
>
(
driver
);
auto
rtcode
=
csv_driver
->
Write
(
server_dataset_schema_
,
table
,
result_file_path_
);
if
(
rtcode
!=
retcode
::
SUCCESS
)
{
LOG
(
ERROR
)
<<
"save PIR data to file "
<<
result_file_path_
<<
" failed."
;
return
retcode
::
FAIL
;
}
return
retcode
::
SUCCESS
;
}
retcode
PirTask
::
InitOperator
()
{
auto
type
=
static_cast
<
primihub
::
pir
::
PirType
>
(
pir_type_
);
this
->
operator_
=
primihub
::
pir
::
Factory
::
Create
(
type
,
options_
);
if
(
this
->
operator_
==
nullptr
)
{
LOG
(
ERROR
)
<<
"create pir operator failed"
;
return
retcode
::
FAIL
;
}
return
retcode
::
SUCCESS
;
}
retcode
PirTask
::
ExecuteOperator
()
{
return
operator_
->
Execute
(
elements_
,
&
result_
);
}
int
PirTask
::
execute
()
{
SCopedTimer
timer
;
auto
ret
=
LoadParams
(
task_param_
);
if
(
ret
!=
retcode
::
SUCCESS
)
{
LOG
(
ERROR
)
<<
"Pir load task params failed."
;
return
-
1
;
}
auto
load_params_ts
=
timer
.
timeElapse
();
VLOG
(
5
)
<<
"LoadParams time cost(ms): "
<<
load_params_ts
;
ret
=
LoadDataset
();
if
(
ret
!=
retcode
::
SUCCESS
)
{
LOG
(
ERROR
)
<<
"Pir load dataset failed."
;
return
-
1
;
}
auto
load_dataset_ts
=
timer
.
timeElapse
();
auto
load_dataset_time_cost
=
load_dataset_ts
-
load_params_ts
;
VLOG
(
5
)
<<
"LoadDataset time cost(ms): "
<<
load_dataset_time_cost
;
ret
=
InitOperator
();
if
(
ret
!=
retcode
::
SUCCESS
)
{
LOG
(
ERROR
)
<<
"Pir init operator failed."
;
return
-
1
;
}
auto
init_op_ts
=
timer
.
timeElapse
();
auto
init_op_time_cost
=
init_op_ts
-
load_dataset_ts
;
VLOG
(
5
)
<<
"InitOperator time cost(ms): "
<<
init_op_time_cost
;
ret
=
ExecuteOperator
();
if
(
ret
!=
retcode
::
SUCCESS
)
{
LOG
(
ERROR
)
<<
"Pir execute operator failed."
;
return
-
1
;
}
auto
exec_op_ts
=
timer
.
timeElapse
();
auto
exec_op_time_cost
=
exec_op_ts
-
init_op_ts
;
VLOG
(
5
)
<<
"ExecuteOperator time cost(ms): "
<<
exec_op_time_cost
;
ret
=
SaveResult
();
if
(
ret
!=
retcode
::
SUCCESS
)
{
LOG
(
ERROR
)
<<
"Pir save result failed."
;
return
-
1
;
}
auto
save_res_ts
=
timer
.
timeElapse
();
auto
save_res_time_cost
=
save_res_ts
-
exec_op_ts
;
VLOG
(
5
)
<<
"SaveResult time cost(ms): "
<<
save_res_time_cost
;
return
0
;
}
std
::
vector
<
std
::
string
>
PirTask
::
GetSelectedContent
(
std
::
shared_ptr
<
arrow
::
Table
>&
data_tbl
,
const
std
::
vector
<
int
>&
selected_col
)
{
// return std::vector<std::string>();
int
col_count
=
data_tbl
->
num_columns
();
size_t
row_count
=
data_tbl
->
num_rows
();
if
(
selected_col
.
empty
())
{
LOG
(
ERROR
)
<<
"no col selected for data"
;
return
std
::
vector
<
std
::
string
>
();
}
std
::
vector
<
std
::
string
>
content_array
;
auto
lable_ptr
=
data_tbl
->
column
(
selected_col
[
0
]);
auto
chunk_size
=
lable_ptr
->
num_chunks
();
size_t
total_row_count
=
col_count
*
chunk_size
;
content_array
.
reserve
(
total_row_count
);
for
(
int
i
=
0
;
i
<
chunk_size
;
++
i
)
{
auto
array
=
std
::
static_pointer_cast
<
arrow
::
StringArray
>
(
lable_ptr
->
chunk
(
i
));
for
(
int64_t
j
=
0
;
j
<
array
->
length
();
j
++
)
{
content_array
.
push_back
(
array
->
GetString
(
j
));
}
}
// process left colums
for
(
size_t
i
=
1
;
i
<
selected_col
.
size
();
++
i
)
{
size_t
index
{
0
};
int
col_index
=
selected_col
[
i
];
auto
lable_ptr
=
data_tbl
->
column
(
col_index
);
int
chunk_size
=
lable_ptr
->
num_chunks
();
for
(
int
j
=
0
;
j
<
chunk_size
;
++
j
)
{
auto
array
=
std
::
static_pointer_cast
<
arrow
::
StringArray
>
(
lable_ptr
->
chunk
(
j
));
for
(
int64_t
k
=
0
;
k
<
array
->
length
();
++
k
)
{
content_array
[
index
++
].
append
(
","
).
append
(
array
->
GetString
(
k
));
}
}
}
return
content_array
;
}
bool
PirTask
::
NeedSaveResult
()
{
if
(
RoleValidation
::
IsClient
(
this
->
party_name
()))
{
return
true
;
}
return
false
;
}
}
// namespace primihub::task
src/primihub/task/semantic/pir_task.h
0 → 100644
浏览文件 @
47e65456
// Copyright 2023 <PrimiHub>
#ifndef SRC_PRIMIHUB_TASK_SEMANTIC_PIR_TASK_H_
#define SRC_PRIMIHUB_TASK_SEMANTIC_PIR_TASK_H_
#include <string>
#include "src/primihub/task/semantic/task.h"
#include "src/primihub/common/common.h"
#include "src/primihub/service/dataset/service.h"
#include "src/primihub/kernel/pir/common.h"
#include "src/primihub/kernel/pir/operator/base_pir.h"
#include "src/primihub/util/util.h"
#include "src/primihub/util/file_util.h"
namespace
primihub
::
task
{
using
BasePirOperator
=
primihub
::
pir
::
BasePirOperator
;
class
PirTask
:
public
TaskBase
{
public:
PirTask
(
const
TaskParam
*
task_param
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
);
~
PirTask
()
=
default
;
int
execute
()
override
;
protected:
retcode
LoadParams
(
const
rpc
::
Task
&
task
);
retcode
GetServerDataSetSchema
(
const
rpc
::
Task
&
task
);
retcode
LoadDataset
();
retcode
ClientLoadDataset
();
retcode
ServerLoadDataset
();
std
::
shared_ptr
<
Dataset
>
LoadDataSetInternal
(
const
std
::
string
&
dataset_id
);
bool
DbCacheAvailable
(
const
std
::
string
&
db_file_cache
)
{
return
FileExists
(
db_file_cache
);
}
std
::
vector
<
std
::
string
>
GetSelectedContent
(
std
::
shared_ptr
<
arrow
::
Table
>&
data_tbl
,
const
std
::
vector
<
int
>&
selected_col
);
retcode
SaveResult
();
retcode
InitOperator
();
retcode
ExecuteOperator
();
retcode
BuildOptions
(
const
rpc
::
Task
&
task
,
primihub
::
pir
::
Options
*
option
);
bool
NeedSaveResult
();
private:
int
pir_type_
{
rpc
::
PirType
::
KEY_PIR
};
std
::
string
dataset_path_
;
std
::
string
dataset_id_
;
std
::
string
result_file_path_
;
primihub
::
pir
::
PirDataType
elements_
;
primihub
::
pir
::
PirDataType
result_
;
primihub
::
pir
::
Options
options_
;
std
::
string
db_cache_dir_
{
"data/cache"
};
std
::
unique_ptr
<
BasePirOperator
>
operator_
{
nullptr
};
std
::
vector
<
std
::
string
>
server_dataset_schema_
;
// std::string dataset_path_;
// std::string dataset_id_;
// std::string db_file_cache_;
// primihub::Node client_node_;
// std::string key{"key_pir"};
// std::string psi_params_str_;
// std::unique_ptr<apsi::oprf::OPRFKey> oprf_key_{nullptr};
// bool generate_db_offline_{false};
};
}
// namespace primihub::task
#endif // SRC_PRIMIHUB_TASK_SEMANTIC_PIR_TASK_H_
src/primihub/task/semantic/private_server_base.cc
已删除
100644 → 0
浏览文件 @
405e10f6
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "src/primihub/task/semantic/private_server_base.h"
#include "src/primihub/data_store/factory.h"
#include <fstream>
using
arrow
::
Array
;
using
arrow
::
StringArray
;
using
arrow
::
Table
;
namespace
primihub
::
task
{
ServerTaskBase
::
ServerTaskBase
(
const
Params
*
params
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
)
:
dataset_service_
(
dataset_service
)
{
setTaskParam
(
params
);
}
Params
*
ServerTaskBase
::
getTaskParam
()
{
return
&
params_
;
}
void
ServerTaskBase
::
setTaskParam
(
const
Params
*
params
)
{
params_
.
CopyFrom
(
*
params
);
}
int
ServerTaskBase
::
loadDatasetFromSQLite
(
const
std
::
string
&
conn_str
,
int
data_col
,
std
::
vector
<
std
::
string
>&
col_array
,
int64_t
max_num
)
{
std
::
string
nodeaddr
(
"localhost"
);
// TODO
// std::shared_ptr<DataDriver>
auto
driver
=
DataDirverFactory
::
getDriver
(
"SQLITE"
,
nodeaddr
);
if
(
driver
==
nullptr
)
{
LOG
(
ERROR
)
<<
"create sqlite db driver failed"
;
return
-
1
;
}
// std::shared_ptr<Cursor> &cursor
auto
cursor
=
driver
->
read
(
conn_str
);
// std::shared_ptr<Dataset>
auto
ds
=
cursor
->
read
();
if
(
ds
==
nullptr
)
{
return
-
1
;
}
auto
table
=
std
::
get
<
std
::
shared_ptr
<
Table
>>
(
ds
->
data
);
int
num_col
=
table
->
num_columns
();
if
(
num_col
<
data_col
)
{
LOG
(
ERROR
)
<<
"psi dataset colunum number is smaller than data_col"
;
return
-
1
;
}
auto
array
=
std
::
static_pointer_cast
<
StringArray
>
(
table
->
column
(
data_col
)
->
chunk
(
0
));
for
(
int64_t
i
=
0
;
i
<
array
->
length
();
i
++
)
{
if
(
max_num
>
0
&&
max_num
==
i
)
{
break
;
}
col_array
.
push_back
(
array
->
GetString
(
i
));
}
VLOG
(
5
)
<<
"psi server loaded data records: "
<<
col_array
.
size
();
return
array
->
length
();
}
int
ServerTaskBase
::
loadDatasetFromCSV
(
const
std
::
string
&
filename
,
int
data_col
,
std
::
vector
<
std
::
string
>
&
col_array
,
int64_t
max_num
)
{
std
::
string
nodeaddr
(
"test address"
);
// TODO
std
::
shared_ptr
<
DataDriver
>
driver
=
DataDirverFactory
::
getDriver
(
"CSV"
,
nodeaddr
);
auto
cursor
=
driver
->
read
(
filename
);
auto
ds
=
cursor
->
read
();
std
::
shared_ptr
<
Table
>
table
=
std
::
get
<
std
::
shared_ptr
<
Table
>>
(
ds
->
data
);
int
num_col
=
table
->
num_columns
();
if
(
num_col
<
data_col
)
{
LOG
(
ERROR
)
<<
"psi dataset colunum number is smaller than data_col"
;
return
-
1
;
}
int64_t
num_rows
=
table
->
num_rows
();
int64_t
num_records
=
max_num
>
0
?
max_num
:
num_rows
;
col_array
.
reserve
(
num_records
);
auto
col_ptr
=
table
->
column
(
data_col
);
int
chunk_size
=
col_ptr
->
num_chunks
();
for
(
int
i
=
0
;
i
<
chunk_size
;
i
++
)
{
auto
array
=
std
::
static_pointer_cast
<
StringArray
>
(
col_ptr
->
chunk
(
i
));
for
(
size_t
j
=
0
;
j
<
array
->
length
();
j
++
)
{
col_array
.
push_back
(
array
->
GetString
(
j
));
if
(
max_num
>
0
&&
max_num
==
col_array
.
size
())
{
return
col_array
.
size
();
}
}
}
return
col_array
.
size
();
}
int
ServerTaskBase
::
loadDatasetFromTXT
(
std
::
string
&
filename
,
std
::
vector
<
std
::
string
>
&
col_array
)
{
LOG
(
INFO
)
<<
"loading file ..."
;
std
::
ifstream
infile
;
infile
.
open
(
filename
);
col_array
.
clear
();
std
::
string
tmp
;
std
::
getline
(
infile
,
tmp
);
// ignore the first line
while
(
std
::
getline
(
infile
,
tmp
))
{
col_array
.
push_back
(
tmp
);
}
infile
.
close
();
return
col_array
.
size
();
}
}
// namespace primihub::task
src/primihub/task/semantic/private_server_base.h
已删除
100644 → 0
浏览文件 @
405e10f6
/*
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef SRC_PRIMIHUB_TASK_SEMANTIC_PRIVATE_SERVER_BASE_H_
#define SRC_PRIMIHUB_TASK_SEMANTIC_PRIVATE_SERVER_BASE_H_
#include <map>
#include <memory>
#include <string>
#include <atomic>
#include <glog/logging.h>
#include "src/primihub/protos/common.pb.h"
#include "src/primihub/service/dataset/service.h"
#include "src/primihub/task/semantic/task.h"
using
primihub
::
rpc
::
Params
;
using
primihub
::
service
::
DatasetService
;
namespace
primihub
::
task
{
class
ServerTaskBase
{
public:
// using task_context_t = TaskContext<primihub::rpc::ExecuteTaskRequest, primihub::rpc::ExecuteTaskResponse>;
//
using
task_context_t
=
TaskContext
;
ServerTaskBase
(
const
Params
*
params
,
std
::
shared_ptr
<
DatasetService
>
dataset_service
);
~
ServerTaskBase
(){}
virtual
int
execute
()
=
0
;
virtual
int
loadParams
(
Params
&
params
)
=
0
;
virtual
int
loadDataset
(
void
)
=
0
;
virtual
void
kill_task
()
{
LOG
(
WARNING
)
<<
"task receives kill task request and stop stauts"
;
stop_
.
store
(
true
);
task_context_
.
clean
();
}
bool
has_stopped
()
{
return
stop_
.
load
(
std
::
memory_order_relaxed
);
}
std
::
shared_ptr
<
DatasetService
>&
getDatasetService
()
{
return
dataset_service_
;
}
void
setTaskParam
(
const
Params
*
params
);
Params
*
getTaskParam
();
inline
task_context_t
&
getTaskContext
()
{
return
task_context_
;
}
inline
task_context_t
*
getMutableTaskContext
()
{
return
&
task_context_
;
}
protected:
int
loadDatasetFromCSV
(
const
std
::
string
&
filename
,
int
data_col
,
std
::
vector
<
std
::
string
>
&
col_array
,
int64_t
max_num
=
0
);
int
loadDatasetFromSQLite
(
const
std
::
string
&
conn_str
,
int
data_col
,
std
::
vector
<
std
::
string
>&
col_array
,
int64_t
max_num
=
0
);
int
loadDatasetFromTXT
(
std
::
string
&
filename
,
std
::
vector
<
std
::
string
>
&
col_array
);
std
::
atomic
<
bool
>
stop_
{
false
};
Params
params_
;
std
::
shared_ptr
<
DatasetService
>
dataset_service_
;
task_context_t
task_context_
;
};
}
// namespace primihub::task
#endif // SRC_PRIMIHUB_TASK_SEMANTIC_PRIVATE_SERVER_BASE_H_
src/primihub/task/semantic/task.h
浏览文件 @
47e65456
/*
Copyright 2022 Primi
h
ub
Copyright 2022 Primi
H
ub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
...
...
src/primihub/util/network/link_context.cc
浏览文件 @
47e65456
//
#include "src/primihub/util/network/link_context.h"
namespace
primihub
::
network
{
void
LinkContext
::
Clean
()
{
stop_
.
store
(
true
);
LOG
(
WARNING
)
<<
"stop all in data queue"
;
{
std
::
lock_guard
<
std
::
mutex
>
lck
(
in_queue_mtx
);
for
(
auto
it
=
in_data_queue
.
begin
();
it
!=
in_data_queue
.
end
();
++
it
)
{
it
->
second
.
shutdown
();
}
}
LOG
(
WARNING
)
<<
"stop all out data queue"
;
{
std
::
lock_guard
<
std
::
mutex
>
lck
(
out_queue_mtx
);
for
(
auto
it
=
out_data_queue
.
begin
();
it
!=
out_data_queue
.
end
();
++
it
)
{
it
->
second
.
shutdown
();
}
}
LOG
(
WARNING
)
<<
"stop all complete queue"
;
{
std
::
lock_guard
<
std
::
mutex
>
lck
(
complete_queue_mtx
);
for
(
auto
it
=
complete_queue
.
begin
();
it
!=
complete_queue
.
end
();
++
it
)
{
it
->
second
.
shutdown
();
}
}
}
LinkContext
::
StringDataQueue
&
LinkContext
::
GetRecvQueue
(
const
std
::
string
&
key
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
this
->
in_queue_mtx
);
auto
it
=
in_data_queue
.
find
(
key
);
if
(
it
!=
in_data_queue
.
end
())
{
return
it
->
second
;
}
else
{
in_data_queue
[
key
];
if
(
stop_
.
load
(
std
::
memory_order
::
memory_order_relaxed
))
{
in_data_queue
[
key
].
shutdown
();
}
return
in_data_queue
[
key
];
}
}
LinkContext
::
StringDataQueue
&
LinkContext
::
GetSendQueue
(
const
std
::
string
&
key
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
this
->
out_queue_mtx
);
auto
it
=
out_data_queue
.
find
(
key
);
if
(
it
!=
out_data_queue
.
end
())
{
return
it
->
second
;
}
else
{
return
out_data_queue
[
key
];
}
}
LinkContext
::
StatusDataQueue
&
LinkContext
::
GetCompleteQueue
(
const
std
::
string
&
key
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
this
->
complete_queue_mtx
);
auto
it
=
complete_queue
.
find
(
key
);
if
(
it
!=
complete_queue
.
end
())
{
return
it
->
second
;
}
else
{
return
complete_queue
[
key
];
}
}
retcode
LinkContext
::
Send
(
const
std
::
string
&
key
,
const
Node
&
dest_node
,
const
std
::
string
&
send_buf
)
{
std
::
string_view
send_data_sv
{
send_buf
.
data
(),
send_buf
.
size
()};
return
Send
(
key
,
dest_node
,
send_data_sv
);
}
retcode
LinkContext
::
Send
(
const
std
::
string
&
key
,
const
Node
&
dest_node
,
std
::
string_view
send_buf_sv
)
{
auto
ch
=
getChannel
(
dest_node
);
return
ch
->
send
(
key
,
send_buf_sv
);
}
retcode
LinkContext
::
Send
(
const
std
::
string
&
key
,
const
Node
&
dest_node
,
char
*
send_buf
,
size_t
send_size
)
{
std
::
string_view
send_data_sv
{
send_buf
,
send_size
};
return
Send
(
key
,
dest_node
,
send_data_sv
);
}
retcode
LinkContext
::
Recv
(
const
std
::
string
&
key
,
std
::
string
*
recv_buf
)
{
std
::
string
recv_buf_tmp
;
auto
&
recv_queue
=
GetRecvQueue
(
key
);
recv_queue
.
wait_and_pop
(
recv_buf_tmp
);
*
recv_buf
=
std
::
move
(
recv_buf_tmp
);
return
retcode
::
SUCCESS
;
}
retcode
LinkContext
::
Recv
(
const
std
::
string
&
key
,
char
*
recv_buf
,
size_t
recv_size
)
{
std
::
string
recv_buf_tmp
;
auto
&
recv_queue
=
GetRecvQueue
(
key
);
recv_queue
.
wait_and_pop
(
recv_buf_tmp
);
if
(
recv_size
!=
recv_buf_tmp
.
size
())
{
LOG
(
ERROR
)
<<
"recv data does not match, expected: "
<<
recv_size
<<
" but get: "
<<
recv_buf_tmp
.
size
();
return
retcode
::
FAIL
;
}
memcpy
(
recv_buf
,
recv_buf_tmp
.
data
(),
recv_size
);
return
retcode
::
SUCCESS
;
}
retcode
LinkContext
::
SendRecv
(
const
std
::
string
&
key
,
const
Node
&
dest_node
,
std
::
string_view
send_buf
,
std
::
string
*
recv_buf
)
{
auto
channel
=
getChannel
(
dest_node
);
auto
ret
=
channel
->
sendRecv
(
key
,
send_buf
,
recv_buf
);
if
(
ret
!=
retcode
::
SUCCESS
)
{
LOG
(
ERROR
)
<<
"send data to peer: ["
<<
dest_node
.
to_string
()
<<
"] failed"
;
return
ret
;
}
return
retcode
::
SUCCESS
;
}
retcode
LinkContext
::
SendRecv
(
const
std
::
string
&
key
,
const
Node
&
dest_node
,
const
std
::
string
&
send_buf
,
std
::
string
*
recv_buf
)
{
auto
send_buf_sv
=
std
::
string_view
(
send_buf
.
data
(),
send_buf
.
size
());
return
SendRecv
(
key
,
dest_node
,
send_buf_sv
,
recv_buf
);
}
retcode
LinkContext
::
SendRecv
(
const
std
::
string
&
key
,
const
Node
&
dest_node
,
const
char
*
send_buf
,
size_t
length
,
std
::
string
*
recv_buf
)
{
auto
send_buf_sv
=
std
::
string_view
(
send_buf
,
length
);
return
SendRecv
(
key
,
dest_node
,
send_buf_sv
,
recv_buf
);
}
retcode
LinkContext
::
SendRecv
(
const
std
::
string
&
key
,
const
std
::
string
&
send_buf
,
std
::
string
*
recv_buf
)
{
std
::
string
recv_buf_tmp
;
auto
&
recv_queue
=
this
->
GetRecvQueue
(
key
);
recv_queue
.
wait_and_pop
(
recv_buf_tmp
);
*
recv_buf
=
std
::
move
(
recv_buf_tmp
);
if
(
HasStopped
())
{
LOG
(
ERROR
)
<<
"link context has been closed"
;
return
retcode
::
FAIL
;
}
auto
&
send_queue
=
this
->
GetSendQueue
(
key
);
send_queue
.
push
(
send_buf
);
auto
&
complete_queue
=
this
->
GetCompleteQueue
(
key
);
retcode
complete_flag
;
complete_queue
.
wait_and_pop
(
complete_flag
);
return
retcode
::
SUCCESS
;
}
}
// namespace primihub::network
src/primihub/util/network/link_context.h
浏览文件 @
47e65456
...
...
@@ -20,6 +20,10 @@ class IChannel;
*/
class
LinkContext
{
public:
using
StringDataQueue
=
primihub
::
ThreadSafeQueue
<
std
::
string
>
;
using
StringDataContainer
=
std
::
unordered_map
<
std
::
string
,
StringDataQueue
>
;
using
StatusDataQueue
=
primihub
::
ThreadSafeQueue
<
retcode
>
;
using
StatusDataContainer
=
std
::
unordered_map
<
std
::
string
,
StatusDataQueue
>
;
LinkContext
()
=
default
;
virtual
~
LinkContext
()
=
default
;
inline
void
setTaskInfo
(
const
std
::
string
&
job_id
,
...
...
@@ -43,11 +47,17 @@ class LinkContext {
* if channel is not exist, create
*/
virtual
std
::
shared_ptr
<
IChannel
>
getChannel
(
const
primihub
::
Node
&
node
)
=
0
;
void
setRecvTimeout
(
int32_t
recv_timeout_ms
)
{
recv_timeout_ms_
=
recv_timeout_ms
;}
void
setSendTimeout
(
int32_t
send_timeout_ms
)
{
send_timeout_ms_
=
send_timeout_ms
;}
inline
void
setRecvTimeout
(
const
int32_t
recv_timeout_ms
)
{
recv_timeout_ms_
=
recv_timeout_ms
;
}
inline
void
setSendTimeout
(
const
int32_t
send_timeout_ms
)
{
send_timeout_ms_
=
send_timeout_ms
;
}
int32_t
sendTimeout
()
const
{
return
send_timeout_ms_
;}
int32_t
recvTimeout
()
const
{
return
recv_timeout_ms_
;}
primihub
::
common
::
CertificateConfig
&
getCertificateConfig
()
{
return
*
cert_config_
;}
primihub
::
common
::
CertificateConfig
&
getCertificateConfig
()
{
return
*
cert_config_
;
}
void
initCertificate
(
const
std
::
string
&
root_ca_path
,
const
std
::
string
&
key_path
,
...
...
@@ -62,69 +72,46 @@ class LinkContext {
return
retcode
::
SUCCESS
;
}
primihub
::
ThreadSafeQueue
<
std
::
string
>&
GetRecvQueue
(
const
std
::
string
&
key
=
"default"
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
this
->
in_queue_mtx
);
auto
it
=
in_data_queue
.
find
(
key
);
if
(
it
!=
in_data_queue
.
end
())
{
return
it
->
second
;
}
else
{
in_data_queue
[
key
];
if
(
stop_
.
load
(
std
::
memory_order
::
memory_order_relaxed
))
{
in_data_queue
[
key
].
shutdown
();
}
return
in_data_queue
[
key
];
}
}
primihub
::
ThreadSafeQueue
<
std
::
string
>&
GetSendQueue
(
const
std
::
string
&
key
=
"default"
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
this
->
out_queue_mtx
);
auto
it
=
out_data_queue
.
find
(
key
);
if
(
it
!=
out_data_queue
.
end
())
{
return
it
->
second
;
}
else
{
return
out_data_queue
[
key
];
}
}
StringDataQueue
&
GetRecvQueue
(
const
std
::
string
&
key
=
"default"
);
StringDataQueue
&
GetSendQueue
(
const
std
::
string
&
key
=
"default"
);
StatusDataQueue
&
GetCompleteQueue
(
const
std
::
string
&
role
=
"default"
);
primihub
::
ThreadSafeQueue
<
retcode
>&
GetCompleteQueue
(
const
std
::
string
&
role
=
"default"
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
this
->
complete_queue_mtx
);
auto
it
=
complete_queue
.
find
(
role
);
if
(
it
!=
complete_queue
.
end
())
{
return
it
->
second
;
}
else
{
return
complete_queue
[
role
];
}
}
void
Clean
();
retcode
Send
(
const
std
::
string
&
key
,
const
Node
&
dest_node
,
const
std
::
string
&
send_buf
);
retcode
Send
(
const
std
::
string
&
key
,
const
Node
&
dest_node
,
std
::
string_view
send_buf
);
retcode
Send
(
const
std
::
string
&
key
,
const
Node
&
dest_node
,
char
*
send_buf
,
size_t
send_size
);
retcode
Recv
(
const
std
::
string
&
key
,
std
::
string
*
recv_buf
);
retcode
Recv
(
const
std
::
string
&
key
,
char
*
recv_buf
,
size_t
recv_size
);
/**
* sender to process send recv
*/
retcode
SendRecv
(
const
std
::
string
&
key
,
const
Node
&
dest_node
,
const
std
::
string
&
send_buf
,
std
::
string
*
recv_buf
);
retcode
SendRecv
(
const
std
::
string
&
key
,
const
Node
&
dest_node
,
std
::
string_view
send_buf
,
std
::
string
*
recv_buf
);
retcode
SendRecv
(
const
std
::
string
&
key
,
const
Node
&
dest_node
,
const
char
*
send_buf
,
size_t
length
,
std
::
string
*
recv_buf
);
/**
* receiver to process send recv
*/
retcode
SendRecv
(
const
std
::
string
&
key
,
const
std
::
string
&
send_buf
,
std
::
string
*
recv_buf
);
void
Clean
()
{
stop_
.
store
(
true
);
LOG
(
ERROR
)
<<
"stop all in data queue"
;
{
std
::
lock_guard
<
std
::
mutex
>
lck
(
in_queue_mtx
);
for
(
auto
it
=
in_data_queue
.
begin
();
it
!=
in_data_queue
.
end
();
++
it
)
{
it
->
second
.
shutdown
();
}
}
LOG
(
ERROR
)
<<
"stop all out data queue"
;
{
std
::
lock_guard
<
std
::
mutex
>
lck
(
out_queue_mtx
);
for
(
auto
it
=
out_data_queue
.
begin
();
it
!=
out_data_queue
.
end
();
++
it
)
{
it
->
second
.
shutdown
();
}
}
LOG
(
ERROR
)
<<
"stop all complete queue"
;
{
std
::
lock_guard
<
std
::
mutex
>
lck
(
complete_queue_mtx
);
for
(
auto
it
=
complete_queue
.
begin
();
it
!=
complete_queue
.
end
();
++
it
)
{
it
->
second
.
shutdown
();
}
}
}
protected:
bool
HasStopped
()
{
return
stop_
.
load
(
std
::
memory_order
::
memory_order_relaxed
);
}
int32_t
recv_timeout_ms_
{
-
1
};
int32_t
send_timeout_ms_
{
-
1
};
std
::
shared_mutex
connection_mgr_mtx
;
...
...
@@ -135,11 +122,13 @@ class LinkContext {
std
::
unique_ptr
<
primihub
::
common
::
CertificateConfig
>
cert_config_
{
nullptr
};
std
::
mutex
in_queue_mtx
;
std
::
unordered_map
<
std
::
string
,
primihub
::
ThreadSafeQueue
<
std
::
string
>>
in_data_queue
;
StringDataContainer
in_data_queue
;
std
::
mutex
out_queue_mtx
;
std
::
unordered_map
<
std
::
string
,
primihub
::
ThreadSafeQueue
<
std
::
string
>>
out_data_queue
;
StringDataContainer
out_data_queue
;
std
::
mutex
complete_queue_mtx
;
std
::
unordered_map
<
std
::
string
,
primihub
::
ThreadSafeQueue
<
retcode
>>
complete_queue
;
StatusDataContainer
complete_queue
;
std
::
atomic
<
bool
>
stop_
{
false
};
};
...
...
@@ -150,19 +139,26 @@ class IChannel {
virtual
~
IChannel
()
=
default
;
virtual
retcode
send
(
const
std
::
string
&
key
,
const
std
::
string
&
data
)
=
0
;
virtual
retcode
send
(
const
std
::
string
&
key
,
std
::
string_view
sv_data
)
=
0
;
virtual
bool
send_wrapper
(
const
std
::
string
&
key
,
const
std
::
string
&
data
)
=
0
;
virtual
bool
send_wrapper
(
const
std
::
string
&
key
,
std
::
string_view
sv_data
)
=
0
;
virtual
bool
send_wrapper
(
const
std
::
string
&
key
,
const
std
::
string
&
data
)
=
0
;
virtual
bool
send_wrapper
(
const
std
::
string
&
key
,
std
::
string_view
sv_data
)
=
0
;
virtual
retcode
sendRecv
(
const
std
::
string
&
key
,
const
std
::
string
&
send_data
,
std
::
string
*
recv_data
)
=
0
;
virtual
retcode
sendRecv
(
const
std
::
string
&
key
,
std
::
string_view
send_data
,
std
::
string
*
recv_data
)
=
0
;
virtual
retcode
submitTask
(
const
rpc
::
PushTaskRequest
&
request
,
rpc
::
PushTaskReply
*
reply
)
=
0
;
virtual
retcode
executeTask
(
const
rpc
::
PushTaskRequest
&
request
,
rpc
::
PushTaskReply
*
reply
)
=
0
;
virtual
retcode
killTask
(
const
rpc
::
KillTaskRequest
&
request
,
rpc
::
KillTaskResponse
*
reply
)
=
0
;
virtual
retcode
updateTaskStatus
(
const
rpc
::
TaskStatus
&
request
,
rpc
::
Empty
*
reply
)
=
0
;
virtual
retcode
fetchTaskStatus
(
const
rpc
::
TaskContext
&
request
,
rpc
::
TaskStatusReply
*
reply
)
=
0
;
virtual
retcode
submitTask
(
const
rpc
::
PushTaskRequest
&
request
,
rpc
::
PushTaskReply
*
reply
)
=
0
;
virtual
retcode
executeTask
(
const
rpc
::
PushTaskRequest
&
request
,
rpc
::
PushTaskReply
*
reply
)
=
0
;
virtual
retcode
killTask
(
const
rpc
::
KillTaskRequest
&
request
,
rpc
::
KillTaskResponse
*
reply
)
=
0
;
virtual
retcode
updateTaskStatus
(
const
rpc
::
TaskStatus
&
request
,
rpc
::
Empty
*
reply
)
=
0
;
virtual
retcode
fetchTaskStatus
(
const
rpc
::
TaskContext
&
request
,
rpc
::
TaskStatusReply
*
reply
)
=
0
;
virtual
std
::
string
forwardRecv
(
const
std
::
string
&
key
)
=
0
;
LinkContext
*
getLinkContext
()
{
return
link_ctx_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录