未验证 提交 a470b3f3 编写于 作者: X Xuefeng Xu 提交者: GitHub

support multiclass VFL CKKS logistic regression (#578)

* support multiclass VFL CKKS logistic regression
上级 aa373211
{
"party_info": {
"task_manager": "127.0.0.1:50050"
},
"component_params": {
"roles": {
"host": "Bob",
"guest": [
"Charlie"
],
"coordinator": "David"
},
"common_params": {
"model": "VFL_logistic_regression",
"method": "CKKS",
"process": "train",
"task_name": "VFL_logistic_regression_multiclass_ckks_train",
"learning_rate": 1e-1,
"alpha": 1e-4,
"epoch": 2,
"shuffle_seed": 0,
"batch_size": 100,
"print_metrics": true
},
"role_params": {
"Bob": {
"data_set": "multiclass_vfl_train_host",
"selected_column": null,
"id": "id",
"label": "y",
"model_path": "data/result/host_model.pkl",
"metric_path": "data/result/metrics.json"
},
"Charlie": {
"data_set": "multiclass_vfl_train_guest",
"selected_column": null,
"id": "id",
"model_path": "data/result/guest_model.pkl"
},
"David": {
"data_set": "fl_fake_data"
}
}
}
}
\ No newline at end of file
......@@ -61,12 +61,10 @@ class LogisticRegression:
error = self.predict_prob(x)
idx = np.arange(len(y))
error[idx, y] -= 1
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
db = error.mean(axis=0, keepdims=True)
else:
error = self.predict_prob(x) - y
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
db = error.mean(keepdims=True)
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
db = error.mean(axis=0, keepdims=True)
return dw, db
def gradient_descent(self, x, y):
......@@ -192,7 +190,8 @@ class LogisticRegression_Paillier(LogisticRegression):
error = 2 + x.dot(self.weight) + self.bias - 4 * y
factor = -self.learning_rate / x.shape[0]
self.weight += (factor * x).T.dot(error) + self.alpha * self.weight
self.weight += (factor * x).T.dot(error) + \
(-self.learning_rate * self.alpha) * self.weight
self.bias += factor * error.sum(keepdims=True)
def BCELoss(self, x, y):
......
import tenseal as ts
import numpy as np
from primihub.utils.logger_util import logger
from .base import LogisticRegression
......@@ -5,6 +6,13 @@ from .base import LogisticRegression
class LogisticRegression_Host_Plaintext(LogisticRegression):
def __init__(self, x, y, learning_rate=0.2, alpha=0.0001):
super().__init__(x, y, learning_rate, alpha)
if self.multiclass:
self.output_dim = self.weight.shape[1]
else:
self.output_dim = 1
def compute_z(self, x, guest_z):
z = x.dot(self.weight) + self.bias
z += np.array(guest_z).sum(axis=0)
......@@ -46,12 +54,8 @@ class LogisticRegression_Host_Plaintext(LogisticRegression):
return self.BCELoss(y, z, regular_loss)
def compute_grad(self, x, error):
if self.multiclass:
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
db = error.mean(axis=0, keepdims=True)
else:
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
db = error.sum(keepdims=True)
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
db = error.mean(axis=0, keepdims=True)
return dw, db
def gradient_descent(self, x, error):
......@@ -67,25 +71,38 @@ class LogisticRegression_Host_CKKS(LogisticRegression_Host_Plaintext):
def compute_enc_z(self, x, guest_z):
z = self.weight.mm(x.T) + self.bias
z += np.array(guest_z).sum(axis=0)
z += sum(guest_z)
return z
def compute_error(self, y, z):
if self.multiclass:
error_msg = "CKKS method doesn't support multiclass classification"
logger.error(error_msg)
raise AttributeError(error_msg)
error = z + 1 - self.output_dim * np.eye(self.output_dim)[y].T
else:
error = 2. + z - 4 * y
return error
def compute_regular_loss(self, guest_regular_loss):
if self.multiclass and isinstance(self.weight, ts.CKKSTensor):
return (0.5 * self.alpha) * (self.weight ** 2).sum().sum() \
+ guest_regular_loss
else:
return super().compute_regular_loss(guest_regular_loss)
def BCELoss(self, y, z, regular_loss):
return z.dot((0.5 - y) / y.shape[0]) + regular_loss
def CELoss(self, y, z, regular_loss):
error_msg = "CKKS method doesn't support multiclass classification"
logger.error(error_msg)
raise AttributeError(error_msg)
factor = 1. / (y.shape[0] * self.output_dim)
if isinstance(z, ts.CKKSTensor):
# Todo: fix encrypted1 and encrypted2 parameter mismatch
return (z * factor \
- z * ((np.eye(self.output_dim)[y].T
+ np.random.normal(0, 1e-4, (self.output_dim, y.shape[0]))) \
* factor)).sum().sum() \
+ regular_loss
else:
return np.sum(np.sum(z, axis=1) - z[np.arange(len(y)), y]) \
* factor + regular_loss
def loss(self, y, z, regular_loss):
if self.multiclass:
......@@ -95,13 +112,13 @@ class LogisticRegression_Host_CKKS(LogisticRegression_Host_Plaintext):
def gradient_descent(self, x, error):
if self.multiclass:
error_msg = "CKKS method doesn't support multiclass classification"
logger.error(error_msg)
raise AttributeError(error_msg)
factor = -self.learning_rate / (self.output_dim * x.shape[0])
self.bias += error.sum(axis=1).reshape((self.output_dim, 1)) * factor
else:
factor = -self.learning_rate / x.shape[0]
self.weight += error.mm(factor * x) + self.alpha * self.weight
self.bias += error.sum() * factor
self.weight += error.mm(factor * x) \
+ (-self.learning_rate * self.alpha) * self.weight
class LogisticRegression_Guest_Plaintext:
......@@ -138,9 +155,23 @@ class LogisticRegression_Guest_Plaintext:
class LogisticRegression_Guest_CKKS(LogisticRegression_Guest_Plaintext):
def __init__(self, x, learning_rate=0.2, alpha=0.0001, output_dim=1):
super().__init__(x, learning_rate, alpha, output_dim)
self.output_dim = output_dim
def compute_enc_z(self, x):
return self.weight.mm(x.T)
def compute_regular_loss(self):
if self.multiclass and isinstance(self.weight, ts.CKKSTensor):
return (0.5 * self.alpha) * (self.weight ** 2).sum().sum()
else:
return super().compute_regular_loss()
def gradient_descent(self, x, error):
factor = -self.learning_rate / x.shape[0]
self.weight += error.mm(factor * x) + self.alpha * self.weight
if self.multiclass:
factor = -self.learning_rate / (self.output_dim * x.shape[0])
else:
factor = -self.learning_rate / x.shape[0]
self.weight += error.mm(factor * x) + \
(-self.learning_rate * self.alpha) * self.weight
......@@ -67,6 +67,7 @@ class CKKS:
if isinstance(context, bytes):
context = ts.context_from(context)
self.context = context
self.multiply_depth = context.data.seal_context().first_context_data().chain_index()
def encrypt_vector(self, vector, context=None):
if context:
......@@ -74,15 +75,24 @@ class CKKS:
else:
return ts.ckks_vector(self.context, vector)
def decrypt(self, vector, secret_key=None):
if vector.context().has_secret_key():
return vector.decrypt()
def encrypt_tensor(self, tensor, context=None):
if context:
return ts.ckks_tensor(context, tensor)
else:
return ts.ckks_tensor(self.context, tensor)
def decrypt(self, ciphertext, secret_key=None):
if ciphertext.context().has_secret_key():
return ciphertext.decrypt()
else:
return vector.decrypt(secret_key)
return ciphertext.decrypt(secret_key)
def load_vector(self, vector):
return ts.ckks_vector_from(self.context, vector)
def load_tensor(self, tensor):
return ts.ckks_tensor_from(self.context, tensor)
class CKKSCoordinator(CKKS):
......@@ -90,19 +100,20 @@ class CKKSCoordinator(CKKS):
self.t = 0
self.host_channel = host_channel
self.guest_channel = guest_channel
self.multiclass = host_channel.recv('multiclass')
# set CKKS params
# use larger poly_mod_degree to support more encrypted multiplications
poly_mod_degree = 32768
# gradient descent uses as least two multiplications per interation
multiply_per_iter = 2
poly_mod_degree = 8192
# the least multiplication per iteration of gradient descent
# more multiplications lead to larger context size
self.max_iter = 7
multiply_per_iter = 2
self.max_iter = 1
multiply_depth = multiply_per_iter * self.max_iter
# sum(coeff_mod_bit_sizes) <= max coeff_modulus bit-length
fe_bits_scale = 35
bits_scale = 27
# 35*2 + 27*2*7 = 448 < 881 (for N = 32768 & 128 bit security)
fe_bits_scale = 60
bits_scale = 49
# 60*2 + 49*1*2 = 218 == 218 (for N = 8192 & 128 bit security)
coeff_mod_bit_sizes = [fe_bits_scale] + \
[bits_scale] * multiply_depth + \
[fe_bits_scale]
......@@ -122,26 +133,28 @@ class CKKSCoordinator(CKKS):
self.secret_context = secret_context
self.send_public_context()
self.send_max_iter()
num_examples = host_channel.recv('num_examples')
self.iter_per_epoch = math.ceil(num_examples / batch_size)
def send_max_iter(self):
self.host_channel.send("max_iter", self.max_iter)
self.guest_channel.send_all("max_iter", self.max_iter)
def send_public_context(self):
serialize_context = self.context.serialize()
self.host_channel.send("public_context", serialize_context)
self.guest_channel.send_all("public_context", serialize_context)
def recv_model(self):
host_weight = self.load_vector(self.host_channel.recv('host_weight'))
host_bias = self.load_vector(self.host_channel.recv('host_bias'))
if self.multiclass:
host_weight = self.load_tensor(self.host_channel.recv('host_weight'))
host_bias = self.load_tensor(self.host_channel.recv('host_bias'))
guest_weight = self.guest_channel.recv_all('guest_weight')
guest_weight = [self.load_vector(weight) for weight in guest_weight]
guest_weight = self.guest_channel.recv_all('guest_weight')
guest_weight = [self.load_tensor(weight) for weight in guest_weight]
else:
host_weight = self.load_vector(self.host_channel.recv('host_weight'))
host_bias = self.load_vector(self.host_channel.recv('host_bias'))
guest_weight = self.guest_channel.recv_all('guest_weight')
guest_weight = [self.load_vector(weight) for weight in guest_weight]
return host_weight, host_bias, guest_weight
......@@ -161,10 +174,16 @@ class CKKSCoordinator(CKKS):
return host_weight, host_bias, guest_weight
def encrypt_model(self, host_weight, host_bias, guest_weight):
host_weight = self.encrypt_vector(host_weight)
host_bias = self.encrypt_vector(host_bias)
if self.multiclass:
host_weight = self.encrypt_tensor(host_weight)
host_bias = self.encrypt_tensor(host_bias)
guest_weight = [self.encrypt_vector(weight) for weight in guest_weight]
guest_weight = [self.encrypt_tensor(weight) for weight in guest_weight]
else:
host_weight = self.encrypt_vector(host_weight)
host_bias = self.encrypt_vector(host_bias)
guest_weight = [self.encrypt_vector(weight) for weight in guest_weight]
return host_weight, host_bias, guest_weight
......@@ -187,28 +206,41 @@ class CKKSCoordinator(CKKS):
host_weight, host_bias, guest_weight)
# list to numpy ndarrry
host_weight = np.array(host_weight)
host_bias = np.array(host_bias)
guest_weight = [np.array(weight) for weight in guest_weight]
if self.multiclass:
host_weight = np.array(host_weight.tolist()).T
host_bias = np.array(host_bias.tolist()).T
guest_weight = [np.array(weight.tolist()).T for weight in guest_weight]
else:
host_weight = np.array(host_weight)
host_bias = np.array(host_bias)
guest_weight = [np.array(weight) for weight in guest_weight]
self.send_model(host_weight, host_bias, guest_weight)
def train(self):
logger.info(f'iteration {self.t} / {self.max_iter}')
self.t += self.iter_per_epoch
for i in range(self.t // self.max_iter):
self.update_ciphertext_model()
logger.warning(f'decrypt model #{i+1}')
num_dec = self.t // self.max_iter
self.t = self.t % self.max_iter
if self.t == 0:
num_dec -= 1
self.t = self.max_iter
for i in range(num_dec):
logger.warning(f'decrypt model #{i+1}')
self.update_ciphertext_model()
def compute_loss(self):
logger.info(f'iteration {self.t} / {self.max_iter}')
self.t += 1
if self.t > self.max_iter:
if self.t >= self.max_iter:
self.t = 0
self.update_ciphertext_model()
logger.warning('decrypt model')
self.update_ciphertext_model()
loss = self.load_vector(self.host_channel.recv('loss'))
loss = self.decrypt(loss, self.secret_context.secret_key())[0]
if self.multiclass:
loss = self.load_tensor(self.host_channel.recv('loss'))
loss = self.decrypt(loss, self.secret_context.secret_key()).tolist()
else:
loss = self.load_vector(self.host_channel.recv('loss'))
loss = self.decrypt(loss, self.secret_context.secret_key())[0]
logger.info(f'loss={loss}')
\ No newline at end of file
......@@ -5,7 +5,7 @@ from primihub.FL.utils.dataset import read_data, DataLoader
from primihub.utils.logger_util import logger
import pickle
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from .vfl_base import LogisticRegression_Guest_Plaintext,\
LogisticRegression_Guest_CKKS
......@@ -70,8 +70,8 @@ class LogisticRegressionGuest(BaseModel):
raise RuntimeError(error_msg)
# data preprocessing
# minmaxscaler
scaler = MinMaxScaler()
# StandardScaler
scaler = StandardScaler()
x = scaler.fit_transform(x)
# guest training
......@@ -196,8 +196,9 @@ class CKKS_Guest(Plaintext_Guest, CKKS):
alpha,
output_dim)
self.recv_public_context(coordinator_channel)
self.max_iter = coordinator_channel.recv('max_iter')
CKKS.__init__(self, self.context)
multiply_per_iter = 2
self.max_iter = self.multiply_depth // multiply_per_iter
self.encrypt_model()
def recv_public_context(self, coordinator_channel):
......@@ -205,13 +206,21 @@ class CKKS_Guest(Plaintext_Guest, CKKS):
self.context = coordinator_channel.recv('public_context')
def encrypt_model(self):
self.model.weight = self.encrypt_vector(self.model.weight)
if self.model.multiclass:
self.model.weight = self.encrypt_tensor(self.model.weight.T)
else:
self.model.weight = self.encrypt_vector(self.model.weight)
def update_ciphertext_model(self):
self.coordinator_channel.send('guest_weight',
self.model.weight.serialize())
self.model.weight = self.load_vector(
self.coordinator_channel.recv('guest_weight'))
if self.model.multiclass:
self.model.weight = self.load_tensor(
self.coordinator_channel.recv('guest_weight'))
else:
self.model.weight = self.load_vector(
self.coordinator_channel.recv('guest_weight'))
def update_plaintext_model(self):
self.coordinator_channel.send('guest_weight',
......@@ -232,23 +241,26 @@ class CKKS_Guest(Plaintext_Guest, CKKS):
logger.info(f'iteration {self.t} / {self.max_iter}')
if self.t >= self.max_iter:
self.t = 0
self.update_ciphertext_model()
logger.warning(f'decrypt model')
self.update_ciphertext_model()
self.t += 1
self.send_enc_z(x)
error = self.load_vector(self.host_channel.recv('error'))
if self.model.multiclass:
error = self.load_tensor(self.host_channel.recv('error'))
else:
error = self.load_vector(self.host_channel.recv('error'))
self.model.fit(x, error)
def compute_metrics(self, x):
logger.info(f'iteration {self.t} / {self.max_iter}')
self.t += 1
if self.t > self.max_iter:
if self.t >= self.max_iter:
self.t = 0
self.update_ciphertext_model()
logger.warning(f'decrypt model')
self.update_ciphertext_model()
self.send_enc_z(x)
self.send_enc_regular_loss()
......
......@@ -11,7 +11,7 @@ import numpy as np
from sklearn import metrics
from primihub.FL.metrics.hfl_metrics import ks_from_fpr_tpr,\
auc_from_fpr_tpr
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from .vfl_base import LogisticRegression_Host_Plaintext,\
LogisticRegression_Host_CKKS
......@@ -79,8 +79,8 @@ class LogisticRegressionHost(BaseModel):
raise RuntimeError(error_msg)
# data preprocessing
# minmaxscaler
scaler = MinMaxScaler()
# StandardScaler
scaler = StandardScaler()
x = scaler.fit_transform(x)
# host training
......@@ -195,13 +195,7 @@ class Plaintext_Host:
def send_output_dim(self, guest_channel):
self.guest_channel = guest_channel
if self.model.multiclass:
output_dim = self.model.weight.shape[1]
else:
output_dim = 1
guest_channel.send_all('output_dim', output_dim)
guest_channel.send_all('output_dim', self.model.output_dim)
def compute_z(self, x):
guest_z = self.guest_channel.recv_all('guest_z')
......@@ -274,10 +268,12 @@ class CKKS_Host(Plaintext_Host, CKKS):
learning_rate,
alpha)
self.send_output_dim(guest_channel)
coordinator_channel.send('multiclass', self.model.multiclass)
self.recv_public_context(coordinator_channel)
self.max_iter = coordinator_channel.recv('max_iter')
coordinator_channel.send('num_examples', x.shape[0])
CKKS.__init__(self, self.context)
multiply_per_iter = 2
self.max_iter = self.multiply_depth // multiply_per_iter
self.encrypt_model()
def recv_public_context(self, coordinator_channel):
......@@ -285,18 +281,29 @@ class CKKS_Host(Plaintext_Host, CKKS):
self.context = coordinator_channel.recv('public_context')
def encrypt_model(self):
self.model.weight = self.encrypt_vector(self.model.weight)
self.model.bias = self.encrypt_vector(self.model.bias)
if self.model.multiclass:
self.model.weight = self.encrypt_tensor(self.model.weight.T)
self.model.bias = self.encrypt_tensor(self.model.bias.T)
else:
self.model.weight = self.encrypt_vector(self.model.weight)
self.model.bias = self.encrypt_vector(self.model.bias)
def update_ciphertext_model(self):
self.coordinator_channel.send('host_weight',
self.model.weight.serialize())
self.coordinator_channel.send('host_bias',
self.model.bias.serialize())
self.model.weight = self.load_vector(
self.coordinator_channel.recv('host_weight'))
self.model.bias = self.load_vector(
self.coordinator_channel.recv('host_bias'))
if self.model.multiclass:
self.model.weight = self.load_tensor(
self.coordinator_channel.recv('host_weight'))
self.model.bias = self.load_tensor(
self.coordinator_channel.recv('host_bias'))
else:
self.model.weight = self.load_vector(
self.coordinator_channel.recv('host_weight'))
self.model.bias = self.load_vector(
self.coordinator_channel.recv('host_bias'))
def update_plaintext_model(self):
self.coordinator_channel.send('host_weight',
......@@ -308,13 +315,19 @@ class CKKS_Host(Plaintext_Host, CKKS):
def compute_enc_z(self, x):
guest_z = self.guest_channel.recv_all('guest_z')
guest_z = [self.load_vector(z) for z in guest_z]
if self.model.multiclass:
guest_z = [self.load_tensor(z) for z in guest_z]
else:
guest_z = [self.load_vector(z) for z in guest_z]
return self.model.compute_enc_z(x, guest_z)
def compute_enc_regular_loss(self):
if self.model.alpha != 0:
guest_regular_loss = self.guest_channel.recv_all('guest_regular_loss')
guest_regular_loss = [self.load_vector(s) for s in guest_regular_loss]
if self.model.multiclass:
guest_regular_loss = [self.load_tensor(s) for s in guest_regular_loss]
else:
guest_regular_loss = [self.load_vector(s) for s in guest_regular_loss]
return self.model.compute_regular_loss(sum(guest_regular_loss))
else:
return 0.
......@@ -323,8 +336,9 @@ class CKKS_Host(Plaintext_Host, CKKS):
logger.info(f'iteration {self.t} / {self.max_iter}')
if self.t >= self.max_iter:
self.t = 0
self.update_ciphertext_model()
logger.warning(f'decrypt model')
self.update_ciphertext_model()
self.t += 1
z = self.compute_enc_z(x)
......@@ -336,11 +350,10 @@ class CKKS_Host(Plaintext_Host, CKKS):
def compute_metrics(self, x, y):
logger.info(f'iteration {self.t} / {self.max_iter}')
self.t += 1
if self.t > self.max_iter:
if self.t >= self.max_iter:
self.t = 0
self.update_ciphertext_model()
logger.warning(f'decrypt model')
self.update_ciphertext_model()
z = self.compute_enc_z(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册