提交 889b7c28 编写于 作者: J jrzaurin

Adapted the examples to the small code changes. Replace add_ with add in...

Adapted the examples to the small code changes. Replace add_ with add in WideDeep to avoid annoying warnings. Replace output_dim with pred_dim in Wide for consistentcy
上级 3e963e78
...@@ -5,14 +5,14 @@ from ..wdtypes import * ...@@ -5,14 +5,14 @@ from ..wdtypes import *
class Wide(nn.Module): class Wide(nn.Module):
r"""Simple linear layer that will receive the one-hot encoded `'wide'` r"""Simple linear layer that will receive the one-hot encoded `'wide'`
input and connect it to the output neuron. input and connect it to the output neuron(s).
Parameters Parameters
----------- -----------
wide_dim: int wide_dim: int
size of the input tensor size of the input tensor
output_dim: int pred_dim: int
size of the ouput tensor size of the ouput tensor containing the predictions
Attributes Attributes
----------- -----------
...@@ -24,7 +24,7 @@ class Wide(nn.Module): ...@@ -24,7 +24,7 @@ class Wide(nn.Module):
>>> import torch >>> import torch
>>> from pytorch_widedeep.models import Wide >>> from pytorch_widedeep.models import Wide
>>> X = torch.empty(4, 4).random_(2) >>> X = torch.empty(4, 4).random_(2)
>>> wide = Wide(wide_dim=X.size(0), output_dim=1) >>> wide = Wide(wide_dim=X.size(0), pred_dim=1)
>>> wide(X) >>> wide(X)
tensor([[-0.8841], tensor([[-0.8841],
[-0.8633], [-0.8633],
...@@ -32,9 +32,9 @@ class Wide(nn.Module): ...@@ -32,9 +32,9 @@ class Wide(nn.Module):
[-0.4762]], grad_fn=<AddmmBackward>) [-0.4762]], grad_fn=<AddmmBackward>)
""" """
def __init__(self, wide_dim: int, output_dim: int = 1): def __init__(self, wide_dim: int, pred_dim: int = 1):
super(Wide, self).__init__() super(Wide, self).__init__()
self.wide_linear = nn.Linear(wide_dim, output_dim) self.wide_linear = nn.Linear(wide_dim, pred_dim)
def forward(self, X: Tensor) -> Tensor: # type: ignore def forward(self, X: Tensor) -> Tensor: # type: ignore
r"""Forward pass. Simply connecting the one-hot encoded input with the r"""Forward pass. Simply connecting the one-hot encoded input with the
......
...@@ -201,13 +201,13 @@ class WideDeep(nn.Module): ...@@ -201,13 +201,13 @@ class WideDeep(nn.Module):
if self.deepimage is not None: if self.deepimage is not None:
deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1) # type: ignore deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1) # type: ignore
deepside_out = self.deephead(deepside) deepside_out = self.deephead(deepside)
return out.add_(deepside_out) return out.add(deepside_out)
else: else:
out.add_(self.deepdense(X["deepdense"])) out.add(self.deepdense(X["deepdense"]))
if self.deeptext is not None: if self.deeptext is not None:
out.add_(self.deeptext(X["deeptext"])) out.add(self.deeptext(X["deeptext"]))
if self.deepimage is not None: if self.deepimage is not None:
out.add_(self.deepimage(X["deepimage"])) out.add(self.deepimage(X["deepimage"]))
return out return out
def compile( def compile(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册