未验证 提交 fc648516 编写于 作者: X Xuefeng Xu 提交者: GitHub

fix bug when using cuda device (#582)

上级 a470b3f3
...@@ -263,6 +263,7 @@ class Plaintext_Client: ...@@ -263,6 +263,7 @@ class Plaintext_Client:
def set_model(self, model): def set_model(self, model):
self.model.load_state_dict(model) self.model.load_state_dict(model)
self.model.to(self.device)
def send_output_dim(self, y): def send_output_dim(self, y):
if self.task == 'regression': if self.task == 'regression':
...@@ -350,8 +351,8 @@ class Plaintext_Client: ...@@ -350,8 +351,8 @@ class Plaintext_Client:
pred = self.model(x) pred = self.model(x)
if self.task == 'classification': if self.task == 'classification':
y_true = torch.cat((y_true, y)) y_true = torch.cat((y_true, y.cpu()))
y_score = torch.cat((y_score, pred)) y_score = torch.cat((y_score, pred.cpu()))
loss += self.loss_fn(pred, y).item() * len(x) loss += self.loss_fn(pred, y).item() * len(x)
if self.output_dim == 1: if self.output_dim == 1:
...@@ -359,8 +360,8 @@ class Plaintext_Client: ...@@ -359,8 +360,8 @@ class Plaintext_Client:
else: else:
acc += (pred.argmax(1) == y).type(torch.float).sum().item() acc += (pred.argmax(1) == y).type(torch.float).sum().item()
elif self.task == 'regression': elif self.task == 'regression':
mae += F.l1_loss(pred, y) * len(x) mae += F.l1_loss(pred, y).cpu() * len(x)
mse += F.mse_loss(pred, y) * len(x) mse += F.mse_loss(pred, y).cpu() * len(x)
client_metrics = {} client_metrics = {}
...@@ -390,11 +391,11 @@ class Plaintext_Client: ...@@ -390,11 +391,11 @@ class Plaintext_Client:
elif self.task == 'regression': elif self.task == 'regression':
mse /= size mse /= size
client_metrics['train_mse'] = mse client_metrics['train_mse'] = mse
self.server_channel.send("mse", mse.type(torch.float64)) self.server_channel.send("mse", mse)
mae /= size mae /= size
client_metrics['train_mae'] = mae client_metrics['train_mae'] = mae
self.server_channel.send("mae", mae.type(torch.float64)) self.server_channel.send("mae", mae)
logger.info(f"mse={mse}, mae={mae}") logger.info(f"mse={mse}, mae={mae}")
...@@ -420,7 +421,7 @@ class DPSGD_Client(Plaintext_Client): ...@@ -420,7 +421,7 @@ class DPSGD_Client(Plaintext_Client):
input_shape = list(self.input_shape) input_shape = list(self.input_shape)
# set batch size equals to 1 to initilize lazy module # set batch size equals to 1 to initilize lazy module
input_shape.insert(0, 1) input_shape.insert(0, 1)
self.model.forward(torch.ones(input_shape)) self.model.forward(torch.ones(input_shape).to(self.device))
super().lazy_module_init() super().lazy_module_init()
def enable_DP_training(self, train_dataloader): def enable_DP_training(self, train_dataloader):
......
...@@ -146,6 +146,7 @@ class CNNClient(BaseModel): ...@@ -146,6 +146,7 @@ class CNNClient(BaseModel):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
for x, y in data_loader: for x, y in data_loader:
x = x.to(device)
pred = model(x) pred = model(x)
pred_prob = torch.softmax(pred, dim=1) pred_prob = torch.softmax(pred, dim=1)
pred_y = pred_prob.argmax(1) pred_y = pred_prob.argmax(1)
......
...@@ -152,7 +152,7 @@ class Plaintext_Server: ...@@ -152,7 +152,7 @@ class Plaintext_Server:
input_shape = list(self.input_shape) input_shape = list(self.input_shape)
# set batch size equals to 1 to initilize lazy module # set batch size equals to 1 to initilize lazy module
input_shape.insert(0, 1) input_shape.insert(0, 1)
self.model.forward(torch.ones(input_shape)) self.model.forward(torch.ones(input_shape).to(self.device))
self.model.load_state_dict(self.model.state_dict()) self.model.load_state_dict(self.model.state_dict())
self.server_model_broadcast() self.server_model_broadcast()
...@@ -170,7 +170,7 @@ class Plaintext_Server: ...@@ -170,7 +170,7 @@ class Plaintext_Server:
np.array(self.num_positive_examples_weights)).tolist() np.array(self.num_positive_examples_weights)).tolist()
self.num_examples_weights = torch.tensor(self.num_examples_weights, self.num_examples_weights = torch.tensor(self.num_examples_weights,
dtype=torch.float32) dtype=torch.float32).to(self.device)
self.num_examples_weights_sum = self.num_examples_weights.sum() self.num_examples_weights_sum = self.num_examples_weights.sum()
def client_model_aggregate(self): def client_model_aggregate(self):
...@@ -207,9 +207,11 @@ class Plaintext_Server: ...@@ -207,9 +207,11 @@ class Plaintext_Server:
raise RuntimeError(error_msg) raise RuntimeError(error_msg)
client_metrics = self.client_channel.recv_all(metrics_name) client_metrics = self.client_channel.recv_all(metrics_name)
return np.average(client_metrics, metrics = torch.tensor(client_metrics, dtype=torch.float).to(self.device) \
weights=self.num_examples_weights) @ self.num_examples_weights \
/ self.num_examples_weights_sum
return float(metrics)
def get_fpr_tpr(self): def get_fpr_tpr(self):
client_fpr = self.client_channel.recv_all('fpr') client_fpr = self.client_channel.recv_all('fpr')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册