提交 845d630d 编写于 作者: J Javier Rodriguez Zaurin

#89 improved the solution for device consistency

上级 2339c3a0
......@@ -136,7 +136,6 @@ class WideDeep(nn.Module):
enforce_positive_activation: str = "softplus",
pred_dim: int = 1,
with_fds: bool = False,
device: Optional[str] = None,
**fds_config,
):
super(WideDeep, self).__init__()
......@@ -152,11 +151,9 @@ class WideDeep(nn.Module):
with_fds,
)
self.wd_device = (
device
if device is not None
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
# this attribute will be eventually over-written by the Trainer's
# device. Acts here as a 'placeholder'.
self.wd_device = "cpu"
# required as attribute just in case we pass a deephead
self.pred_dim = pred_dim
......
......@@ -238,6 +238,7 @@ class Trainer:
self.lambda_sparse = kwargs.get("lambda_sparse", 1e-3)
self.reducing_matrix = create_explain_matrix(self.model)
self.model.to(self.device)
self.model.wd_device = self.device
self.objective = objective
self.method = _ObjectiveToMethod.get(objective)
......@@ -1474,7 +1475,7 @@ class Trainer:
if sys.platform == "darwin" and sys.version_info.minor > 7
else os.cpu_count()
)
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
default_device = "cuda" if torch.cuda.is_available() else "cpu"
device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册