提交 e32993bc 编写于 作者: J jrzaurin

changed the dense layer to be almost identical to that of fastai, which I...

changed the dense layer to be almost identical to that of fastai, which I really like. Changed the code accordingly. Changed the name of DeepDense and DeepDenseResnet to TabMlp and TabResnet. Change the tests acccordingly
上级 902c5987
......@@ -39,3 +39,4 @@ Dependencies
* tqdm
* torch
* torchvision
* einops
......@@ -3,15 +3,13 @@ Quick Start
This is an example of a binary classification with the `adult census
<https://www.kaggle.com/wenruliu/adult-income-dataset?select=adult.csv>`__
dataset using a combination of a ``Wide`` and ``DeepDense`` model with
defaults settings.
dataset using a combination of a wide and deep model (in this case a so called
``DeepDense model``) with defaults settings.
Read and split the dataset
--------------------------
The following code snippet is not directly related to ``pytorch-widedeep``.
.. code-block:: python
import pandas as pd
......@@ -30,7 +28,7 @@ Prepare the wide and deep columns
.. code-block:: python
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.metrics import Accuracy
......@@ -62,15 +60,15 @@ Preprocessing and model components definition
.. code-block:: python
# wide
# wide (linear) model
preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
X_wide = preprocess_wide.fit_transform(df_train)
wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)
# deepdense
preprocess_deep = DensePreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
X_deep = preprocess_deep.fit_transform(df_train)
deepdense = DeepDense(
# deeptabular component is a DeepDense model
preprocess_deep = TabPreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
X_tab = preprocess_deep.fit_transform(df_train)
deeptabular = DeepDense(
hidden_layers=[64, 32],
deep_column_idx=preprocess_deep.deep_column_idx,
embed_input=preprocess_deep.embeddings_input,
......@@ -84,11 +82,11 @@ Build, compile, fit and predict
.. code-block:: python
# build, compile and fit
model = WideDeep(wide=wide, deepdense=deepdense)
model = WideDeep(wide=wide, deeptabular=deeptabular)
model.compile(method="binary", metrics=[Accuracy])
model.fit(
X_wide=X_wide,
X_deep=X_deep,
X_tab=X_tab,
target=target,
n_epochs=5,
batch_size=256,
......@@ -97,14 +95,14 @@ Build, compile, fit and predict
# predict
X_wide_te = preprocess_wide.transform(df_test)
X_deep_te = preprocess_deep.transform(df_test)
preds = model.predict(X_wide=X_wide_te, X_deep=X_deep_te)
X_tab_te = preprocess_deep.transform(df_test)
preds = model.predict(X_wide=X_wide_te, X_tab=X_tab_te)
Of course, one can do much more, such as using different initializations,
optimizers or learning rate schedulers for each component of the overall
model. Adding FC-Heads to the Text and Image components. Using the Focal Loss,
warming up individual components before joined training, etc. See the
`examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/build_docs/examples>`__
warming up individual components before joined training, using the
`TabTransformer <https://arxiv.org/pdf/2012.06678.pdf>`__, etc. See the
`examples <https://github.com/jrzaurin/pytorch-widedeep/tree/build_docs/examples>`__
directory for a better understanding of the content of the package and its
functionalities.
The ``utils`` module
====================
Initially the intention was for the ``utils`` module to be hidden from the
user. However, there are a series of utilities there that might be useful for
a number of preprocessing tasks. All the classes and functions discussed here
are available directly from the ``utils`` module. For example, the
``LabelEncoder`` within the ``deeptabular_utils`` submodule can be imported as:
These are a series utilities that might be useful for a number of
preprocessing tasks. All the classes and functions discussed here are
available directly from the ``utils`` module. For example, the
``LabelEncoder`` within the ``deeptabular_utils`` submodule can be imported
as:
.. code-block:: python
......
......@@ -3,7 +3,7 @@ import torch
import pandas as pd
from pytorch_widedeep.optim import RAdam
from pytorch_widedeep.models import Wide, WideDeep, DeepDense, DeepDenseResnet
from pytorch_widedeep.models import Wide, WideDeep, TabMlp, TabResnet # noqa: F401
from pytorch_widedeep.metrics import Accuracy, Precision
from pytorch_widedeep.callbacks import (
LRHistory,
......@@ -55,16 +55,16 @@ if __name__ == "__main__":
wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)
deepdense = DeepDense(
hidden_layers=[64, 32],
deepdense = TabMlp(
mlp_hidden_dims=[64, 32],
dropout=[0.2, 0.2],
deep_column_idx=prepare_deep.deep_column_idx,
embed_input=prepare_deep.embeddings_input,
continuous_cols=continuous_cols,
)
# # To use DeepDenseResnet as the deepdense component simply:
# deepdense = DeepDenseResnet(
# # To use TabResnet as the deepdense component simply:
# deepdense = TabMlpResnet(
# blocks=[64, 32],
# deep_column_idx=prepare_deep.deep_column_idx,
# embed_input=prepare_deep.embeddings_input,
......
......@@ -61,7 +61,7 @@ if __name__ == "__main__":
continuous_cols=continuous_cols,
)
model = WideDeep(wide=wide, deepdense=deeptabular)
model = WideDeep(wide=wide, deeptabular=deeptabular)
wide_opt = torch.optim.Adam(model.wide.parameters(), lr=0.01)
deep_opt = RAdam(model.deeptabular.parameters())
......
......@@ -8,9 +8,9 @@ from pytorch_widedeep.models import (
Wide,
DeepText,
WideDeep,
DeepDense,
TabMlp,
DeepImage,
DeepDenseResnet,
TabResnet, # noqa: F401
)
from pytorch_widedeep.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_widedeep.initializers import KaimingNormal
......@@ -69,15 +69,15 @@ if __name__ == "__main__":
X_images = image_processor.fit_transform(df)
wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)
deepdense = DeepDense(
hidden_layers=[64, 32],
deepdense = TabMlp(
mlp_hidden_dims=[64, 32],
dropout=[0.2, 0.2],
deep_column_idx=prepare_deep.deep_column_idx,
embed_input=prepare_deep.embeddings_input,
continuous_cols=continuous_cols,
)
# # To use DeepDenseResnet as the deepdense component simply:
# deepdense = DeepDenseResnet(
# # To use TabResnet as the deepdense component simply:
# deepdense = TabResnet(
# blocks=[64, 32],
# dropout=0.2,
# deep_column_idx=prepare_deep.deep_column_idx,
......@@ -92,7 +92,7 @@ if __name__ == "__main__":
padding_idx=1,
embed_matrix=text_processor.embedding_matrix,
)
deepimage = DeepImage(pretrained=True, head_layers=None)
deepimage = DeepImage(pretrained=True, head_layers_dim=None)
model = WideDeep(
wide=wide, deeptabular=deepdense, deeptext=deeptext, deepimage=deepimage
)
......
......@@ -2,7 +2,7 @@ import numpy as np
import torch
import pandas as pd
from pytorch_widedeep.models import Wide, WideDeep, DeepDense
from pytorch_widedeep.models import Wide, WideDeep, TabMlp
from pytorch_widedeep.metrics import F1Score, Accuracy
from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor
......@@ -41,8 +41,8 @@ if __name__ == "__main__":
X_deep = prepare_deep.fit_transform(df)
wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=3)
deepdense = DeepDense(
hidden_layers=[64, 32],
deepdense = TabMlp(
mlp_hidden_dims=[64, 32],
dropout=[0.2, 0.2],
deep_column_idx=prepare_deep.deep_column_idx,
embed_input=prepare_deep.embeddings_input,
......
from .wide import Wide
from .tab_mlp import TabMlp
from .deep_text import DeepText
from .wide_deep import WideDeep
from .deep_dense import DeepDense
from .deep_image import DeepImage
from .tab_resnet import TabResnet
from .tab_transformer import TabTransformer
from .deep_dense_resnet import DeepDenseResnet
from torch import nn
from torchvision import models
from .tab_mlp import MLP
from ..wdtypes import * # noqa: F403
from .deep_dense import dense_layer
def conv_layer(
......@@ -57,7 +57,7 @@ class DeepImage(nn.Module):
freeze_n: int, default = 6
number of layers to freeze. Must be less than or equal to 8. If 8
the entire 'backbone' of the nwtwork will be frozen
head_layers: List, Optional
head_layers_dim: List, Optional
List with the sizes of the stacked dense layers in the head
e.g: [128, 64]
head_dropout: List, Optional
......@@ -81,7 +81,7 @@ class DeepImage(nn.Module):
>>> import torch
>>> from pytorch_widedeep.models import DeepImage
>>> X_img = torch.rand((2,3,224,224))
>>> model = DeepImage(head_layers=[512, 64, 8])
>>> model = DeepImage(head_layers_dim=[512, 64, 8])
>>> out = model(X_img)
"""
......@@ -90,9 +90,12 @@ class DeepImage(nn.Module):
pretrained: bool = True,
resnet_architecture: int = 18,
freeze_n: int = 6,
head_layers: Optional[List[int]] = None,
head_dropout: Optional[List[float]] = None,
head_layers_dim: Optional[List[int]] = None,
head_activation: Optional[str] = "relu",
head_dropout: Optional[Union[float, List[float]]] = None,
head_batchnorm: Optional[bool] = False,
head_batchnorm_last: Optional[bool] = False,
head_linear_first: Optional[bool] = False,
):
super(DeepImage, self).__init__()
......@@ -100,9 +103,12 @@ class DeepImage(nn.Module):
self.pretrained = pretrained
self.resnet_architecture = resnet_architecture
self.freeze_n = freeze_n
self.head_layers = head_layers
self.head_layers_dim = head_layers_dim
self.head_activation = head_activation
self.head_dropout = head_dropout
self.head_batchnorm = head_batchnorm
self.head_batchnorm_last = head_batchnorm_last
self.head_linear_first = head_linear_first
if pretrained:
vision_model = self.select_resnet_architecture(resnet_architecture)
......@@ -114,33 +120,28 @@ class DeepImage(nn.Module):
# the output_dim attribute will be used as input_dim when "merging" the models
self.output_dim = 512
if self.head_layers is not None:
assert self.head_layers[0] == self.output_dim, (
if self.head_layers_dim is not None:
assert self.head_layers_dim[0] == self.output_dim, (
"The output dimension from the backbone ({}) is not consistent with "
"the expected input dimension ({}) of the fc-head".format(
self.output_dim, self.head_layers[0]
self.output_dim, self.head_layers_dim[0]
)
)
if not head_dropout:
head_dropout = [0.0] * len(head_layers)
self.imagehead = nn.Sequential()
for i in range(1, len(head_layers)):
self.imagehead.add_module(
"dense_layer_{}".format(i - 1),
dense_layer(
head_layers[i - 1],
head_layers[i],
head_dropout[i - 1],
head_batchnorm,
),
)
self.output_dim = head_layers[-1]
self.imagehead = MLP(
head_layers_dim,
head_activation,
head_dropout,
head_batchnorm,
head_batchnorm_last,
head_linear_first,
)
self.output_dim = head_layers_dim[-1]
def forward(self, x: Tensor) -> Tensor: # type: ignore
r"""Forward pass connecting the `'backbone'` with the `'head layers'`"""
x = self.backbone(x)
x = x.view(x.size(0), -1)
if self.head_layers is not None:
if self.head_layers_dim is not None:
out = self.imagehead(x)
return out
else:
......
......@@ -4,8 +4,8 @@ import numpy as np
import torch
from torch import nn
from .tab_mlp import MLP
from ..wdtypes import * # noqa: F403
from .deep_dense import dense_layer
class DeepText(nn.Module):
......@@ -37,7 +37,7 @@ class DeepText(nn.Module):
Pretrained word embeddings
embed_trainable: bool, Optional. default = False
Boolean indicating if the pretrained embeddings are trainable
head_layers: List, Optional
head_layers_dim: List, Optional
List with the sizes of the stacked dense layers in the head
e.g: [128, 64]
head_dropout: List, Optional
......@@ -54,7 +54,7 @@ class DeepText(nn.Module):
Stack of LSTMs
texthead: :obj:`nn.Sequential`
Stack of dense layers on top of the RNN. This will only exists if
`head_layers` is not `None`
`head_layers_dim` is not `None`
output_dim: :obj:`int`
The output dimension of the model. This is a required attribute
neccesary to build the WideDeep class
......@@ -79,9 +79,12 @@ class DeepText(nn.Module):
embed_dim: Optional[int] = None,
embed_matrix: Optional[np.ndarray] = None,
embed_trainable: Optional[bool] = True,
head_layers: Optional[List[int]] = None,
head_dropout: Optional[List[float]] = None,
head_layers_dim: Optional[List[int]] = None,
head_activation: Optional[str] = "relu",
head_dropout: Optional[Union[float, List[float]]] = None,
head_batchnorm: Optional[bool] = False,
head_batchnorm_last: Optional[bool] = False,
head_linear_first: Optional[bool] = False,
):
super(DeepText, self).__init__()
......@@ -108,9 +111,12 @@ class DeepText(nn.Module):
self.padding_idx = padding_idx
self.embed_dim = embed_dim
self.embed_trainable = embed_trainable
self.head_layers = head_layers
self.head_layers_dim = head_layers_dim
self.head_activation = head_activation
self.head_dropout = head_dropout
self.head_batchnorm = head_batchnorm
self.head_batchnorm_last = head_batchnorm_last
self.head_linear_first = head_linear_first
# Pre-trained Embeddings
if isinstance(embed_matrix, np.ndarray):
......@@ -149,27 +155,22 @@ class DeepText(nn.Module):
# the output_dim attribute will be used as input_dim when "merging" the models
self.output_dim = hidden_dim * 2 if bidirectional else hidden_dim
if self.head_layers is not None:
assert self.head_layers[0] == self.output_dim, (
if self.head_layers_dim is not None:
assert self.head_layers_dim[0] == self.output_dim, (
"The hidden dimension from the stack or RNNs ({}) is not consistent with "
"the expected input dimension ({}) of the fc-head".format(
self.output_dim, self.head_layers[0]
self.output_dim, self.head_layers_dim[0]
)
)
if not head_dropout:
head_dropout = [0.0] * len(head_layers)
self.texthead = nn.Sequential()
for i in range(1, len(head_layers)):
self.texthead.add_module(
"dense_layer_{}".format(i - 1),
dense_layer(
head_layers[i - 1],
head_layers[i],
head_dropout[i - 1],
head_batchnorm,
),
)
self.output_dim = head_layers[-1]
self.texthead = MLP(
head_layers_dim,
head_activation,
head_dropout,
head_batchnorm,
head_batchnorm_last,
head_linear_first,
)
self.output_dim = head_layers_dim[-1]
def forward(self, X: Tensor) -> Tensor: # type: ignore
r"""Forward pass that is simply a standard RNN-based
......@@ -181,7 +182,7 @@ class DeepText(nn.Module):
last_h = torch.cat((h[-2], h[-1]), dim=1)
else:
last_h = h[-1]
if self.head_layers is not None:
if self.head_layers_dim is not None:
out = self.texthead(last_h)
return out
else:
......
......@@ -4,17 +4,73 @@ from torch import nn
from ..wdtypes import * # noqa: F403
def dense_layer(inp: int, out: int, p: float = 0.0, bn=False):
layers = [nn.Linear(inp, out), nn.LeakyReLU(inplace=True)]
if bn:
layers.append(nn.BatchNorm1d(out))
layers.append(nn.Dropout(p))
allowed_activations = ["relu", "leaky_relu", "gelu"]
def _get_activation_fn(activation):
if activation == "relu":
return nn.ReLU(inplace=True)
if activation == "leaky_relu":
return nn.LeakyReLU(inplace=True)
elif activation == "gelu":
return nn.GELU(inplace=True)
def dense_layer(
inp: int,
out: int,
activation: str,
p: float,
bn: bool,
linear_first: bool,
):
# This is bascially the LinBnDrop class at the fastai library
act_fn = _get_activation_fn(activation)
layers = [nn.BatchNorm1d(out if linear_first else inp)] if bn else []
if p != 0:
layers.append(nn.Dropout(p)) # type: ignore[arg-type]
lin = [nn.Linear(inp, out, bias=not bn), act_fn]
layers = lin + layers if linear_first else layers + lin
return nn.Sequential(*layers)
class DeepDense(nn.Module):
r"""Defines a so-called ``DeepDense`` model that can be used as the
class MLP(nn.Module):
def __init__(
self,
d_hidden: List[int],
activation: str,
dropout: Optional[Union[float, List[float]]],
batchnorm: bool,
batchnorm_last: bool,
linear_first: bool,
):
super(MLP, self).__init__()
if not dropout:
dropout = [0.0] * len(d_hidden)
elif isinstance(dropout, float):
dropout = [dropout] * len(d_hidden)
self.mlp = nn.Sequential()
for i in range(1, len(d_hidden)):
self.mlp.add_module(
"dense_layer_{}".format(i - 1),
dense_layer(
d_hidden[i - 1],
d_hidden[i],
activation,
dropout[i - 1],
batchnorm and (i != len(d_hidden) - 1 or batchnorm_last),
linear_first,
),
)
def forward(self, X: Tensor) -> Tensor:
return self.mlp(X)
class TabMlp(nn.Module):
r"""Defines a so-called ``TabMlp`` model that can be used as the
``deeptabular`` component of a Wide & Deep model.
This class combines embedding representations of the categorical features
......@@ -25,15 +81,25 @@ class DeepDense(nn.Module):
----------
deep_column_idx: Dict
Dict containing the index of the columns that will be passed through
the DeepDense model. Required to slice the tensors. e.g. {'education':
the TabMlp model. Required to slice the tensors. e.g. {'education':
0, 'relationship': 1, 'workclass': 2, ...}
hidden_layers: List
mlp_hidden_dims: List
List with the number of neurons per dense layer. e.g: [64,32]
batchnorm: bool
Boolean indicating whether or not to include batch normalizatin in the
activation: str, default = "relu"
Activation function for the dense layers of the MLP
dropout: float or List[float]], Optional, default = 0.
float or List of floats with the dropout between the dense layers.
e.g: [0.5,0.5]
mlp_batchnorm: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
dense layers
dropout: List, Optional
List with the dropout between the dense layers. e.g: [0.5,0.5]
mlp_batchnorm_last: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
last of the dense layers
linear_first: bool, default = False
Boolean indicating whether the order of the operations in the dense
layer. If 'True': [LIN -> ACT -> BN -> DP]. If 'False': [BN -> DP ->
LIN -> ACT]
embed_input: List, Optional
List of Tuples with the column name, number of unique values and
embedding dimension. e.g. [(education, 11, 32), ...]
......@@ -41,15 +107,18 @@ class DeepDense(nn.Module):
embeddings dropout
continuous_cols: List, Optional
List with the name of the numeric (aka continuous) columns
batchnorm_cont: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
continuous input
.. note:: Either ``embed_input`` or ``continuous_cols`` (or both) should be passed to the
model
Attributes
----------
dense: :obj:`nn.Sequential`
deep dense model that will receive the concatenation of the
embeddings and the continuous columns
mlp: :obj:`nn.Sequential`
mlp model that will receive the concatenation of the embeddings and
the continuous columns
embed_layers: :obj:`nn.ModuleDict`
:obj:`ModuleDict` with the embedding
output_dim: :obj:`int`
......@@ -59,35 +128,49 @@ class DeepDense(nn.Module):
Example
--------
>>> import torch
>>> from pytorch_widedeep.models import DeepDense
>>> from pytorch_widedeep.models import TabMlp
>>> X_deep = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
>>> colnames = ['a', 'b', 'c', 'd', 'e']
>>> embed_input = [(u,i,j) for u,i,j in zip(colnames[:4], [4]*4, [8]*4)]
>>> deep_column_idx = {k:v for v,k in enumerate(colnames)}
>>> model = DeepDense(hidden_layers=[8,4], deep_column_idx=deep_column_idx, embed_input=embed_input)
>>> model = TabMlp(mlp_hidden_dims=[8,4], deep_column_idx=deep_column_idx, embed_input=embed_input)
>>> out = model(X_deep)
"""
def __init__(
self,
deep_column_idx: Dict[str, int],
hidden_layers: List[int],
batchnorm: bool = False,
dropout: Optional[List[float]] = None,
mlp_hidden_dims: List[int],
activation: str = "relu",
dropout: Optional[Union[float, List[float]]] = 0.0,
mlp_batchnorm: bool = False,
mlp_batchnorm_last: bool = False,
linear_first: bool = False,
embed_input: Optional[List[Tuple[str, int, int]]] = None,
embed_dropout: float = 0.0,
continuous_cols: Optional[List[str]] = None,
batchnorm_cont: Optional[bool] = False,
):
super(DeepDense, self).__init__()
super(TabMlp, self).__init__()
self.deep_column_idx = deep_column_idx
self.hidden_layers = hidden_layers
self.batchnorm = batchnorm
self.mlp_hidden_dims = mlp_hidden_dims
self.activation = activation
self.dropout = dropout
self.mlp_batchnorm = mlp_batchnorm
self.linear_first = linear_first
self.embed_input = embed_input
self.embed_dropout = embed_dropout
self.continuous_cols = continuous_cols
self.batchnorm_cont = batchnorm_cont
if self.activation not in allowed_activations:
raise ValueError(
"the activation function for the dense layers must be one of {}. Got {} instead".format(
", ".join(allowed_activations), self.activation
)
)
# Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories.
if self.embed_input is not None:
......@@ -105,25 +188,25 @@ class DeepDense(nn.Module):
# Continuous
if self.continuous_cols is not None:
cont_inp_dim = len(self.continuous_cols)
if self.batchnorm_cont:
self.norm = nn.BatchNorm1d(cont_inp_dim)
else:
cont_inp_dim = 0
# Dense Layers
# MLP
input_dim = emb_inp_dim + cont_inp_dim
hidden_layers = [input_dim] + hidden_layers
if not dropout:
dropout = [0.0] * len(hidden_layers)
self.dense = nn.Sequential()
for i in range(1, len(hidden_layers)):
self.dense.add_module(
"dense_layer_{}".format(i - 1),
dense_layer(
hidden_layers[i - 1], hidden_layers[i], dropout[i - 1], batchnorm
),
)
mlp_hidden_dims = [input_dim] + mlp_hidden_dims
self.tab_mlp = MLP(
mlp_hidden_dims,
activation,
dropout,
mlp_batchnorm,
mlp_batchnorm_last,
linear_first,
)
# the output_dim attribute will be used as input_dim when "merging" the models
self.output_dim = hidden_layers[-1]
self.output_dim = mlp_hidden_dims[-1]
def forward(self, X: Tensor) -> Tensor: # type: ignore
r"""Forward pass that concatenates the continuous features with the
......@@ -141,5 +224,7 @@ class DeepDense(nn.Module):
if self.continuous_cols is not None:
cont_idx = [self.deep_column_idx[col] for col in self.continuous_cols]
x_cont = X[:, cont_idx].float()
if self.batchnorm_cont:
x_cont = self.norm(x_cont)
x = torch.cat([x, x_cont], 1) if self.embed_input is not None else x_cont
return self.dense(x)
return self.tab_mlp(x)
......@@ -5,6 +5,7 @@ import torch
from torch import nn
from torch.nn import Module
from .tab_mlp import MLP
from ..wdtypes import * # noqa: F403
......@@ -46,8 +47,43 @@ class BasicBlock(nn.Module):
return out
class DeepDenseResnet(nn.Module):
r"""Defines a so-called ``DeepDenseResnet`` model that can be used as the
class DenseResnet(nn.Module):
def __init__(self, input_dim: int, blocks_dim: List[int], dropout: float):
super(DenseResnet, self).__init__()
self.input_dim = input_dim
self.blocks_dim = blocks_dim
self.dropout = dropout
if input_dim != blocks_dim[0]:
self.dense_resnet = nn.Sequential(
OrderedDict(
[
("lin1", nn.Linear(input_dim, blocks_dim[0])),
("bn1", nn.BatchNorm1d(blocks_dim[0])),
]
)
)
else:
self.dense_resnet = nn.Sequential()
for i in range(1, len(blocks_dim)):
resize = None
if blocks_dim[i - 1] != blocks_dim[i]:
resize = nn.Sequential(
nn.Linear(blocks_dim[i - 1], blocks_dim[i]),
nn.BatchNorm1d(blocks_dim[i]),
)
self.dense_resnet.add_module(
"block_{}".format(i - 1),
BasicBlock(blocks_dim[i - 1], blocks_dim[i], dropout, resize),
)
def forward(self, X: Tensor) -> Tensor:
return self.dense_resnet(X)
class TabResnet(nn.Module):
r"""Defines a so-called ``TabResnet`` model that can be used as the
``deeptabular`` component of a Wide & Deep model.
This class combines embedding representations of the categorical
......@@ -58,30 +94,54 @@ class DeepDenseResnet(nn.Module):
Parameters
----------
embed_input: List
List of Tuples with the column name, number of unique values and
embedding dimension. e.g. [(education, 11, 32), ...].
deep_column_idx: Dict
Dict containing the index of the columns that will be passed through
the DeepDense model. Required to slice the tensors. e.g. {'education':
the TabMlp model. Required to slice the tensors. e.g. {'education':
0, 'relationship': 1, 'workclass': 2, ...}
blocks: List
blocks_dim: List
List of integers that define the input and output units of each block.
For example: ``[128, 64, 32]`` will generate 2 blocks. The first will
For example: ``[128, 64, 32]`` will generate 2 blocks_dim. The first will
receive a tensor of size 128 and output a tensor of size 64, and the
second will receive a tensor of size 64 and output a tensor of size
32. See ``pytorch_widedeep.models.deep_dense_resnet.BasicBlock`` for
details on the structure of each block.
activation: str, default = "relu"
Activation function for the dense layers of the MLP
dropout: float, default = 0.0
Block's `"internal"` dropout. This dropout is applied to the first of
the two dense layers that comprise each ``BasicBlock``
mlp_batchnorm: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
dense layers
mlp_batchnorm_last: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
last of the dense layers
linear_first: bool, default = False
Boolean indicating whether the order of the operations in the dense
layer. If 'True': [LIN -> ACT -> BN -> DP]. If 'False': [BN -> DP ->
LIN -> ACT]
embed_dropout: float, Optional, default = 0.0
embeddings dropout
continuous_cols: List, Optional
List with the name of the numeric (aka continuous) columns
batchnorm_cont: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
continuous input
concat_cont_first: bool, Optional, default = True
Boolean indicating
dropout: float, default = 0.0
Block's `"internal"` dropout. This dropout is applied to the first of
the two dense layers that comprise each ``BasicBlock``
embed_input: List, Optional
List of Tuples with the column name, number of unique values and
embedding dimension. e.g. [(education, 11, 32), ...]
embed_dropout: float
embeddings dropout
continuous_cols: List, Optional
List with the name of the numeric (aka continuous) columns
.. note:: Either ``embed_input`` or ``continuous_cols`` (or both) should be passed to the
model
.. note:: Unlike ``TabMlp``, ``TabResnet`` assumes that there are categorical
columns
Attributes
----------
......@@ -93,104 +153,135 @@ class DeepDenseResnet(nn.Module):
output_dim: :obj:`int`
The output dimension of the model. This is a required attribute
neccesary to build the WideDeep class
tab_resnet_mlp: :obj:`nn.Sequential`
if ``mlp_hidden_dims`` is ``True`, this attribute will be an mlp model
that will receive i) the results of concatenation of the embeddings
and the continuous columns (if present) and then passed them through
the ``dense_resnet``, or ii) the result of passing the embeddings
through the ``dense_resnet`` and the concatenating the results with
the continuous colnames (if present)
Example
--------
>>> import torch
>>> from pytorch_widedeep.models import DeepDenseResnet
>>> from pytorch_widedeep.models import TabResnet
>>> X_deep = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
>>> colnames = ['a', 'b', 'c', 'd', 'e']
>>> embed_input = [(u,i,j) for u,i,j in zip(colnames[:4], [4]*4, [8]*4)]
>>> deep_column_idx = {k:v for v,k in enumerate(colnames)}
>>> model = DeepDenseResnet(blocks=[16,4], deep_column_idx=deep_column_idx, embed_input=embed_input)
>>> model = TabResnet(blocks=[16,4], deep_column_idx=deep_column_idx, embed_input=embed_input)
>>> out = model(X_deep)
"""
def __init__(
self,
embed_input: List[Tuple[str, int, int]],
deep_column_idx: Dict[str, int],
blocks: List[int],
blocks_dim: List[int],
activation: str = "relu",
dropout: float = 0.0,
embed_dropout: float = 0.0,
embed_input: Optional[List[Tuple[str, int, int]]] = None,
mlp_hidden_dims: Optional[List[int]] = None,
mlp_batchnorm: Optional[bool] = False,
mlp_batchnorm_last: Optional[bool] = False,
linear_first: Optional[bool] = False,
embed_dropout: Optional[float] = 0.0,
continuous_cols: Optional[List[str]] = None,
batchnorm_cont: Optional[bool] = False,
concat_cont_first: Optional[bool] = True,
):
super(DeepDenseResnet, self).__init__()
if len(blocks) < 2:
raise ValueError(
"'blocks' must contain at least two elements, e.g. [256, 128]"
)
super(TabResnet, self).__init__()
self.embed_input = embed_input
self.deep_column_idx = deep_column_idx
self.blocks = blocks
self.blocks_dim = blocks_dim
self.activation = activation
self.dropout = dropout
self.mlp_hidden_dims = mlp_hidden_dims
self.mlp_batchnorm = mlp_batchnorm
self.mlp_batchnorm_last = mlp_batchnorm_last
self.linear_first = linear_first
self.embed_dropout = embed_dropout
self.embed_input = embed_input
self.continuous_cols = continuous_cols
self.batchnorm_cont = batchnorm_cont
self.concat_cont_first = concat_cont_first
# Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories.
if self.embed_input is not None:
self.embed_layers = nn.ModuleDict(
{
"emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0)
for col, val, dim in self.embed_input
}
if len(self.blocks_dim) < 2:
raise ValueError(
"'blocks' must contain at least two elements, e.g. [256, 128]"
)
self.embedding_dropout = nn.Dropout(embed_dropout)
emb_inp_dim = np.sum([embed[2] for embed in self.embed_input])
else:
emb_inp_dim = 0
# Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories.
self.embed_layers = nn.ModuleDict(
{
"emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0)
for col, val, dim in self.embed_input
}
)
self.embedding_dropout = nn.Dropout(embed_dropout)
emb_inp_dim = np.sum([embed[2] for embed in self.embed_input])
# Continuous
if self.continuous_cols is not None:
cont_inp_dim = len(self.continuous_cols)
if self.batchnorm_cont:
self.norm = nn.BatchNorm1d(cont_inp_dim)
else:
cont_inp_dim = 0
# Dense Resnet
input_dim = emb_inp_dim + cont_inp_dim
if input_dim != blocks[0]:
self.dense_resnet = nn.Sequential(
OrderedDict(
[
("lin1", nn.Linear(input_dim, blocks[0])),
("bn1", nn.BatchNorm1d(blocks[0])),
]
)
)
# DenseResnet
if self.concat_cont_first:
dense_resnet_input_dim = emb_inp_dim + cont_inp_dim
self.output_dim = blocks_dim[-1]
else:
self.dense_resnet = nn.Sequential()
for i in range(1, len(blocks)):
resize = None
if blocks[i - 1] != blocks[i]:
resize = nn.Sequential(
nn.Linear(blocks[i - 1], blocks[i]), nn.BatchNorm1d(blocks[i])
)
self.dense_resnet.add_module(
"block_{}".format(i - 1),
BasicBlock(blocks[i - 1], blocks[i], dropout, resize),
)
dense_resnet_input_dim = emb_inp_dim
self.output_dim = cont_inp_dim + blocks_dim[-1]
self.tab_resnet = DenseResnet(dense_resnet_input_dim, blocks_dim, dropout)
# the output_dim attribute will be used as input_dim when "merging" the models
self.output_dim = blocks[-1]
# MLP
if self.mlp_hidden_dims is not None:
if self.concat_cont_first:
mlp_input_dim = blocks_dim[-1]
else:
mlp_input_dim = cont_inp_dim + blocks_dim[-1]
mlp_hidden_dims = [mlp_input_dim] + mlp_hidden_dims
self.tab_resnet_mlp = MLP(
mlp_hidden_dims,
activation,
dropout,
mlp_batchnorm,
mlp_batchnorm_last,
linear_first,
)
self.output_dim = mlp_hidden_dims[-1]
def forward(self, X: Tensor) -> Tensor: # type: ignore
r"""Forward pass that concatenates the continuous features with the
embeddings. The result is then passed through a series of dense Resnet
blocks"""
if self.embed_input is not None:
embed = [
self.embed_layers["emb_layer_" + col](
X[:, self.deep_column_idx[col]].long()
)
for col, _, _ in self.embed_input
]
x = torch.cat(embed, 1)
x = self.embedding_dropout(x)
embed = [
self.embed_layers["emb_layer_" + col](
X[:, self.deep_column_idx[col]].long()
)
for col, _, _ in self.embed_input
]
x = torch.cat(embed, 1)
x = self.embedding_dropout(x)
if self.continuous_cols is not None:
cont_idx = [self.deep_column_idx[col] for col in self.continuous_cols]
x_cont = X[:, cont_idx].float()
x = torch.cat([x, x_cont], 1) if self.embed_input is not None else x_cont
return self.dense_resnet(x)
if self.batchnorm_cont:
x_cont = self.norm(x_cont)
if self.concat_cont_first:
x = torch.cat([x, x_cont], 1)
out = self.tab_resnet(x)
else:
out = torch.cat([self.tab_resnet(x), x_cont], 1)
else:
out = self.tab_resnet(x)
if self.mlp_hidden_dims is not None:
out = self.tab_resnet_mlp(out)
return out
......@@ -20,6 +20,7 @@ import torch
import einops
from torch import nn, einsum
from .tab_mlp import MLP
from ..wdtypes import * # noqa: F403
......@@ -160,36 +161,6 @@ class TransformerEncoder(nn.Module):
return self.ff_addnorm(Y, self.feed_forward(Y))
class MLP(nn.Module):
def __init__(
self,
d_hidden: List[int],
dropout: Optional[Union[float, List[float]]],
activation: str = "relu",
):
super(MLP, self).__init__()
if not dropout:
dropout = [0.0] * len(d_hidden)
elif isinstance(dropout, float):
dropout = [dropout] * len(d_hidden)
self.mlp = nn.Sequential()
for i in range(1, len(d_hidden)):
self.mlp.add_module(
"dense_layer_{}".format(i - 1),
_dense_layer(
d_hidden[i - 1],
d_hidden[i],
activation,
dropout[i - 1],
),
)
def forward(self, X: Tensor) -> Tensor:
return self.mlp(X)
class SharedEmbeddings(nn.Module):
def __init__(
self,
......@@ -240,8 +211,8 @@ class TabTransformer(nn.Module):
embed_dropout: float, default = 0.
embeddings dropout.
shared_embed: bool, default = False
The idea behind `shared_embed` is described in the Appendix A in the paper: `
'The goal of having column embedding is to enable the model to distinguish the
The idea behind `shared_embed` is described in the Appendix A in the paper:
`'The goal of having column embedding is to enable the model to distinguish the
classes in one column from those in the other columns'`. In other words, the idea
is to let the model learn the which column is embedding.
add_shared_embed: bool, default = False,
......@@ -325,8 +296,11 @@ class TabTransformer(nn.Module):
num_cat_columns: Optional[int] = None,
ff_hidden_dim: int = 32 * 4,
transformer_activation: str = "gelu",
mlp_activation: str = "relu",
mlp_hidden_dims: Optional[List[int]] = None,
mlp_activation: Optional[str] = "relu",
mlp_batchnorm: Optional[bool] = False,
mlp_batchnorm_last: Optional[bool] = False,
mlp_linear_first: Optional[bool] = True,
):
super(TabTransformer, self).__init__()
......@@ -347,8 +321,11 @@ class TabTransformer(nn.Module):
self.num_cat_columns = num_cat_columns
self.ff_hidden_dim = ff_hidden_dim
self.transformer_activation = transformer_activation
self.mlp_activation = mlp_activation
self.mlp_hidden_dims = mlp_hidden_dims
self.mlp_activation = mlp_activation
self.mlp_batchnorm = mlp_batchnorm
self.mlp_batchnorm_last = mlp_batchnorm_last
self.mlp_linear_first = mlp_linear_first
# Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories.
if shared_embed:
......@@ -403,7 +380,14 @@ class TabTransformer(nn.Module):
mlp_inp_l = len(embed_input) * input_dim + cont_inp_dim
mlp_hidden_dims = [mlp_inp_l, mlp_inp_l * 4, mlp_inp_l * 2]
self.mlp = MLP(mlp_hidden_dims, dropout, mlp_activation)
self.tab_transformer_mlp = MLP(
mlp_hidden_dims,
mlp_activation,
dropout,
mlp_batchnorm,
mlp_batchnorm_last,
mlp_linear_first,
)
# the output_dim attribute will be used as input_dim when "merging" the models
self.output_dim = mlp_hidden_dims[-1]
......@@ -431,4 +415,4 @@ class TabTransformer(nn.Module):
x_cont = X[:, cont_idx].float()
x = torch.cat([x, x_cont], 1)
return self.mlp(x)
return self.tab_transformer_mlp(x)
......@@ -12,10 +12,10 @@ from sklearn.model_selection import train_test_split
from ..losses import FocalLoss
from ._warmup import WarmUp
from .tab_mlp import MLP
from ..metrics import Metric, MetricCallback, MultipleMetrics
from ..wdtypes import * # noqa: F403
from ..callbacks import History, Callback, CallbackContainer
from .deep_dense import dense_layer
from ._wd_dataset import WideDeepDataset
from ..initializers import Initializer, MultipleInitializer
from ._multiple_optimizer import MultipleOptimizer
......@@ -71,25 +71,25 @@ class WideDeep(nn.Module):
Parameters
----------
wide: nn.Module
wide: nn.Module, Optional
Wide model. We recommend using the ``Wide`` class in this package.
However, it is possible to use a custom model as long as is consistent
with the required architecture, see
:class:`pytorch_widedeep.models.wide.Wide`
deeptabular: nn.Module
deeptabular: nn.Module, Optional
currently we offer three possible architectures for the `deeptabular` component
implemented in this package. These are: ``DeepDense``, ``DeepDenseResnet`` and `
implemented in this package. These are: ``TabMlp``, ``TabResnet`` and `
``TabTransformer``.
1. ``DeepDense`` is simply an embedding layer encoding the categorical
1. ``TabMlp`` is simply an embedding layer encoding the categorical
features that are then concatenated and passed through a series of
dense layers.
See: ``pytorch_widedeep.models.deep_dense.DeepDense``
dense (hidden) layers (i.e. and MLP).
See: ``pytorch_widedeep.models.deep_dense.TabMlp``
2. ``DeepDenseResnet`` is an embedding layer encoding the categorical
2. ``TabResnet`` is an embedding layer encoding the categorical
features that are then concatenated and passed through a series of
`"dense"` ResNet blocks.
See ``pytorch_widedeep.models.deep_dense_resnet.DeepDenseResnet``
See ``pytorch_widedeep.models.deep_dense_resnet.TabResnet``
3. ``TabTransformer`` is detailed in `TabTransformer: Tabular Data Modeling
Using Contextual Embeddings <https://arxiv.org/pdf/2012.06678.pdf>`_.
......@@ -108,14 +108,25 @@ class WideDeep(nn.Module):
required architecture. See
:class:`pytorch_widedeep.models.deep_dense.DeepImage`
deephead: nn.Module, Optional
`Dense` model consisting in a stack of dense layers. The FC-Head.
head_layers: List, Optional
Alternatively, we can use ``head_layers`` to specify the sizes of the
Custom model by the user that will receive the outtput of the deep
component. Typically a FC-Head
head_layers_dim: List, Optional
Alternatively, we can use ``head_layers_dim`` to specify the sizes of the
stacked dense layers in the fc-head e.g: ``[128, 64]``
head_dropout: List, Optional
Dropout between the layers in ``head_layers``. e.g: ``[0.5, 0.5]``
head_dropout: float or List, Optional
Dropout between the layers in ``head_layers_dim``. e.g: ``[0.5, 0.5]``
head_activation: str, default = "relu"
activation function of the head layers. One of "relu", gelu" or
"leaky_relu"
head_batchnorm: bool, Optional
Specifies if batch normalizatin should be included in the dense layers
Specifies if batch normalizatin should be included in the head layers
head_batchnorm_last: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
last of the dense layers
head__linear_first: bool, default = False
Boolean indicating whether the order of the operations in the dense
layer. If 'True': [LIN -> ACT -> BN -> DP]. If 'False': [BN -> DP ->
LIN -> ACT]
pred_dim: int
Size of the final wide and deep output layer containing the
predictions. `1` for regression and binary classification or `number
......@@ -146,7 +157,7 @@ class WideDeep(nn.Module):
that these activations are collected by ``WideDeep`` and combined
accordingly. In addition, the models MUST also contain an attribute
``output_dim`` with the size of these last layers of activations. See
for example :class:`pytorch_widedeep.models.deep_dense.DeepDense`
for example :class:`pytorch_widedeep.models.tab_mlp.TabMlp`
"""
......@@ -158,9 +169,12 @@ class WideDeep(nn.Module):
deeptext: Optional[nn.Module] = None,
deepimage: Optional[nn.Module] = None,
deephead: Optional[nn.Module] = None,
head_layers: Optional[List[int]] = None,
head_dropout: Optional[List] = None,
head_batchnorm: Optional[bool] = None,
head_layers_dim: Optional[List[int]] = None,
head_activation: Optional[str] = "relu",
head_dropout: Optional[Union[float, List[float]]] = None,
head_batchnorm: Optional[bool] = False,
head_batchnorm_last: Optional[bool] = False,
head_linear_first: Optional[bool] = False,
pred_dim: int = 1,
):
......@@ -172,8 +186,7 @@ class WideDeep(nn.Module):
deeptext,
deepimage,
deephead,
head_layers,
head_dropout,
head_layers_dim,
pred_dim,
)
......@@ -188,30 +201,25 @@ class WideDeep(nn.Module):
self.deephead = deephead
if self.deephead is None:
if head_layers is not None:
input_dim = 0
if head_layers_dim is not None:
deep_dim = 0
if self.deeptabular is not None:
input_dim += self.deeptabular.output_dim # type:ignore
deep_dim += self.deeptabular.output_dim # type:ignore
if self.deeptext is not None:
input_dim += self.deeptext.output_dim # type:ignore
deep_dim += self.deeptext.output_dim # type:ignore
if self.deepimage is not None:
input_dim += self.deepimage.output_dim # type:ignore
head_layers = [input_dim] + head_layers
if not head_dropout:
head_dropout = [0.0] * (len(head_layers) - 1)
self.deephead = nn.Sequential()
for i in range(1, len(head_layers)):
self.deephead.add_module(
"head_layer_{}".format(i - 1),
dense_layer(
head_layers[i - 1],
head_layers[i],
head_dropout[i - 1],
head_batchnorm,
),
)
deep_dim += self.deepimage.output_dim # type:ignore
head_layers_dim = [deep_dim] + head_layers_dim
self.deephead = MLP(
head_layers_dim,
head_activation,
head_dropout,
head_batchnorm,
head_batchnorm_last,
head_linear_first,
)
self.deephead.add_module(
"head_out", nn.Linear(head_layers[-1], pred_dim)
"head_out", nn.Linear(head_layers_dim[-1], pred_dim)
)
else:
if self.deeptabular is not None:
......@@ -359,12 +367,12 @@ class WideDeep(nn.Module):
>>>
>>> from pytorch_widedeep.callbacks import EarlyStopping, LRHistory
>>> from pytorch_widedeep.initializers import KaimingNormal, KaimingUniform, Normal, Uniform
>>> from pytorch_widedeep.models import DeepDenseResnet, DeepImage, DeepText, Wide, WideDeep
>>> from pytorch_widedeep.models import TabResnet, DeepImage, DeepText, Wide, WideDeep
>>> from pytorch_widedeep.optim import RAdam
>>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
>>> deep_column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
>>> wide = Wide(10, 1)
>>> deeptabular = DeepDenseResnet(blocks=[8, 4], deep_column_idx=deep_column_idx, embed_input=embed_input)
>>> deeptabular = TabResnet(blocks=[8, 4], deep_column_idx=deep_column_idx, embed_input=embed_input)
>>> deeptext = DeepText(vocab_size=10, embed_dim=4, padding_idx=0)
>>> deepimage = DeepImage(pretrained=False)
>>> model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=deeptext, deepimage=deepimage)
......@@ -1112,8 +1120,7 @@ class WideDeep(nn.Module):
deeptext,
deepimage,
deephead,
head_layers,
head_dropout,
head_layers_dim,
pred_dim,
):
......@@ -1139,23 +1146,19 @@ class WideDeep(nn.Module):
"deepimage model must have an 'output_dim' attribute. "
"See pytorch-widedeep.models.deep_dense.DeepText"
)
if deephead is not None and head_layers is not None:
if deephead is not None and head_layers_dim is not None:
raise ValueError(
"both 'deephead' and 'head_layers' are not None. Use one of the other, but not both"
"both 'deephead' and 'head_layers_dim' are not None. Use one of the other, but not both"
)
if (
head_layers is not None
head_layers_dim is not None
and not deeptabular
and not deeptext
and not deepimage
):
raise ValueError(
"if 'head_layers' is not None, at least one deep component must be used"
"if 'head_layers_dim' is not None, at least one deep component must be used"
)
if head_layers is not None and head_dropout is not None:
assert len(head_layers) == len(
head_dropout
), "'head_layers' and 'head_dropout' must have the same length"
if deephead is not None:
deephead_inp_feat = next(deephead.parameters()).size(1)
output_dim = 0
......
......@@ -54,7 +54,7 @@ def test_deep_image_resnet_34():
###############################################################################
# Testing the head
###############################################################################
model6 = DeepImage(head_layers=[512, 256, 128], head_dropout=[0.0, 0.0])
model6 = DeepImage(head_layers_dim=[512, 256, 128], head_dropout=[0.0, 0.5])
def test_deep_image_2():
......
......@@ -57,7 +57,7 @@ def test_catch_warning():
###############################################################################
model4 = DeepText(
vocab_size=vocab_size, embed_dim=32, padding_idx=0, head_layers=[64, 16]
vocab_size=vocab_size, embed_dim=32, padding_idx=0, head_layers_dim=[64, 16]
)
......
......@@ -2,8 +2,9 @@ import string
import numpy as np
import torch
import pytest
from pytorch_widedeep.models import DeepDense
from pytorch_widedeep.models import TabMlp
colnames = list(string.ascii_lowercase)[:10]
embed_cols = [np.random.choice(np.arange(5), 10) for _ in range(5)]
......@@ -18,8 +19,8 @@ X_deep_cont = X_deep[:, 5:]
# Embeddings and NO continuous_cols
###############################################################################
embed_input = [(u, i, j) for u, i, j in zip(colnames[:5], [5] * 5, [16] * 5)]
model1 = DeepDense(
hidden_layers=[32, 16],
model1 = TabMlp(
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.2],
deep_column_idx={k: v for v, k in enumerate(colnames[:5])},
embed_input=embed_input,
......@@ -35,8 +36,8 @@ def test_deep_dense_embed():
# Continous cols but NO embeddings
###############################################################################
continuous_cols = colnames[-5:]
model2 = DeepDense(
hidden_layers=[32, 16],
model2 = TabMlp(
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.2],
deep_column_idx={k: v for v, k in enumerate(colnames[5:])},
continuous_cols=continuous_cols,
......@@ -49,17 +50,39 @@ def test_deep_dense_cont():
###############################################################################
# Continous Cols and Embeddings
# All parameters
###############################################################################
model3 = DeepDense(
hidden_layers=[32, 16],
dropout=[0.5, 0.2],
model3 = TabMlp(
deep_column_idx={k: v for v, k in enumerate(colnames)},
mlp_hidden_dims=[32, 16, 8],
dropout=0.1,
mlp_batchnorm=True,
mlp_batchnorm_last=False,
linear_first=False,
embed_input=embed_input,
embed_dropout=0.1,
continuous_cols=continuous_cols,
batchnorm_cont=True,
)
def test_deep_dense():
out = model3(X_deep)
assert out.size(0) == 10 and out.size(1) == 16
assert out.size(0) == 10 and out.size(1) == 8
###############################################################################
# Test raise ValueError
###############################################################################
def test_act_fn_ValueError():
with pytest.raises(ValueError):
model4 = TabMlp( # noqa: F841
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.2],
activation="javier",
deep_column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=continuous_cols,
)
......@@ -2,64 +2,121 @@ import string
import numpy as np
import torch
import pytest
from pytorch_widedeep.models import DeepDenseResnet
from pytorch_widedeep.models import TabResnet
colnames = list(string.ascii_lowercase)[:10]
embed_cols = [np.random.choice(np.arange(5), 10) for _ in range(5)]
cont_cols = [np.random.rand(10) for _ in range(5)]
X_deep = torch.from_numpy(np.vstack(embed_cols + cont_cols).transpose())
X_deep_emb = X_deep[:, :5]
X_deep_cont = X_deep[:, 5:]
X_tab = torch.from_numpy(np.vstack(embed_cols + cont_cols).transpose())
X_tab_emb = X_tab[:, :5]
X_tab_cont = X_tab[:, 5:]
###############################################################################
# Embeddings and NO continuous_cols
# Embeddings and no continuous_cols
###############################################################################
embed_input = [(u, i, j) for u, i, j in zip(colnames[:5], [5] * 5, [16] * 5)]
model1 = DeepDenseResnet(
blocks=[32, 16],
model1 = TabResnet(
blocks_dim=[32, 16],
dropout=0.5,
deep_column_idx={k: v for v, k in enumerate(colnames[:5])},
embed_input=embed_input,
)
def test_deep_dense_resnet_embed():
out = model1(X_deep_emb)
def test_tab_resnet_embed():
out = model1(X_tab_emb)
assert out.size(0) == 10 and out.size(1) == 16
###############################################################################
# Continous cols but NO embeddings
# Continous Cols and Embeddings
###############################################################################
continuous_cols = colnames[-5:]
model2 = DeepDenseResnet(
blocks=[32, 16, 16],
model2 = TabResnet(
blocks_dim=[32, 16, 8],
dropout=0.5,
deep_column_idx={k: v for v, k in enumerate(colnames[5:])},
deep_column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=continuous_cols,
)
def test_deep_dense_resnet_cont():
out = model2(X_deep_cont)
assert out.size(0) == 10 and out.size(1) == 16
def test_tab_resnet_dense():
out = model2(X_tab)
assert out.size(0) == 10 and out.size(1) == 8
###############################################################################
# Continous Cols and Embeddings
# Continous Cols concatenated with Embeddings or with the output of the
# dense_resnet
###############################################################################
model3 = DeepDenseResnet(
blocks=[32, 16, 8],
dropout=0.5,
deep_column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=continuous_cols,
continuous_cols = colnames[-5:]
@pytest.mark.parametrize(
"concat_cont_first",
[
True,
False,
],
)
def test_cont_contat(concat_cont_first):
model3 = TabResnet(
blocks_dim=[32, 16, 8],
dropout=0.5,
deep_column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=continuous_cols,
concat_cont_first=concat_cont_first,
)
out = model3(X_tab)
assert out.size(0) == 10 and out.size(1) == model3.output_dim
def test_deep_dense_resnet_dense():
out = model3(X_deep)
assert out.size(0) == 10 and out.size(1) == 8
###############################################################################
# Test full set up
###############################################################################
@pytest.mark.parametrize(
"concat_cont_first",
[
True,
False,
],
)
def test_full_setup(concat_cont_first):
model4 = TabResnet(
embed_input=embed_input,
deep_column_idx={k: v for v, k in enumerate(colnames)},
blocks_dim=[32, 16, 8],
dropout=0.1,
mlp_hidden_dims=[32, 16],
mlp_batchnorm=True,
mlp_batchnorm_last=False,
embed_dropout=0.1,
continuous_cols=continuous_cols,
batchnorm_cont=True,
concat_cont_first=concat_cont_first,
)
out = model4(X_tab)
true_mlp_inp_dim = list(model4.tab_resnet_mlp.mlp.dense_layer_0.parameters())[
2
].size(1)
if concat_cont_first:
expected_mlp_inp_dim = model4.blocks_dim[-1]
else:
expected_mlp_inp_dim = model4.blocks_dim[-1] + len(continuous_cols)
assert (
out.size(0) == 10
and out.size(1) == model4.output_dim
and expected_mlp_inp_dim == true_mlp_inp_dim
)
......@@ -3,19 +3,13 @@ from copy import deepcopy
import pytest
from torch import nn
from pytorch_widedeep.models import (
Wide,
DeepText,
WideDeep,
DeepDense,
DeepImage,
)
from pytorch_widedeep.models import Wide, TabMlp, DeepText, WideDeep, DeepImage
embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
deep_column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
wide = Wide(10, 1)
deepdense = DeepDense(
hidden_layers=[16, 8], deep_column_idx=deep_column_idx, embed_input=embed_input
deepdense = TabMlp(
mlp_hidden_dims=[16, 8], deep_column_idx=deep_column_idx, embed_input=embed_input
)
deeptext = DeepText(vocab_size=100, embed_dim=8)
deepimage = DeepImage(pretrained=False)
......@@ -49,26 +43,28 @@ def test_history_callback(deepcomponent, component_name):
###############################################################################
#  test warning on head_layers and deephead
#  test warning on head_layers_dim and deephead
###############################################################################
def test_deephead_and_head_layers():
def test_deephead_and_head_layers_dim():
deephead = nn.Sequential(nn.Linear(32, 16), nn.Linear(16, 8))
with pytest.raises(ValueError):
model = WideDeep( # noqa: F841
wide=wide, deeptabular=deepdense, head_layers=[16, 8], deephead=deephead
wide=wide, deeptabular=deepdense, head_layers_dim=[16, 8], deephead=deephead
)
###############################################################################
#  test deephead is None and head_layers is not None
#  test deephead is None and head_layers_dim is not None
###############################################################################
def test_no_deephead_and_head_layers():
def test_no_deephead_and_head_layers_dim():
out = []
model = WideDeep(wide=wide, deeptabular=deepdense, head_layers=[8, 4]) # noqa: F841
model = WideDeep(
wide=wide, deeptabular=deepdense, head_layers_dim=[8, 4]
) # noqa: F841
for n, p in model.named_parameters():
if n == "deephead.head_layer_0.0.weight":
out.append(p.size(0) == 8 and p.size(1) == 8)
......
......@@ -8,7 +8,7 @@ import pytest
from torch.optim.lr_scheduler import StepLR, CyclicLR
from pytorch_widedeep.optim import RAdam
from pytorch_widedeep.models import Wide, WideDeep, DeepDense, TabTransformer
from pytorch_widedeep.models import Wide, TabMlp, WideDeep, TabTransformer
from pytorch_widedeep.callbacks import (
LRHistory,
EarlyStopping,
......@@ -39,8 +39,8 @@ target = np.random.choice(2, 32)
# Test that history saves the information adequately
###############################################################################
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = DeepDense(
hidden_layers=[32, 16],
deeptabular = TabMlp(
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.5],
deep_column_idx=deep_column_idx,
embed_input=embed_input,
......@@ -129,8 +129,8 @@ def test_history_callback(
###############################################################################
def test_early_stop():
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = DeepDense(
hidden_layers=[32, 16],
deeptabular = TabMlp(
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.5],
deep_column_idx=deep_column_idx,
embed_input=embed_input,
......@@ -159,8 +159,8 @@ def test_early_stop():
)
def test_model_checkpoint(save_best_only, max_save, n_files):
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = DeepDense(
hidden_layers=[32, 16],
deeptabular = TabMlp(
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.5],
deep_column_idx=deep_column_idx,
embed_input=embed_input,
......@@ -185,8 +185,8 @@ def test_model_checkpoint(save_best_only, max_save, n_files):
def test_filepath_error():
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = DeepDense(
hidden_layers=[16, 4],
deeptabular = TabMlp(
mlp_hidden_dims=[16, 4],
deep_column_idx=deep_column_idx,
embed_input=embed_input,
continuous_cols=colnames[-5:],
......
......@@ -6,13 +6,7 @@ from torch import nn
from torchvision.transforms import ToTensor, Normalize
from sklearn.model_selection import train_test_split
from pytorch_widedeep.models import (
Wide,
DeepText,
WideDeep,
DeepDense,
DeepImage,
)
from pytorch_widedeep.models import Wide, TabMlp, DeepText, WideDeep, DeepImage
np.random.seed(1)
......@@ -54,8 +48,8 @@ target = np.random.choice(2, 32)
# build model components
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = DeepDense(
hidden_layers=[32, 16],
deeptabular = TabMlp(
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.5],
deep_column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
......@@ -402,7 +396,7 @@ def test_head_layers_individual_components(
deeptabular=deeptabular,
deeptext=deeptext,
deepimage=deepimage,
head_layers=[8, 4],
head_layers_dim=[8, 4],
) # noqa: F841
model.compile(method="binary", verbose=0)
model.fit(
......
......@@ -4,7 +4,7 @@ import numpy as np
import pytest
from torch import nn
from pytorch_widedeep.models import Wide, WideDeep, DeepDense, TabTransformer
from pytorch_widedeep.models import Wide, TabMlp, WideDeep, TabTransformer
# Wide array
X_wide = np.random.choice(50, (32, 10))
......@@ -54,8 +54,8 @@ def test_fit_methods(
probs_dim,
):
wide = Wide(np.unique(X_wide).shape[0], pred_dim)
deeptabular = DeepDense(
hidden_layers=[32, 16],
deeptabular = TabMlp(
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.5],
deep_column_idx=deep_column_idx,
embed_input=embed_input,
......@@ -77,8 +77,8 @@ def test_fit_methods(
##############################################################################
def test_fit_with_deephead():
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = DeepDense(
hidden_layers=[32, 16],
deeptabular = TabMlp(
mlp_hidden_dims=[32, 16],
deep_column_idx=deep_column_idx,
embed_input=embed_input,
continuous_cols=colnames[-5:],
......
......@@ -3,7 +3,7 @@ import string
import numpy as np
import pytest
from pytorch_widedeep.models import Wide, WideDeep, DeepDense
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
# Wide array
X_wide = np.random.choice(50, (100, 10))
......@@ -33,8 +33,8 @@ target_multic = np.random.choice(3, 100)
)
def test_focal_loss(X_wide, X_tab, target, method, pred_dim, probs_dim):
wide = Wide(np.unique(X_wide).shape[0], pred_dim)
deeptabular = DeepDense(
hidden_layers=[32, 16],
deeptabular = TabMlp(
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.5],
deep_column_idx=deep_column_idx,
embed_input=embed_input,
......
......@@ -5,13 +5,7 @@ import numpy as np
import torch
import pytest
from pytorch_widedeep.models import (
Wide,
DeepText,
WideDeep,
DeepDense,
DeepImage,
)
from pytorch_widedeep.models import Wide, TabMlp, DeepText, WideDeep, DeepImage
from pytorch_widedeep.initializers import (
Normal,
Uniform,
......@@ -79,8 +73,8 @@ test_layers = [
def test_initializers_1(initializers, test_layers):
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = DeepDense(
hidden_layers=[32, 16],
deeptabular = TabMlp(
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.5],
deep_column_idx=deep_column_idx,
embed_input=embed_input,
......@@ -130,8 +124,8 @@ initializers_2 = {
def test_initializers_with_pattern():
wide = Wide(100, 1)
deeptabular = DeepDense(
hidden_layers=[32, 16],
deeptabular = TabMlp(
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.5],
deep_column_idx=deep_column_idx,
embed_input=embed_input,
......
......@@ -7,12 +7,12 @@ from sklearn.model_selection import train_test_split
from pytorch_widedeep.models import (
Wide,
TabMlp,
DeepText,
WideDeep,
DeepDense,
DeepImage,
TabResnet,
TabTransformer,
DeepDenseResnet,
)
from pytorch_widedeep.metrics import Accuracy, Precision
from pytorch_widedeep.callbacks import EarlyStopping
......@@ -57,15 +57,15 @@ target_multi = np.random.choice(3, 32)
# build model components
wide = Wide(np.unique(X_wide).shape[0], 1)
deepdense = DeepDense(
hidden_layers=[32, 16],
deepdense = TabMlp(
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.5],
deep_column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=colnames[-5:],
)
deepdenseresnet = DeepDenseResnet(
blocks=[32, 16],
deepdenseresnet = TabResnet(
blocks_dim=[32, 16],
deep_column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=colnames[-5:],
......@@ -158,8 +158,8 @@ def test_basic_run_with_metrics_binary(wide, deeptabular):
def test_basic_run_with_metrics_multiclass():
wide = Wide(np.unique(X_wide).shape[0], 3)
deeptabular = DeepDense(
hidden_layers=[32, 16],
deeptabular = TabMlp(
mlp_hidden_dims=[32, 16],
dropout=[0.5, 0.5],
deep_column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
......
......@@ -8,7 +8,7 @@ from torch import nn
from sklearn.utils import Bunch
from torch.utils.data import Dataset, DataLoader
from pytorch_widedeep.models import Wide, DeepDense
from pytorch_widedeep.models import Wide, TabMlp
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.models._warmup import WarmUp
from pytorch_widedeep.models.deep_image import conv_layer
......@@ -112,8 +112,8 @@ if use_cuda:
wide.cuda()
# deep
deeptabular = DeepDense(
hidden_layers=[16, 8],
deeptabular = TabMlp(
mlp_hidden_dims=[16, 8],
dropout=[0.5, 0.2],
deep_column_idx=deep_column_idx,
embed_input=embed_input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册