Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Primihub
PrimiHub
提交
a470b3f3
P
PrimiHub
项目概览
Primihub
/
PrimiHub
9 个月 前同步成功
通知
21
Star
1
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PrimiHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
a470b3f3
编写于
7月 28, 2023
作者:
X
Xuefeng Xu
提交者:
GitHub
7月 28, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support multiclass VFL CKKS logistic regression (#578)
* support multiclass VFL CKKS logistic regression
上级
aa373211
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
225 addition
and
93 deletion
+225
-93
example/FL/logistic_regression/vfl_multiclass_ckks.json
example/FL/logistic_regression/vfl_multiclass_ckks.json
+45
-0
python/primihub/FL/logistic_regression/base.py
python/primihub/FL/logistic_regression/base.py
+4
-5
python/primihub/FL/logistic_regression/vfl_base.py
python/primihub/FL/logistic_regression/vfl_base.py
+50
-19
python/primihub/FL/logistic_regression/vfl_coordinator.py
python/primihub/FL/logistic_regression/vfl_coordinator.py
+66
-34
python/primihub/FL/logistic_regression/vfl_guest.py
python/primihub/FL/logistic_regression/vfl_guest.py
+24
-12
python/primihub/FL/logistic_regression/vfl_host.py
python/primihub/FL/logistic_regression/vfl_host.py
+36
-23
未找到文件。
example/FL/logistic_regression/vfl_multiclass_ckks.json
0 → 100644
浏览文件 @
a470b3f3
{
"party_info"
:
{
"task_manager"
:
"127.0.0.1:50050"
},
"component_params"
:
{
"roles"
:
{
"host"
:
"Bob"
,
"guest"
:
[
"Charlie"
],
"coordinator"
:
"David"
},
"common_params"
:
{
"model"
:
"VFL_logistic_regression"
,
"method"
:
"CKKS"
,
"process"
:
"train"
,
"task_name"
:
"VFL_logistic_regression_multiclass_ckks_train"
,
"learning_rate"
:
1e-1
,
"alpha"
:
1e-4
,
"epoch"
:
2
,
"shuffle_seed"
:
0
,
"batch_size"
:
100
,
"print_metrics"
:
true
},
"role_params"
:
{
"Bob"
:
{
"data_set"
:
"multiclass_vfl_train_host"
,
"selected_column"
:
null
,
"id"
:
"id"
,
"label"
:
"y"
,
"model_path"
:
"data/result/host_model.pkl"
,
"metric_path"
:
"data/result/metrics.json"
},
"Charlie"
:
{
"data_set"
:
"multiclass_vfl_train_guest"
,
"selected_column"
:
null
,
"id"
:
"id"
,
"model_path"
:
"data/result/guest_model.pkl"
},
"David"
:
{
"data_set"
:
"fl_fake_data"
}
}
}
}
\ No newline at end of file
python/primihub/FL/logistic_regression/base.py
浏览文件 @
a470b3f3
...
...
@@ -61,12 +61,10 @@ class LogisticRegression:
error
=
self
.
predict_prob
(
x
)
idx
=
np
.
arange
(
len
(
y
))
error
[
idx
,
y
]
-=
1
dw
=
x
.
T
.
dot
(
error
)
/
x
.
shape
[
0
]
+
self
.
alpha
*
self
.
weight
db
=
error
.
mean
(
axis
=
0
,
keepdims
=
True
)
else
:
error
=
self
.
predict_prob
(
x
)
-
y
dw
=
x
.
T
.
dot
(
error
)
/
x
.
shape
[
0
]
+
self
.
alpha
*
self
.
weight
db
=
error
.
mean
(
keepdims
=
True
)
dw
=
x
.
T
.
dot
(
error
)
/
x
.
shape
[
0
]
+
self
.
alpha
*
self
.
weight
db
=
error
.
mean
(
axis
=
0
,
keepdims
=
True
)
return
dw
,
db
def
gradient_descent
(
self
,
x
,
y
):
...
...
@@ -192,7 +190,8 @@ class LogisticRegression_Paillier(LogisticRegression):
error
=
2
+
x
.
dot
(
self
.
weight
)
+
self
.
bias
-
4
*
y
factor
=
-
self
.
learning_rate
/
x
.
shape
[
0
]
self
.
weight
+=
(
factor
*
x
).
T
.
dot
(
error
)
+
self
.
alpha
*
self
.
weight
self
.
weight
+=
(
factor
*
x
).
T
.
dot
(
error
)
+
\
(
-
self
.
learning_rate
*
self
.
alpha
)
*
self
.
weight
self
.
bias
+=
factor
*
error
.
sum
(
keepdims
=
True
)
def
BCELoss
(
self
,
x
,
y
):
...
...
python/primihub/FL/logistic_regression/vfl_base.py
浏览文件 @
a470b3f3
import
tenseal
as
ts
import
numpy
as
np
from
primihub.utils.logger_util
import
logger
from
.base
import
LogisticRegression
...
...
@@ -5,6 +6,13 @@ from .base import LogisticRegression
class
LogisticRegression_Host_Plaintext
(
LogisticRegression
):
def
__init__
(
self
,
x
,
y
,
learning_rate
=
0.2
,
alpha
=
0.0001
):
super
().
__init__
(
x
,
y
,
learning_rate
,
alpha
)
if
self
.
multiclass
:
self
.
output_dim
=
self
.
weight
.
shape
[
1
]
else
:
self
.
output_dim
=
1
def
compute_z
(
self
,
x
,
guest_z
):
z
=
x
.
dot
(
self
.
weight
)
+
self
.
bias
z
+=
np
.
array
(
guest_z
).
sum
(
axis
=
0
)
...
...
@@ -46,12 +54,8 @@ class LogisticRegression_Host_Plaintext(LogisticRegression):
return
self
.
BCELoss
(
y
,
z
,
regular_loss
)
def
compute_grad
(
self
,
x
,
error
):
if
self
.
multiclass
:
dw
=
x
.
T
.
dot
(
error
)
/
x
.
shape
[
0
]
+
self
.
alpha
*
self
.
weight
db
=
error
.
mean
(
axis
=
0
,
keepdims
=
True
)
else
:
dw
=
x
.
T
.
dot
(
error
)
/
x
.
shape
[
0
]
+
self
.
alpha
*
self
.
weight
db
=
error
.
sum
(
keepdims
=
True
)
dw
=
x
.
T
.
dot
(
error
)
/
x
.
shape
[
0
]
+
self
.
alpha
*
self
.
weight
db
=
error
.
mean
(
axis
=
0
,
keepdims
=
True
)
return
dw
,
db
def
gradient_descent
(
self
,
x
,
error
):
...
...
@@ -67,25 +71,38 @@ class LogisticRegression_Host_CKKS(LogisticRegression_Host_Plaintext):
def
compute_enc_z
(
self
,
x
,
guest_z
):
z
=
self
.
weight
.
mm
(
x
.
T
)
+
self
.
bias
z
+=
np
.
array
(
guest_z
).
sum
(
axis
=
0
)
z
+=
sum
(
guest_z
)
return
z
def
compute_error
(
self
,
y
,
z
):
if
self
.
multiclass
:
error_msg
=
"CKKS method doesn't support multiclass classification"
logger
.
error
(
error_msg
)
raise
AttributeError
(
error_msg
)
error
=
z
+
1
-
self
.
output_dim
*
np
.
eye
(
self
.
output_dim
)[
y
].
T
else
:
error
=
2.
+
z
-
4
*
y
return
error
def
compute_regular_loss
(
self
,
guest_regular_loss
):
if
self
.
multiclass
and
isinstance
(
self
.
weight
,
ts
.
CKKSTensor
):
return
(
0.5
*
self
.
alpha
)
*
(
self
.
weight
**
2
).
sum
().
sum
()
\
+
guest_regular_loss
else
:
return
super
().
compute_regular_loss
(
guest_regular_loss
)
def
BCELoss
(
self
,
y
,
z
,
regular_loss
):
return
z
.
dot
((
0.5
-
y
)
/
y
.
shape
[
0
])
+
regular_loss
def
CELoss
(
self
,
y
,
z
,
regular_loss
):
error_msg
=
"CKKS method doesn't support multiclass classification"
logger
.
error
(
error_msg
)
raise
AttributeError
(
error_msg
)
factor
=
1.
/
(
y
.
shape
[
0
]
*
self
.
output_dim
)
if
isinstance
(
z
,
ts
.
CKKSTensor
):
# Todo: fix encrypted1 and encrypted2 parameter mismatch
return
(
z
*
factor
\
-
z
*
((
np
.
eye
(
self
.
output_dim
)[
y
].
T
+
np
.
random
.
normal
(
0
,
1e-4
,
(
self
.
output_dim
,
y
.
shape
[
0
])))
\
*
factor
)).
sum
().
sum
()
\
+
regular_loss
else
:
return
np
.
sum
(
np
.
sum
(
z
,
axis
=
1
)
-
z
[
np
.
arange
(
len
(
y
)),
y
])
\
*
factor
+
regular_loss
def
loss
(
self
,
y
,
z
,
regular_loss
):
if
self
.
multiclass
:
...
...
@@ -95,13 +112,13 @@ class LogisticRegression_Host_CKKS(LogisticRegression_Host_Plaintext):
def
gradient_descent
(
self
,
x
,
error
):
if
self
.
multiclass
:
error_msg
=
"CKKS method doesn't support multiclass classification"
logger
.
error
(
error_msg
)
raise
AttributeError
(
error_msg
)
factor
=
-
self
.
learning_rate
/
(
self
.
output_dim
*
x
.
shape
[
0
])
self
.
bias
+=
error
.
sum
(
axis
=
1
).
reshape
((
self
.
output_dim
,
1
))
*
factor
else
:
factor
=
-
self
.
learning_rate
/
x
.
shape
[
0
]
self
.
weight
+=
error
.
mm
(
factor
*
x
)
+
self
.
alpha
*
self
.
weight
self
.
bias
+=
error
.
sum
()
*
factor
self
.
weight
+=
error
.
mm
(
factor
*
x
)
\
+
(
-
self
.
learning_rate
*
self
.
alpha
)
*
self
.
weight
class
LogisticRegression_Guest_Plaintext
:
...
...
@@ -138,9 +155,23 @@ class LogisticRegression_Guest_Plaintext:
class
LogisticRegression_Guest_CKKS
(
LogisticRegression_Guest_Plaintext
):
def
__init__
(
self
,
x
,
learning_rate
=
0.2
,
alpha
=
0.0001
,
output_dim
=
1
):
super
().
__init__
(
x
,
learning_rate
,
alpha
,
output_dim
)
self
.
output_dim
=
output_dim
def
compute_enc_z
(
self
,
x
):
return
self
.
weight
.
mm
(
x
.
T
)
def
compute_regular_loss
(
self
):
if
self
.
multiclass
and
isinstance
(
self
.
weight
,
ts
.
CKKSTensor
):
return
(
0.5
*
self
.
alpha
)
*
(
self
.
weight
**
2
).
sum
().
sum
()
else
:
return
super
().
compute_regular_loss
()
def
gradient_descent
(
self
,
x
,
error
):
factor
=
-
self
.
learning_rate
/
x
.
shape
[
0
]
self
.
weight
+=
error
.
mm
(
factor
*
x
)
+
self
.
alpha
*
self
.
weight
if
self
.
multiclass
:
factor
=
-
self
.
learning_rate
/
(
self
.
output_dim
*
x
.
shape
[
0
])
else
:
factor
=
-
self
.
learning_rate
/
x
.
shape
[
0
]
self
.
weight
+=
error
.
mm
(
factor
*
x
)
+
\
(
-
self
.
learning_rate
*
self
.
alpha
)
*
self
.
weight
python/primihub/FL/logistic_regression/vfl_coordinator.py
浏览文件 @
a470b3f3
...
...
@@ -67,6 +67,7 @@ class CKKS:
if
isinstance
(
context
,
bytes
):
context
=
ts
.
context_from
(
context
)
self
.
context
=
context
self
.
multiply_depth
=
context
.
data
.
seal_context
().
first_context_data
().
chain_index
()
def
encrypt_vector
(
self
,
vector
,
context
=
None
):
if
context
:
...
...
@@ -74,15 +75,24 @@ class CKKS:
else
:
return
ts
.
ckks_vector
(
self
.
context
,
vector
)
def
decrypt
(
self
,
vector
,
secret_key
=
None
):
if
vector
.
context
().
has_secret_key
():
return
vector
.
decrypt
()
def
encrypt_tensor
(
self
,
tensor
,
context
=
None
):
if
context
:
return
ts
.
ckks_tensor
(
context
,
tensor
)
else
:
return
ts
.
ckks_tensor
(
self
.
context
,
tensor
)
def
decrypt
(
self
,
ciphertext
,
secret_key
=
None
):
if
ciphertext
.
context
().
has_secret_key
():
return
ciphertext
.
decrypt
()
else
:
return
vector
.
decrypt
(
secret_key
)
return
ciphertext
.
decrypt
(
secret_key
)
def
load_vector
(
self
,
vector
):
return
ts
.
ckks_vector_from
(
self
.
context
,
vector
)
def
load_tensor
(
self
,
tensor
):
return
ts
.
ckks_tensor_from
(
self
.
context
,
tensor
)
class
CKKSCoordinator
(
CKKS
):
...
...
@@ -90,19 +100,20 @@ class CKKSCoordinator(CKKS):
self
.
t
=
0
self
.
host_channel
=
host_channel
self
.
guest_channel
=
guest_channel
self
.
multiclass
=
host_channel
.
recv
(
'multiclass'
)
# set CKKS params
# use larger poly_mod_degree to support more encrypted multiplications
poly_mod_degree
=
32768
# gradient descent uses as least two multiplications per interation
multiply_per_iter
=
2
poly_mod_degree
=
8192
# the least multiplication per iteration of gradient descent
# more multiplications lead to larger context size
self
.
max_iter
=
7
multiply_per_iter
=
2
self
.
max_iter
=
1
multiply_depth
=
multiply_per_iter
*
self
.
max_iter
# sum(coeff_mod_bit_sizes) <= max coeff_modulus bit-length
fe_bits_scale
=
35
bits_scale
=
27
#
35*2 + 27*2*7 = 448 < 881 (for N = 32768
& 128 bit security)
fe_bits_scale
=
60
bits_scale
=
49
#
60*2 + 49*1*2 = 218 == 218 (for N = 8192
& 128 bit security)
coeff_mod_bit_sizes
=
[
fe_bits_scale
]
+
\
[
bits_scale
]
*
multiply_depth
+
\
[
fe_bits_scale
]
...
...
@@ -122,26 +133,28 @@ class CKKSCoordinator(CKKS):
self
.
secret_context
=
secret_context
self
.
send_public_context
()
self
.
send_max_iter
()
num_examples
=
host_channel
.
recv
(
'num_examples'
)
self
.
iter_per_epoch
=
math
.
ceil
(
num_examples
/
batch_size
)
def
send_max_iter
(
self
):
self
.
host_channel
.
send
(
"max_iter"
,
self
.
max_iter
)
self
.
guest_channel
.
send_all
(
"max_iter"
,
self
.
max_iter
)
def
send_public_context
(
self
):
serialize_context
=
self
.
context
.
serialize
()
self
.
host_channel
.
send
(
"public_context"
,
serialize_context
)
self
.
guest_channel
.
send_all
(
"public_context"
,
serialize_context
)
def
recv_model
(
self
):
host_weight
=
self
.
load_vector
(
self
.
host_channel
.
recv
(
'host_weight'
))
host_bias
=
self
.
load_vector
(
self
.
host_channel
.
recv
(
'host_bias'
))
if
self
.
multiclass
:
host_weight
=
self
.
load_tensor
(
self
.
host_channel
.
recv
(
'host_weight'
))
host_bias
=
self
.
load_tensor
(
self
.
host_channel
.
recv
(
'host_bias'
))
guest_weight
=
self
.
guest_channel
.
recv_all
(
'guest_weight'
)
guest_weight
=
[
self
.
load_vector
(
weight
)
for
weight
in
guest_weight
]
guest_weight
=
self
.
guest_channel
.
recv_all
(
'guest_weight'
)
guest_weight
=
[
self
.
load_tensor
(
weight
)
for
weight
in
guest_weight
]
else
:
host_weight
=
self
.
load_vector
(
self
.
host_channel
.
recv
(
'host_weight'
))
host_bias
=
self
.
load_vector
(
self
.
host_channel
.
recv
(
'host_bias'
))
guest_weight
=
self
.
guest_channel
.
recv_all
(
'guest_weight'
)
guest_weight
=
[
self
.
load_vector
(
weight
)
for
weight
in
guest_weight
]
return
host_weight
,
host_bias
,
guest_weight
...
...
@@ -161,10 +174,16 @@ class CKKSCoordinator(CKKS):
return
host_weight
,
host_bias
,
guest_weight
def
encrypt_model
(
self
,
host_weight
,
host_bias
,
guest_weight
):
host_weight
=
self
.
encrypt_vector
(
host_weight
)
host_bias
=
self
.
encrypt_vector
(
host_bias
)
if
self
.
multiclass
:
host_weight
=
self
.
encrypt_tensor
(
host_weight
)
host_bias
=
self
.
encrypt_tensor
(
host_bias
)
guest_weight
=
[
self
.
encrypt_vector
(
weight
)
for
weight
in
guest_weight
]
guest_weight
=
[
self
.
encrypt_tensor
(
weight
)
for
weight
in
guest_weight
]
else
:
host_weight
=
self
.
encrypt_vector
(
host_weight
)
host_bias
=
self
.
encrypt_vector
(
host_bias
)
guest_weight
=
[
self
.
encrypt_vector
(
weight
)
for
weight
in
guest_weight
]
return
host_weight
,
host_bias
,
guest_weight
...
...
@@ -187,28 +206,41 @@ class CKKSCoordinator(CKKS):
host_weight
,
host_bias
,
guest_weight
)
# list to numpy ndarrry
host_weight
=
np
.
array
(
host_weight
)
host_bias
=
np
.
array
(
host_bias
)
guest_weight
=
[
np
.
array
(
weight
)
for
weight
in
guest_weight
]
if
self
.
multiclass
:
host_weight
=
np
.
array
(
host_weight
.
tolist
()).
T
host_bias
=
np
.
array
(
host_bias
.
tolist
()).
T
guest_weight
=
[
np
.
array
(
weight
.
tolist
()).
T
for
weight
in
guest_weight
]
else
:
host_weight
=
np
.
array
(
host_weight
)
host_bias
=
np
.
array
(
host_bias
)
guest_weight
=
[
np
.
array
(
weight
)
for
weight
in
guest_weight
]
self
.
send_model
(
host_weight
,
host_bias
,
guest_weight
)
def
train
(
self
):
logger
.
info
(
f
'iteration
{
self
.
t
}
/
{
self
.
max_iter
}
'
)
self
.
t
+=
self
.
iter_per_epoch
for
i
in
range
(
self
.
t
//
self
.
max_iter
):
self
.
update_ciphertext_model
()
logger
.
warning
(
f
'decrypt model #
{
i
+
1
}
'
)
num_dec
=
self
.
t
//
self
.
max_iter
self
.
t
=
self
.
t
%
self
.
max_iter
if
self
.
t
==
0
:
num_dec
-=
1
self
.
t
=
self
.
max_iter
for
i
in
range
(
num_dec
):
logger
.
warning
(
f
'decrypt model #
{
i
+
1
}
'
)
self
.
update_ciphertext_model
()
def
compute_loss
(
self
):
logger
.
info
(
f
'iteration
{
self
.
t
}
/
{
self
.
max_iter
}
'
)
self
.
t
+=
1
if
self
.
t
>
self
.
max_iter
:
if
self
.
t
>=
self
.
max_iter
:
self
.
t
=
0
self
.
update_ciphertext_model
()
logger
.
warning
(
'decrypt model'
)
self
.
update_ciphertext_model
()
loss
=
self
.
load_vector
(
self
.
host_channel
.
recv
(
'loss'
))
loss
=
self
.
decrypt
(
loss
,
self
.
secret_context
.
secret_key
())[
0
]
if
self
.
multiclass
:
loss
=
self
.
load_tensor
(
self
.
host_channel
.
recv
(
'loss'
))
loss
=
self
.
decrypt
(
loss
,
self
.
secret_context
.
secret_key
()).
tolist
()
else
:
loss
=
self
.
load_vector
(
self
.
host_channel
.
recv
(
'loss'
))
loss
=
self
.
decrypt
(
loss
,
self
.
secret_context
.
secret_key
())[
0
]
logger
.
info
(
f
'loss=
{
loss
}
'
)
\ No newline at end of file
python/primihub/FL/logistic_regression/vfl_guest.py
浏览文件 @
a470b3f3
...
...
@@ -5,7 +5,7 @@ from primihub.FL.utils.dataset import read_data, DataLoader
from
primihub.utils.logger_util
import
logger
import
pickle
from
sklearn.preprocessing
import
MinMax
Scaler
from
sklearn.preprocessing
import
Standard
Scaler
from
.vfl_base
import
LogisticRegression_Guest_Plaintext
,
\
LogisticRegression_Guest_CKKS
...
...
@@ -70,8 +70,8 @@ class LogisticRegressionGuest(BaseModel):
raise
RuntimeError
(
error_msg
)
# data preprocessing
#
minmaxs
caler
scaler
=
MinMax
Scaler
()
#
StandardS
caler
scaler
=
Standard
Scaler
()
x
=
scaler
.
fit_transform
(
x
)
# guest training
...
...
@@ -196,8 +196,9 @@ class CKKS_Guest(Plaintext_Guest, CKKS):
alpha
,
output_dim
)
self
.
recv_public_context
(
coordinator_channel
)
self
.
max_iter
=
coordinator_channel
.
recv
(
'max_iter'
)
CKKS
.
__init__
(
self
,
self
.
context
)
multiply_per_iter
=
2
self
.
max_iter
=
self
.
multiply_depth
//
multiply_per_iter
self
.
encrypt_model
()
def
recv_public_context
(
self
,
coordinator_channel
):
...
...
@@ -205,13 +206,21 @@ class CKKS_Guest(Plaintext_Guest, CKKS):
self
.
context
=
coordinator_channel
.
recv
(
'public_context'
)
def
encrypt_model
(
self
):
self
.
model
.
weight
=
self
.
encrypt_vector
(
self
.
model
.
weight
)
if
self
.
model
.
multiclass
:
self
.
model
.
weight
=
self
.
encrypt_tensor
(
self
.
model
.
weight
.
T
)
else
:
self
.
model
.
weight
=
self
.
encrypt_vector
(
self
.
model
.
weight
)
def
update_ciphertext_model
(
self
):
self
.
coordinator_channel
.
send
(
'guest_weight'
,
self
.
model
.
weight
.
serialize
())
self
.
model
.
weight
=
self
.
load_vector
(
self
.
coordinator_channel
.
recv
(
'guest_weight'
))
if
self
.
model
.
multiclass
:
self
.
model
.
weight
=
self
.
load_tensor
(
self
.
coordinator_channel
.
recv
(
'guest_weight'
))
else
:
self
.
model
.
weight
=
self
.
load_vector
(
self
.
coordinator_channel
.
recv
(
'guest_weight'
))
def
update_plaintext_model
(
self
):
self
.
coordinator_channel
.
send
(
'guest_weight'
,
...
...
@@ -232,23 +241,26 @@ class CKKS_Guest(Plaintext_Guest, CKKS):
logger
.
info
(
f
'iteration
{
self
.
t
}
/
{
self
.
max_iter
}
'
)
if
self
.
t
>=
self
.
max_iter
:
self
.
t
=
0
self
.
update_ciphertext_model
()
logger
.
warning
(
f
'decrypt model'
)
self
.
update_ciphertext_model
()
self
.
t
+=
1
self
.
send_enc_z
(
x
)
error
=
self
.
load_vector
(
self
.
host_channel
.
recv
(
'error'
))
if
self
.
model
.
multiclass
:
error
=
self
.
load_tensor
(
self
.
host_channel
.
recv
(
'error'
))
else
:
error
=
self
.
load_vector
(
self
.
host_channel
.
recv
(
'error'
))
self
.
model
.
fit
(
x
,
error
)
def
compute_metrics
(
self
,
x
):
logger
.
info
(
f
'iteration
{
self
.
t
}
/
{
self
.
max_iter
}
'
)
self
.
t
+=
1
if
self
.
t
>
self
.
max_iter
:
if
self
.
t
>=
self
.
max_iter
:
self
.
t
=
0
self
.
update_ciphertext_model
()
logger
.
warning
(
f
'decrypt model'
)
self
.
update_ciphertext_model
()
self
.
send_enc_z
(
x
)
self
.
send_enc_regular_loss
()
...
...
python/primihub/FL/logistic_regression/vfl_host.py
浏览文件 @
a470b3f3
...
...
@@ -11,7 +11,7 @@ import numpy as np
from
sklearn
import
metrics
from
primihub.FL.metrics.hfl_metrics
import
ks_from_fpr_tpr
,
\
auc_from_fpr_tpr
from
sklearn.preprocessing
import
MinMax
Scaler
from
sklearn.preprocessing
import
Standard
Scaler
from
.vfl_base
import
LogisticRegression_Host_Plaintext
,
\
LogisticRegression_Host_CKKS
...
...
@@ -79,8 +79,8 @@ class LogisticRegressionHost(BaseModel):
raise
RuntimeError
(
error_msg
)
# data preprocessing
#
minmaxs
caler
scaler
=
MinMax
Scaler
()
#
StandardS
caler
scaler
=
Standard
Scaler
()
x
=
scaler
.
fit_transform
(
x
)
# host training
...
...
@@ -195,13 +195,7 @@ class Plaintext_Host:
def
send_output_dim
(
self
,
guest_channel
):
self
.
guest_channel
=
guest_channel
if
self
.
model
.
multiclass
:
output_dim
=
self
.
model
.
weight
.
shape
[
1
]
else
:
output_dim
=
1
guest_channel
.
send_all
(
'output_dim'
,
output_dim
)
guest_channel
.
send_all
(
'output_dim'
,
self
.
model
.
output_dim
)
def
compute_z
(
self
,
x
):
guest_z
=
self
.
guest_channel
.
recv_all
(
'guest_z'
)
...
...
@@ -274,10 +268,12 @@ class CKKS_Host(Plaintext_Host, CKKS):
learning_rate
,
alpha
)
self
.
send_output_dim
(
guest_channel
)
coordinator_channel
.
send
(
'multiclass'
,
self
.
model
.
multiclass
)
self
.
recv_public_context
(
coordinator_channel
)
self
.
max_iter
=
coordinator_channel
.
recv
(
'max_iter'
)
coordinator_channel
.
send
(
'num_examples'
,
x
.
shape
[
0
])
CKKS
.
__init__
(
self
,
self
.
context
)
multiply_per_iter
=
2
self
.
max_iter
=
self
.
multiply_depth
//
multiply_per_iter
self
.
encrypt_model
()
def
recv_public_context
(
self
,
coordinator_channel
):
...
...
@@ -285,18 +281,29 @@ class CKKS_Host(Plaintext_Host, CKKS):
self
.
context
=
coordinator_channel
.
recv
(
'public_context'
)
def
encrypt_model
(
self
):
self
.
model
.
weight
=
self
.
encrypt_vector
(
self
.
model
.
weight
)
self
.
model
.
bias
=
self
.
encrypt_vector
(
self
.
model
.
bias
)
if
self
.
model
.
multiclass
:
self
.
model
.
weight
=
self
.
encrypt_tensor
(
self
.
model
.
weight
.
T
)
self
.
model
.
bias
=
self
.
encrypt_tensor
(
self
.
model
.
bias
.
T
)
else
:
self
.
model
.
weight
=
self
.
encrypt_vector
(
self
.
model
.
weight
)
self
.
model
.
bias
=
self
.
encrypt_vector
(
self
.
model
.
bias
)
def
update_ciphertext_model
(
self
):
self
.
coordinator_channel
.
send
(
'host_weight'
,
self
.
model
.
weight
.
serialize
())
self
.
coordinator_channel
.
send
(
'host_bias'
,
self
.
model
.
bias
.
serialize
())
self
.
model
.
weight
=
self
.
load_vector
(
self
.
coordinator_channel
.
recv
(
'host_weight'
))
self
.
model
.
bias
=
self
.
load_vector
(
self
.
coordinator_channel
.
recv
(
'host_bias'
))
if
self
.
model
.
multiclass
:
self
.
model
.
weight
=
self
.
load_tensor
(
self
.
coordinator_channel
.
recv
(
'host_weight'
))
self
.
model
.
bias
=
self
.
load_tensor
(
self
.
coordinator_channel
.
recv
(
'host_bias'
))
else
:
self
.
model
.
weight
=
self
.
load_vector
(
self
.
coordinator_channel
.
recv
(
'host_weight'
))
self
.
model
.
bias
=
self
.
load_vector
(
self
.
coordinator_channel
.
recv
(
'host_bias'
))
def
update_plaintext_model
(
self
):
self
.
coordinator_channel
.
send
(
'host_weight'
,
...
...
@@ -308,13 +315,19 @@ class CKKS_Host(Plaintext_Host, CKKS):
def
compute_enc_z
(
self
,
x
):
guest_z
=
self
.
guest_channel
.
recv_all
(
'guest_z'
)
guest_z
=
[
self
.
load_vector
(
z
)
for
z
in
guest_z
]
if
self
.
model
.
multiclass
:
guest_z
=
[
self
.
load_tensor
(
z
)
for
z
in
guest_z
]
else
:
guest_z
=
[
self
.
load_vector
(
z
)
for
z
in
guest_z
]
return
self
.
model
.
compute_enc_z
(
x
,
guest_z
)
def
compute_enc_regular_loss
(
self
):
if
self
.
model
.
alpha
!=
0
:
guest_regular_loss
=
self
.
guest_channel
.
recv_all
(
'guest_regular_loss'
)
guest_regular_loss
=
[
self
.
load_vector
(
s
)
for
s
in
guest_regular_loss
]
if
self
.
model
.
multiclass
:
guest_regular_loss
=
[
self
.
load_tensor
(
s
)
for
s
in
guest_regular_loss
]
else
:
guest_regular_loss
=
[
self
.
load_vector
(
s
)
for
s
in
guest_regular_loss
]
return
self
.
model
.
compute_regular_loss
(
sum
(
guest_regular_loss
))
else
:
return
0.
...
...
@@ -323,8 +336,9 @@ class CKKS_Host(Plaintext_Host, CKKS):
logger
.
info
(
f
'iteration
{
self
.
t
}
/
{
self
.
max_iter
}
'
)
if
self
.
t
>=
self
.
max_iter
:
self
.
t
=
0
self
.
update_ciphertext_model
()
logger
.
warning
(
f
'decrypt model'
)
self
.
update_ciphertext_model
()
self
.
t
+=
1
z
=
self
.
compute_enc_z
(
x
)
...
...
@@ -336,11 +350,10 @@ class CKKS_Host(Plaintext_Host, CKKS):
def
compute_metrics
(
self
,
x
,
y
):
logger
.
info
(
f
'iteration
{
self
.
t
}
/
{
self
.
max_iter
}
'
)
self
.
t
+=
1
if
self
.
t
>
self
.
max_iter
:
if
self
.
t
>=
self
.
max_iter
:
self
.
t
=
0
self
.
update_ciphertext_model
()
logger
.
warning
(
f
'decrypt model'
)
self
.
update_ciphertext_model
()
z
=
self
.
compute_enc_z
(
x
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录