未验证 提交 fc648516 编写于 作者: X Xuefeng Xu 提交者: GitHub

fix bug when using cuda device (#582)

上级 a470b3f3
......@@ -263,6 +263,7 @@ class Plaintext_Client:
def set_model(self, model):
self.model.load_state_dict(model)
self.model.to(self.device)
def send_output_dim(self, y):
if self.task == 'regression':
......@@ -350,8 +351,8 @@ class Plaintext_Client:
pred = self.model(x)
if self.task == 'classification':
y_true = torch.cat((y_true, y))
y_score = torch.cat((y_score, pred))
y_true = torch.cat((y_true, y.cpu()))
y_score = torch.cat((y_score, pred.cpu()))
loss += self.loss_fn(pred, y).item() * len(x)
if self.output_dim == 1:
......@@ -359,8 +360,8 @@ class Plaintext_Client:
else:
acc += (pred.argmax(1) == y).type(torch.float).sum().item()
elif self.task == 'regression':
mae += F.l1_loss(pred, y) * len(x)
mse += F.mse_loss(pred, y) * len(x)
mae += F.l1_loss(pred, y).cpu() * len(x)
mse += F.mse_loss(pred, y).cpu() * len(x)
client_metrics = {}
......@@ -390,11 +391,11 @@ class Plaintext_Client:
elif self.task == 'regression':
mse /= size
client_metrics['train_mse'] = mse
self.server_channel.send("mse", mse.type(torch.float64))
self.server_channel.send("mse", mse)
mae /= size
client_metrics['train_mae'] = mae
self.server_channel.send("mae", mae.type(torch.float64))
self.server_channel.send("mae", mae)
logger.info(f"mse={mse}, mae={mae}")
......@@ -420,7 +421,7 @@ class DPSGD_Client(Plaintext_Client):
input_shape = list(self.input_shape)
# set batch size equals to 1 to initilize lazy module
input_shape.insert(0, 1)
self.model.forward(torch.ones(input_shape))
self.model.forward(torch.ones(input_shape).to(self.device))
super().lazy_module_init()
def enable_DP_training(self, train_dataloader):
......
......@@ -146,6 +146,7 @@ class CNNClient(BaseModel):
model.eval()
with torch.no_grad():
for x, y in data_loader:
x = x.to(device)
pred = model(x)
pred_prob = torch.softmax(pred, dim=1)
pred_y = pred_prob.argmax(1)
......
......@@ -152,7 +152,7 @@ class Plaintext_Server:
input_shape = list(self.input_shape)
# set batch size equals to 1 to initilize lazy module
input_shape.insert(0, 1)
self.model.forward(torch.ones(input_shape))
self.model.forward(torch.ones(input_shape).to(self.device))
self.model.load_state_dict(self.model.state_dict())
self.server_model_broadcast()
......@@ -170,7 +170,7 @@ class Plaintext_Server:
np.array(self.num_positive_examples_weights)).tolist()
self.num_examples_weights = torch.tensor(self.num_examples_weights,
dtype=torch.float32)
dtype=torch.float32).to(self.device)
self.num_examples_weights_sum = self.num_examples_weights.sum()
def client_model_aggregate(self):
......@@ -207,9 +207,11 @@ class Plaintext_Server:
raise RuntimeError(error_msg)
client_metrics = self.client_channel.recv_all(metrics_name)
return np.average(client_metrics,
weights=self.num_examples_weights)
metrics = torch.tensor(client_metrics, dtype=torch.float).to(self.device) \
@ self.num_examples_weights \
/ self.num_examples_weights_sum
return float(metrics)
def get_fpr_tpr(self):
client_fpr = self.client_channel.recv_all('fpr')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册