From 889b7c28fb4f20bada8539b140dd2a11847ef33a Mon Sep 17 00:00:00 2001 From: jrzaurin Date: Fri, 10 Jul 2020 18:30:21 +0100 Subject: [PATCH] Adapted the examples to the small code changes. Replace add_ with add in WideDeep to avoid annoying warnings. Replace output_dim with pred_dim in Wide for consistentcy --- pytorch_widedeep/models/wide.py | 12 ++++++------ pytorch_widedeep/models/wide_deep.py | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_widedeep/models/wide.py b/pytorch_widedeep/models/wide.py index 91b8793..10cc790 100644 --- a/pytorch_widedeep/models/wide.py +++ b/pytorch_widedeep/models/wide.py @@ -5,14 +5,14 @@ from ..wdtypes import * class Wide(nn.Module): r"""Simple linear layer that will receive the one-hot encoded `'wide'` - input and connect it to the output neuron. + input and connect it to the output neuron(s). Parameters ----------- wide_dim: int size of the input tensor - output_dim: int - size of the ouput tensor + pred_dim: int + size of the ouput tensor containing the predictions Attributes ----------- @@ -24,7 +24,7 @@ class Wide(nn.Module): >>> import torch >>> from pytorch_widedeep.models import Wide >>> X = torch.empty(4, 4).random_(2) - >>> wide = Wide(wide_dim=X.size(0), output_dim=1) + >>> wide = Wide(wide_dim=X.size(0), pred_dim=1) >>> wide(X) tensor([[-0.8841], [-0.8633], @@ -32,9 +32,9 @@ class Wide(nn.Module): [-0.4762]], grad_fn=) """ - def __init__(self, wide_dim: int, output_dim: int = 1): + def __init__(self, wide_dim: int, pred_dim: int = 1): super(Wide, self).__init__() - self.wide_linear = nn.Linear(wide_dim, output_dim) + self.wide_linear = nn.Linear(wide_dim, pred_dim) def forward(self, X: Tensor) -> Tensor: # type: ignore r"""Forward pass. Simply connecting the one-hot encoded input with the diff --git a/pytorch_widedeep/models/wide_deep.py b/pytorch_widedeep/models/wide_deep.py index bc2cab6..8956381 100644 --- a/pytorch_widedeep/models/wide_deep.py +++ b/pytorch_widedeep/models/wide_deep.py @@ -201,13 +201,13 @@ class WideDeep(nn.Module): if self.deepimage is not None: deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1) # type: ignore deepside_out = self.deephead(deepside) - return out.add_(deepside_out) + return out.add(deepside_out) else: - out.add_(self.deepdense(X["deepdense"])) + out.add(self.deepdense(X["deepdense"])) if self.deeptext is not None: - out.add_(self.deeptext(X["deeptext"])) + out.add(self.deeptext(X["deeptext"])) if self.deepimage is not None: - out.add_(self.deepimage(X["deepimage"])) + out.add(self.deepimage(X["deepimage"])) return out def compile( -- GitLab