未验证 提交 d749eca8 编写于 作者: X Xuefeng Xu 提交者: GitHub

use roc vertical averaging in HFL to save communication overhead (#518)

上级 9f052f37
......@@ -279,7 +279,7 @@ class Plaintext_Client:
drop_intermediate=False)
self.server_channel.send("fpr", fpr)
self.server_channel.send("tpr", tpr)
self.server_channel.send("thresholds", thresholds)
# self.server_channel.send("thresholds", thresholds)
# ks
ks = ks_from_fpr_tpr(fpr, tpr)
......
......@@ -6,7 +6,7 @@ from primihub.utils.logger_util import logger
import json
import numpy as np
from phe import paillier
from primihub.FL.metrics.hfl_metrics import fpr_tpr_merge2,\
from primihub.FL.metrics.hfl_metrics import roc_vertical_avg,\
ks_from_fpr_tpr,\
auc_from_fpr_tpr
from .base import PaillierFunc
......@@ -157,21 +157,15 @@ class Plaintext_DPSGD_Server:
def get_fpr_tpr(self):
client_fpr = self.client_channel.recv_all('fpr')
client_tpr = self.client_channel.recv_all('tpr')
client_thresholds = self.client_channel.recv_all('thresholds')
# client_thresholds = self.client_channel.recv_all('thresholds')
# fpr & tpr
# Note: fpr_tpr_merge2 only support two clients
# use ROC averaging when for multiple clients
# roc_vertical_avg: sample = 0.1 * n
samples = int(0.1 * sum(self.num_examples_weights))
fpr,\
tpr,\
thresholds = fpr_tpr_merge2(client_fpr[0],
client_tpr[0],
client_thresholds[0],
client_fpr[1],
client_tpr[1],
client_thresholds[1],
self.num_positive_examples_weights,
self.num_negtive_examples_weights)
tpr = roc_vertical_avg(samples,
client_fpr,
client_tpr)
return fpr, tpr
def get_metrics(self):
......
......@@ -317,7 +317,7 @@ class Plaintext_Client:
drop_intermediate=False)
self.server_channel.send("fpr", fpr)
self.server_channel.send("tpr", tpr)
self.server_channel.send("thresholds", thresholds)
# self.server_channel.send("thresholds", thresholds)
client_metrics['train_fpr'] = fpr
client_metrics['train_tpr'] = tpr
......
......@@ -6,9 +6,9 @@ from primihub.utils.logger_util import logger
import json
import numpy as np
import torch
from primihub.FL.metrics.hfl_metrics import fpr_tpr_merge2,\
ks_from_fpr_tpr,\
auc_from_fpr_tpr
from primihub.FL.metrics.hfl_metrics import roc_vertical_avg,\
ks_from_fpr_tpr,\
auc_from_fpr_tpr
from .base import create_model
......@@ -203,21 +203,15 @@ class Plaintext_Server:
def get_fpr_tpr(self):
client_fpr = self.client_channel.recv_all('fpr')
client_tpr = self.client_channel.recv_all('tpr')
client_thresholds = self.client_channel.recv_all('thresholds')
# client_thresholds = self.client_channel.recv_all('thresholds')
# fpr & tpr
# Note: fpr_tpr_merge2 only support two clients
# use ROC averaging when for multiple clients
# roc_vertical_avg: sample = 0.1 * n
samples = int(0.1 * sum(self.num_examples_weights))
fpr,\
tpr,\
thresholds = fpr_tpr_merge2(client_fpr[0],
client_tpr[0],
client_thresholds[0],
client_fpr[1],
client_tpr[1],
client_thresholds[1],
self.num_positive_examples_weights,
self.num_negtive_examples_weights)
tpr = roc_vertical_avg(samples,
client_fpr,
client_tpr)
return fpr, tpr
def get_metrics(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册