提交 889b7c28 编写于 作者: J jrzaurin

Adapted the examples to the small code changes. Replace add_ with add in...

Adapted the examples to the small code changes. Replace add_ with add in WideDeep to avoid annoying warnings. Replace output_dim with pred_dim in Wide for consistentcy
上级 3e963e78
......@@ -5,14 +5,14 @@ from ..wdtypes import *
class Wide(nn.Module):
r"""Simple linear layer that will receive the one-hot encoded `'wide'`
input and connect it to the output neuron.
input and connect it to the output neuron(s).
Parameters
-----------
wide_dim: int
size of the input tensor
output_dim: int
size of the ouput tensor
pred_dim: int
size of the ouput tensor containing the predictions
Attributes
-----------
......@@ -24,7 +24,7 @@ class Wide(nn.Module):
>>> import torch
>>> from pytorch_widedeep.models import Wide
>>> X = torch.empty(4, 4).random_(2)
>>> wide = Wide(wide_dim=X.size(0), output_dim=1)
>>> wide = Wide(wide_dim=X.size(0), pred_dim=1)
>>> wide(X)
tensor([[-0.8841],
[-0.8633],
......@@ -32,9 +32,9 @@ class Wide(nn.Module):
[-0.4762]], grad_fn=<AddmmBackward>)
"""
def __init__(self, wide_dim: int, output_dim: int = 1):
def __init__(self, wide_dim: int, pred_dim: int = 1):
super(Wide, self).__init__()
self.wide_linear = nn.Linear(wide_dim, output_dim)
self.wide_linear = nn.Linear(wide_dim, pred_dim)
def forward(self, X: Tensor) -> Tensor: # type: ignore
r"""Forward pass. Simply connecting the one-hot encoded input with the
......
......@@ -201,13 +201,13 @@ class WideDeep(nn.Module):
if self.deepimage is not None:
deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1) # type: ignore
deepside_out = self.deephead(deepside)
return out.add_(deepside_out)
return out.add(deepside_out)
else:
out.add_(self.deepdense(X["deepdense"]))
out.add(self.deepdense(X["deepdense"]))
if self.deeptext is not None:
out.add_(self.deeptext(X["deeptext"]))
out.add(self.deeptext(X["deeptext"]))
if self.deepimage is not None:
out.add_(self.deepimage(X["deepimage"]))
out.add(self.deepimage(X["deepimage"]))
return out
def compile(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册